tpu
This commit is contained in:
@@ -74,10 +74,65 @@ self.accelerator = Accelerator(
|
|||||||
- TPU v5e-8: 8 cores × 16GB = 128GB total
|
- TPU v5e-8: 8 cores × 16GB = 128GB total
|
||||||
- Fewer cores = LESS total memory (not more per core)
|
- Fewer cores = LESS total memory (not more per core)
|
||||||
|
|
||||||
|
## Latest Status (2025-10-12)
|
||||||
|
|
||||||
|
### After DataLoaderConfiguration Fix
|
||||||
|
✅ **even_batches Error RESOLVED** - No more `ValueError: You need to use 'even_batches=False'`
|
||||||
|
|
||||||
|
❌ **NEW ERROR**: `TypeError: 'NoneType' object is not iterable`
|
||||||
|
```
|
||||||
|
File "/usr/local/lib/python3.12/site-packages/accelerate/data_loader.py", line 221, in _iter_with_no_split
|
||||||
|
for idx, batch in enumerate(self.batch_sampler):
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
TypeError: 'NoneType' object is not iterable
|
||||||
|
```
|
||||||
|
|
||||||
|
**Root Cause**: `batch_sampler` becomes `None` when our DataLoader has `batch_size=None`
|
||||||
|
|
||||||
|
### Current Investigation
|
||||||
|
- The issue is in Accelerate's data_loader.py line 221
|
||||||
|
- Our custom dataset returns full batches, so we use `batch_size=None` in DataLoader
|
||||||
|
- But Accelerate expects a proper batch_sampler when iterating
|
||||||
|
- This is a fundamental incompatibility between our batching approach and Accelerate's expectations
|
||||||
|
|
||||||
|
## FINAL SOLUTION ✅
|
||||||
|
|
||||||
|
### Problem Resolution
|
||||||
|
1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration
|
||||||
|
2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn
|
||||||
|
|
||||||
|
### Final Implementation
|
||||||
|
```python
|
||||||
|
# In rnn_trainer.py prepare_dataloaders()
|
||||||
|
|
||||||
|
# Custom collate function that handles pre-batched data from our dataset
|
||||||
|
def collate_fn(batch):
|
||||||
|
# Our dataset returns full batches, so batch will be a list of single batch dict
|
||||||
|
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
|
||||||
|
if len(batch) == 1 and isinstance(batch[0], dict):
|
||||||
|
return batch[0]
|
||||||
|
else:
|
||||||
|
# Fallback for unexpected batch structure
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# DataLoader configuration compatible with Accelerate
|
||||||
|
self.train_loader = DataLoader(
|
||||||
|
self.train_dataset,
|
||||||
|
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||||
|
shuffle = shuffle_setting,
|
||||||
|
num_workers = workers_setting,
|
||||||
|
pin_memory = True,
|
||||||
|
collate_fn = collate_fn
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Insight**: Our dataset's `__getitem__()` returns complete batches, but Accelerate expects individual samples. The solution is to use `batch_size=1` and a custom `collate_fn` that unwraps the pre-batched data.
|
||||||
|
|
||||||
## Next Steps
|
## Next Steps
|
||||||
1. Implement correct even_batches=False in accelerator.prepare()
|
1. ~~Implement even_batches=False~~ ✅ DONE
|
||||||
2. Test TPU training without overengineering
|
2. ~~Fix batch_sampler None issue~~ ✅ DONE
|
||||||
3. Verify memory usage with 8 cores configuration
|
3. Test TPU training with complete solution
|
||||||
|
4. Integrate final solution into CLAUDE.md
|
||||||
|
|
||||||
## Lessons Learned
|
## Lessons Learned
|
||||||
- Don't overcomplicate TPU conversion - it should be straightforward
|
- Don't overcomplicate TPU conversion - it should be straightforward
|
||||||
|
@@ -189,13 +189,24 @@ class BrainToTextDecoder_Trainer:
|
|||||||
random_seed = self.args['dataset']['seed'],
|
random_seed = self.args['dataset']['seed'],
|
||||||
feature_subset = feature_subset
|
feature_subset = feature_subset
|
||||||
)
|
)
|
||||||
# Standard DataLoader configuration - let Accelerator handle device-specific optimizations
|
# Custom collate function that handles pre-batched data from our dataset
|
||||||
|
def collate_fn(batch):
|
||||||
|
# Our dataset returns full batches, so batch will be a list of single batch dict
|
||||||
|
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
|
||||||
|
if len(batch) == 1 and isinstance(batch[0], dict):
|
||||||
|
return batch[0]
|
||||||
|
else:
|
||||||
|
# Fallback for unexpected batch structure
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# DataLoader configuration compatible with Accelerate
|
||||||
self.train_loader = DataLoader(
|
self.train_loader = DataLoader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||||
shuffle = self.args['dataset']['loader_shuffle'],
|
shuffle = self.args['dataset']['loader_shuffle'],
|
||||||
num_workers = self.args['dataset']['num_dataloader_workers'],
|
num_workers = self.args['dataset']['num_dataloader_workers'],
|
||||||
pin_memory = True
|
pin_memory = True,
|
||||||
|
collate_fn = collate_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
# val dataset and dataloader
|
# val dataset and dataloader
|
||||||
@@ -209,13 +220,14 @@ class BrainToTextDecoder_Trainer:
|
|||||||
random_seed = self.args['dataset']['seed'],
|
random_seed = self.args['dataset']['seed'],
|
||||||
feature_subset = feature_subset
|
feature_subset = feature_subset
|
||||||
)
|
)
|
||||||
# Standard validation DataLoader configuration
|
# Validation DataLoader with same collate function
|
||||||
self.val_loader = DataLoader(
|
self.val_loader = DataLoader(
|
||||||
self.val_dataset,
|
self.val_dataset,
|
||||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||||
shuffle = False,
|
shuffle = False,
|
||||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||||
pin_memory = True
|
pin_memory = True,
|
||||||
|
collate_fn = collate_fn # Use same collate function
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.info("Successfully initialized datasets")
|
self.logger.info("Successfully initialized datasets")
|
||||||
|
Reference in New Issue
Block a user