tpu
This commit is contained in:
@@ -95,12 +95,12 @@ TypeError: 'NoneType' object is not iterable
|
|||||||
- But Accelerate expects a proper batch_sampler when iterating
|
- But Accelerate expects a proper batch_sampler when iterating
|
||||||
- This is a fundamental incompatibility between our batching approach and Accelerate's expectations
|
- This is a fundamental incompatibility between our batching approach and Accelerate's expectations
|
||||||
|
|
||||||
## COMPREHENSIVE SOLUTION ✅
|
## COMPREHENSIVE SOLUTION ✅ (v2.0)
|
||||||
|
|
||||||
### Problem Resolution Status
|
### Problem Resolution Status
|
||||||
1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration
|
1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration
|
||||||
2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn
|
2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn
|
||||||
3. ~~Data Type Mismatch Error~~ ✅ RESOLVED with bf16 conversion in dataset
|
3. ~~Data Type Mismatch Error~~ ✅ RESOLVED - Fixed both input conversion and padding dtype preservation
|
||||||
|
|
||||||
### Latest Error (2025-10-12 13:38)
|
### Latest Error (2025-10-12 13:38)
|
||||||
```
|
```
|
||||||
@@ -115,8 +115,10 @@ INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32
|
|||||||
- But input data remains as `f32`, causing type mismatch during forward pass
|
- But input data remains as `f32`, causing type mismatch during forward pass
|
||||||
- TPU XLA compiler is strict about type matching
|
- TPU XLA compiler is strict about type matching
|
||||||
|
|
||||||
### Solution: Data Type Conversion in Dataset
|
### Solution: Comprehensive Data Type Conversion in Dataset
|
||||||
Fixed in `dataset.py:130` by converting neural data to `bf16`:
|
Fixed in `dataset.py` with two changes:
|
||||||
|
|
||||||
|
**1. Convert input data to bf16 (line 130):**
|
||||||
```python
|
```python
|
||||||
# Before (causes type mismatch):
|
# Before (causes type mismatch):
|
||||||
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
|
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
|
||||||
@@ -125,6 +127,17 @@ input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
|
|||||||
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
|
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**2. Preserve bf16 dtype after padding (line 149):**
|
||||||
|
```python
|
||||||
|
# Before (pad_sequence converts back to f32):
|
||||||
|
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
|
||||||
|
|
||||||
|
# After (explicitly maintain bf16):
|
||||||
|
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Root Cause**: `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16.
|
||||||
|
|
||||||
### Final Implementation
|
### Final Implementation
|
||||||
```python
|
```python
|
||||||
# In rnn_trainer.py prepare_dataloaders()
|
# In rnn_trainer.py prepare_dataloaders()
|
||||||
@@ -163,6 +176,7 @@ self.train_loader = DataLoader(
|
|||||||
- [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration
|
- [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration
|
||||||
- [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1
|
- [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1
|
||||||
- [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16
|
- [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16
|
||||||
|
- [dataset.py:149](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L149): Preserve bf16 dtype after padding
|
||||||
|
|
||||||
### Next Steps
|
### Next Steps
|
||||||
1. ~~Implement even_batches=False~~ ✅ DONE
|
1. ~~Implement even_batches=False~~ ✅ DONE
|
||||||
|
@@ -145,8 +145,8 @@ class BrainToTextDataset(Dataset):
|
|||||||
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
|
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Pad data to form a cohesive batch
|
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
|
||||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
|
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||||||
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
|
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
|
||||||
|
|
||||||
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
|
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
|
||||||
|
Reference in New Issue
Block a user