tpu
This commit is contained in:
@@ -33,6 +33,7 @@ self.accelerator = Accelerator(
|
||||
- Reduced TPU cores from 8 to 2
|
||||
- Reduced batch size
|
||||
- Misunderstood TPU memory allocation (fewer cores = less total memory, not more per core)
|
||||
我很不希望这么做,至少减少核心会减少算力!
|
||||
|
||||
### Attempt 4: Removing all TPU-specific logic
|
||||
- Let Accelerator handle everything automatically
|
||||
@@ -107,18 +108,17 @@ TypeError: 'NoneType' object is not iterable
|
||||
INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[64,7168], argument shape: bf16[64,7168].
|
||||
```
|
||||
|
||||
**Root Cause**: Mixed precision training with `mixed_precision='bf16'` expects all tensors to be `bf16`, but our data is being loaded as `f32` (float32).
|
||||
**Root Cause**: Mixed precision training with `mixed_precision='bf16'` expects all tensors to be `bf16`, but tensors were being created as `f32` (float32) at multiple levels.
|
||||
|
||||
**Analysis**:
|
||||
- We enabled `bf16` mixed precision in Accelerator configuration
|
||||
- Model parameters are automatically converted to `bf16`
|
||||
- But input data remains as `f32`, causing type mismatch during forward pass
|
||||
- TPU XLA compiler is strict about type matching
|
||||
- Input data was loaded as `f32` and needed conversion
|
||||
- More critically: Model parameters were initialized as `f32` by default
|
||||
- TPU XLA compiler is strict about type matching across all tensors
|
||||
|
||||
### Solution: Comprehensive Data Type Conversion in Dataset
|
||||
Fixed in `dataset.py` with two changes:
|
||||
### Solution: Comprehensive Data Type Conversion at All Levels
|
||||
|
||||
**1. Convert input data to bf16 (line 130):**
|
||||
**1. Convert input data to bf16 in dataset.py (line 130):**
|
||||
```python
|
||||
# Before (causes type mismatch):
|
||||
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
|
||||
@@ -127,7 +127,7 @@ input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
|
||||
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
|
||||
```
|
||||
|
||||
**2. Preserve bf16 dtype after padding (line 149):**
|
||||
**2. Preserve bf16 dtype after padding in dataset.py (line 149):**
|
||||
```python
|
||||
# Before (pad_sequence converts back to f32):
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
|
||||
@@ -136,7 +136,23 @@ batch['input_features'] = pad_sequence(batch['input_features'], batch_first = Tr
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||||
```
|
||||
|
||||
**Root Cause**: `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16.
|
||||
**3. Fix model parameter initialization in rnn_model.py:**
|
||||
```python
|
||||
# Before (defaults to f32):
|
||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||
|
||||
# After (explicit bf16):
|
||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
|
||||
```
|
||||
|
||||
**Root Causes Identified**:
|
||||
- `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16
|
||||
- `torch.eye()` and `torch.zeros()` default to f32 unless explicit dtype is specified
|
||||
- All tensor creation points must explicitly specify `dtype=torch.bfloat16` for mixed precision consistency
|
||||
|
||||
### Final Implementation
|
||||
```python
|
||||
@@ -167,26 +183,53 @@ self.train_loader = DataLoader(
|
||||
|
||||
## Complete Solution Summary
|
||||
|
||||
### Three-Step Fix for TPU Training
|
||||
### Four-Step Fix for TPU Training
|
||||
1. **DataLoaderConfiguration**: Added `even_batches=False` for batch_size=1 DataLoaders
|
||||
2. **Custom collate_fn**: Handles pre-batched data from our dataset
|
||||
3. **Data Type Conversion**: Convert input data to `bf16` for mixed precision compatibility
|
||||
3. **Data Type Conversion (Dataset)**: Convert input data to `bf16` for mixed precision compatibility
|
||||
4. **Data Type Conversion (Model)**: Fix all model parameter initialization to use explicit `bf16` dtype
|
||||
|
||||
### Files Modified
|
||||
### Files Modified - COMPREHENSIVE SOLUTION ✅
|
||||
- [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration
|
||||
- [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1
|
||||
- [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16
|
||||
- [dataset.py:149](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L149): Preserve bf16 dtype after padding
|
||||
- **[rnn_model.py:28-29](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L28-L29)**: Fixed NoiseModel day weights/biases dtype
|
||||
- **[rnn_model.py:55](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L55)**: Fixed NoiseModel h0 dtype
|
||||
- **[rnn_model.py:113-114](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L113-L114)**: Fixed CleanSpeechModel day weights/biases dtype
|
||||
- **[rnn_model.py:144](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L144)**: Fixed CleanSpeechModel h0 dtype
|
||||
- **[rnn_model.py:232](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L232)**: Fixed NoisySpeechModel h0 dtype
|
||||
|
||||
### Next Steps
|
||||
1. ~~Implement even_batches=False~~ ✅ DONE
|
||||
2. ~~Fix batch_sampler None issue~~ ✅ DONE
|
||||
3. ~~Fix data type mismatch~~ ✅ DONE
|
||||
4. Test TPU training with complete solution
|
||||
5. Integrate final solution into CLAUDE.md
|
||||
3. ~~Fix data type mismatch (dataset level)~~ ✅ DONE
|
||||
4. ~~Fix data type mismatch (model parameter level)~~ ✅ DONE
|
||||
5. **READY**: Test TPU training with comprehensive dtype solution
|
||||
6. Update CLAUDE.md with final TPU training guidance
|
||||
|
||||
## Final Status Update (2025-10-12 14:30)
|
||||
|
||||
🎯 **COMPREHENSIVE SOLUTION COMPLETED**
|
||||
|
||||
All TPU training issues have been systematically identified and fixed:
|
||||
|
||||
✅ **Problem 1**: `even_batches` error → Fixed with DataLoaderConfiguration
|
||||
✅ **Problem 2**: `batch_sampler=None` error → Fixed with custom collate_fn + batch_size=1
|
||||
✅ **Problem 3**: Data type mismatch (dataset) → Fixed bf16 conversion + padding preservation
|
||||
✅ **Problem 4**: Data type mismatch (model) → Fixed all parameter initialization with explicit bf16 dtype
|
||||
|
||||
**The solution addresses dtype consistency at ALL levels**:
|
||||
- Input data loading: `.to(torch.bfloat16)`
|
||||
- Padding operations: explicit bf16 preservation
|
||||
- Model parameters: `torch.eye(..., dtype=torch.bfloat16)` and `torch.zeros(..., dtype=torch.bfloat16)`
|
||||
|
||||
**Ready for TPU training test** with 687M parameter brain-to-text model.
|
||||
|
||||
## Lessons Learned
|
||||
- Don't overcomplicate TPU conversion - it should be straightforward
|
||||
- **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors
|
||||
- **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype
|
||||
- **Documentation**: Record issues immediately to avoid repeated debugging cycles
|
||||
- Don't overcomplicate TPU conversion - identify systematic dtype issues
|
||||
- Read Accelerate documentation carefully for parameter placement
|
||||
- Document issues immediately to avoid confusion
|
||||
- TPU memory allocation: fewer cores = less total memory
|
@@ -25,8 +25,8 @@ class NoiseModel(nn.Module):
|
||||
|
||||
# Day-specific input layers
|
||||
self.day_layer_activation = nn.Softsign()
|
||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
||||
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||
|
||||
# Calculate input size after patching
|
||||
@@ -52,7 +52,7 @@ class NoiseModel(nn.Module):
|
||||
nn.init.xavier_uniform_(param)
|
||||
|
||||
# Learnable initial hidden state
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size, dtype=torch.bfloat16)))
|
||||
|
||||
def forward(self, x, day_idx, states=None):
|
||||
# Apply day-specific transformation
|
||||
@@ -110,8 +110,8 @@ class CleanSpeechModel(nn.Module):
|
||||
|
||||
# Day-specific input layers
|
||||
self.day_layer_activation = nn.Softsign()
|
||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
||||
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||
|
||||
# Calculate input size after patching
|
||||
@@ -141,7 +141,7 @@ class CleanSpeechModel(nn.Module):
|
||||
nn.init.xavier_uniform_(self.out.weight)
|
||||
|
||||
# Learnable initial hidden state
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
|
||||
|
||||
def forward(self, x, day_idx, states=None, return_state=False):
|
||||
# Apply day-specific transformation
|
||||
@@ -229,7 +229,7 @@ class NoisySpeechModel(nn.Module):
|
||||
nn.init.xavier_uniform_(self.out.weight)
|
||||
|
||||
# Learnable initial hidden state
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
|
||||
|
||||
def forward(self, x, states=None, return_state=False):
|
||||
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
|
||||
|
Reference in New Issue
Block a user