tpu
This commit is contained in:
@@ -95,12 +95,12 @@ TypeError: 'NoneType' object is not iterable
|
||||
- But Accelerate expects a proper batch_sampler when iterating
|
||||
- This is a fundamental incompatibility between our batching approach and Accelerate's expectations
|
||||
|
||||
## COMPREHENSIVE SOLUTION ✅
|
||||
## COMPREHENSIVE SOLUTION ✅ (v2.0)
|
||||
|
||||
### Problem Resolution Status
|
||||
1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration
|
||||
2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn
|
||||
3. ~~Data Type Mismatch Error~~ ✅ RESOLVED with bf16 conversion in dataset
|
||||
3. ~~Data Type Mismatch Error~~ ✅ RESOLVED - Fixed both input conversion and padding dtype preservation
|
||||
|
||||
### Latest Error (2025-10-12 13:38)
|
||||
```
|
||||
@@ -115,8 +115,10 @@ INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32
|
||||
- But input data remains as `f32`, causing type mismatch during forward pass
|
||||
- TPU XLA compiler is strict about type matching
|
||||
|
||||
### Solution: Data Type Conversion in Dataset
|
||||
Fixed in `dataset.py:130` by converting neural data to `bf16`:
|
||||
### Solution: Comprehensive Data Type Conversion in Dataset
|
||||
Fixed in `dataset.py` with two changes:
|
||||
|
||||
**1. Convert input data to bf16 (line 130):**
|
||||
```python
|
||||
# Before (causes type mismatch):
|
||||
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
|
||||
@@ -125,6 +127,17 @@ input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
|
||||
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
|
||||
```
|
||||
|
||||
**2. Preserve bf16 dtype after padding (line 149):**
|
||||
```python
|
||||
# Before (pad_sequence converts back to f32):
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
|
||||
|
||||
# After (explicitly maintain bf16):
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||||
```
|
||||
|
||||
**Root Cause**: `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16.
|
||||
|
||||
### Final Implementation
|
||||
```python
|
||||
# In rnn_trainer.py prepare_dataloaders()
|
||||
@@ -163,6 +176,7 @@ self.train_loader = DataLoader(
|
||||
- [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration
|
||||
- [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1
|
||||
- [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16
|
||||
- [dataset.py:149](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L149): Preserve bf16 dtype after padding
|
||||
|
||||
### Next Steps
|
||||
1. ~~Implement even_batches=False~~ ✅ DONE
|
||||
|
@@ -145,8 +145,8 @@ class BrainToTextDataset(Dataset):
|
||||
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
|
||||
continue
|
||||
|
||||
# Pad data to form a cohesive batch
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
|
||||
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||||
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
|
||||
|
||||
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
|
||||
|
Reference in New Issue
Block a user