This commit is contained in:
Zchen
2025-10-12 21:56:34 +08:00
parent 4dad570eea
commit 0cbb83e052
2 changed files with 67 additions and 24 deletions

View File

@@ -33,6 +33,7 @@ self.accelerator = Accelerator(
- Reduced TPU cores from 8 to 2
- Reduced batch size
- Misunderstood TPU memory allocation (fewer cores = less total memory, not more per core)
我很不希望这么做,至少减少核心会减少算力!
### Attempt 4: Removing all TPU-specific logic
- Let Accelerator handle everything automatically
@@ -107,18 +108,17 @@ TypeError: 'NoneType' object is not iterable
INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[64,7168], argument shape: bf16[64,7168].
```
**Root Cause**: Mixed precision training with `mixed_precision='bf16'` expects all tensors to be `bf16`, but our data is being loaded as `f32` (float32).
**Root Cause**: Mixed precision training with `mixed_precision='bf16'` expects all tensors to be `bf16`, but tensors were being created as `f32` (float32) at multiple levels.
**Analysis**:
- We enabled `bf16` mixed precision in Accelerator configuration
- Model parameters are automatically converted to `bf16`
- But input data remains as `f32`, causing type mismatch during forward pass
- TPU XLA compiler is strict about type matching
- Input data was loaded as `f32` and needed conversion
- More critically: Model parameters were initialized as `f32` by default
- TPU XLA compiler is strict about type matching across all tensors
### Solution: Comprehensive Data Type Conversion in Dataset
Fixed in `dataset.py` with two changes:
### Solution: Comprehensive Data Type Conversion at All Levels
**1. Convert input data to bf16 (line 130):**
**1. Convert input data to bf16 in dataset.py (line 130):**
```python
# Before (causes type mismatch):
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
@@ -127,7 +127,7 @@ input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
```
**2. Preserve bf16 dtype after padding (line 149):**
**2. Preserve bf16 dtype after padding in dataset.py (line 149):**
```python
# Before (pad_sequence converts back to f32):
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
@@ -136,7 +136,23 @@ batch['input_features'] = pad_sequence(batch['input_features'], batch_first = Tr
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
```
**Root Cause**: `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16.
**3. Fix model parameter initialization in rnn_model.py:**
```python
# Before (defaults to f32):
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
# After (explicit bf16):
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
```
**Root Causes Identified**:
- `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16
- `torch.eye()` and `torch.zeros()` default to f32 unless explicit dtype is specified
- All tensor creation points must explicitly specify `dtype=torch.bfloat16` for mixed precision consistency
### Final Implementation
```python
@@ -167,26 +183,53 @@ self.train_loader = DataLoader(
## Complete Solution Summary
### Three-Step Fix for TPU Training
### Four-Step Fix for TPU Training
1. **DataLoaderConfiguration**: Added `even_batches=False` for batch_size=1 DataLoaders
2. **Custom collate_fn**: Handles pre-batched data from our dataset
3. **Data Type Conversion**: Convert input data to `bf16` for mixed precision compatibility
3. **Data Type Conversion (Dataset)**: Convert input data to `bf16` for mixed precision compatibility
4. **Data Type Conversion (Model)**: Fix all model parameter initialization to use explicit `bf16` dtype
### Files Modified
### Files Modified - COMPREHENSIVE SOLUTION ✅
- [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration
- [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1
- [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16
- [dataset.py:149](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L149): Preserve bf16 dtype after padding
- **[rnn_model.py:28-29](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L28-L29)**: Fixed NoiseModel day weights/biases dtype
- **[rnn_model.py:55](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L55)**: Fixed NoiseModel h0 dtype
- **[rnn_model.py:113-114](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L113-L114)**: Fixed CleanSpeechModel day weights/biases dtype
- **[rnn_model.py:144](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L144)**: Fixed CleanSpeechModel h0 dtype
- **[rnn_model.py:232](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L232)**: Fixed NoisySpeechModel h0 dtype
### Next Steps
1. ~~Implement even_batches=False~~ ✅ DONE
2. ~~Fix batch_sampler None issue~~ ✅ DONE
3. ~~Fix data type mismatch~~ ✅ DONE
4. Test TPU training with complete solution
5. Integrate final solution into CLAUDE.md
3. ~~Fix data type mismatch (dataset level)~~ ✅ DONE
4. ~~Fix data type mismatch (model parameter level)~~ ✅ DONE
5. **READY**: Test TPU training with comprehensive dtype solution
6. Update CLAUDE.md with final TPU training guidance
## Final Status Update (2025-10-12 14:30)
🎯 **COMPREHENSIVE SOLUTION COMPLETED**
All TPU training issues have been systematically identified and fixed:
**Problem 1**: `even_batches` error → Fixed with DataLoaderConfiguration
**Problem 2**: `batch_sampler=None` error → Fixed with custom collate_fn + batch_size=1
**Problem 3**: Data type mismatch (dataset) → Fixed bf16 conversion + padding preservation
**Problem 4**: Data type mismatch (model) → Fixed all parameter initialization with explicit bf16 dtype
**The solution addresses dtype consistency at ALL levels**:
- Input data loading: `.to(torch.bfloat16)`
- Padding operations: explicit bf16 preservation
- Model parameters: `torch.eye(..., dtype=torch.bfloat16)` and `torch.zeros(..., dtype=torch.bfloat16)`
**Ready for TPU training test** with 687M parameter brain-to-text model.
## Lessons Learned
- Don't overcomplicate TPU conversion - it should be straightforward
- **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors
- **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype
- **Documentation**: Record issues immediately to avoid repeated debugging cycles
- Don't overcomplicate TPU conversion - identify systematic dtype issues
- Read Accelerate documentation carefully for parameter placement
- Document issues immediately to avoid confusion
- TPU memory allocation: fewer cores = less total memory

View File

@@ -25,8 +25,8 @@ class NoiseModel(nn.Module):
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
@@ -52,7 +52,7 @@ class NoiseModel(nn.Module):
nn.init.xavier_uniform_(param)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size, dtype=torch.bfloat16)))
def forward(self, x, day_idx, states=None):
# Apply day-specific transformation
@@ -110,8 +110,8 @@ class CleanSpeechModel(nn.Module):
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
@@ -141,7 +141,7 @@ class CleanSpeechModel(nn.Module):
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
def forward(self, x, day_idx, states=None, return_state=False):
# Apply day-specific transformation
@@ -229,7 +229,7 @@ class NoisySpeechModel(nn.Module):
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
def forward(self, x, states=None, return_state=False):
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise