# TPU Training Issues Record ## Core Problem **Primary Error**: `ValueError: You need to use 'even_batches=False' when the batch sampler has no batch size` This error occurs when using TPU with Hugging Face Accelerate framework and custom DataLoaders that have `batch_size=None`. ## Root Cause Analysis 1. Our custom dataset returns full batches (not individual samples) 2. DataLoader is created with `batch_size=None` because batching is handled by the dataset 3. TPU training with Accelerate requires `even_batches=False` for this configuration 4. The `even_batches` parameter needs to be set in the DataLoader preparation, not Accelerator initialization ## Failed Solution Attempts ### Attempt 1: Adding even_batches to Accelerator.__init__() ```python self.accelerator = Accelerator( mixed_precision='bf16', gradient_accumulation_steps=1, even_batches=False # ❌ WRONG - This parameter doesn't exist in Accelerator.__init__() ) ``` **Error**: `TypeError: Accelerator.__init__() got an unexpected keyword argument 'even_batches'` ### Attempt 2: Complex TPU-specific DataLoader handling - Created conditional TPU/GPU logic - Manual data movement with `to(device)` - Custom collate_fn modifications - Result: Overengineered solution that didn't address root cause ### Attempt 3: Memory optimization - Reduced TPU cores from 8 to 2 - Reduced batch size - Misunderstood TPU memory allocation (fewer cores = less total memory, not more per core) 我很不希望这么做,至少减少核心会减少算力! ### Attempt 4: Removing all TPU-specific logic - Let Accelerator handle everything automatically - Result: Same even_batches error returned ## Correct Solution The `even_batches=False` parameter should be passed using `DataLoaderConfiguration` when initializing the Accelerator: ```python from accelerate import Accelerator, DataLoaderConfiguration # Configure DataLoader behavior for TPU dataloader_config = DataLoaderConfiguration( even_batches=False # Required for batch_size=None DataLoaders ) self.accelerator = Accelerator( mixed_precision='bf16' if args.get('use_amp', True) else 'no', gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), log_with=None, project_dir=args.get('output_dir', './output'), dataloader_config=dataloader_config # ✅ CORRECT - Pass DataLoaderConfiguration ) ``` ## Technical Context - **Model**: Brain-to-text RNN with 687M parameters - **Dataset**: Custom dataset that returns full batches (batch_size=None in DataLoader) - **TPU Config**: 8 cores × 16GB = 128GB total memory - **Batch Size**: 64 - **Framework**: PyTorch XLA with Hugging Face Accelerate ## Key Files Modified - `model_training_nnn/rnn_trainer.py` - Main trainer class - `model_training_nnn/rnn_args.yaml` - Configuration file - `model_training_nnn/dataset.py` - Custom dataset class ## Memory Allocation Facts - TPU v5e-8: 8 cores × 16GB = 128GB total - Fewer cores = LESS total memory (not more per core) ## Latest Status (2025-10-12) ### After DataLoaderConfiguration Fix ✅ **even_batches Error RESOLVED** - No more `ValueError: You need to use 'even_batches=False'` ❌ **NEW ERROR**: `TypeError: 'NoneType' object is not iterable` ``` File "/usr/local/lib/python3.12/site-packages/accelerate/data_loader.py", line 221, in _iter_with_no_split for idx, batch in enumerate(self.batch_sampler): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: 'NoneType' object is not iterable ``` **Root Cause**: `batch_sampler` becomes `None` when our DataLoader has `batch_size=None` ### Current Investigation - The issue is in Accelerate's data_loader.py line 221 - Our custom dataset returns full batches, so we use `batch_size=None` in DataLoader - But Accelerate expects a proper batch_sampler when iterating - This is a fundamental incompatibility between our batching approach and Accelerate's expectations ## COMPREHENSIVE SOLUTION ✅ (v2.0) ### Problem Resolution Status 1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration 2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn 3. ~~Data Type Mismatch Error~~ ✅ RESOLVED - Fixed both input conversion and padding dtype preservation ### Latest Error (2025-10-12 13:38) ``` INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[64,7168], argument shape: bf16[64,7168]. ``` **Root Cause**: Mixed precision training with `mixed_precision='bf16'` expects all tensors to be `bf16`, but tensors were being created as `f32` (float32) at multiple levels. **Analysis**: - We enabled `bf16` mixed precision in Accelerator configuration - Input data was loaded as `f32` and needed conversion - More critically: Model parameters were initialized as `f32` by default - TPU XLA compiler is strict about type matching across all tensors ### Solution: Comprehensive Data Type Conversion at All Levels **1. Convert input data to bf16 in dataset.py (line 130):** ```python # Before (causes type mismatch): input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32 # After (TPU compatible): input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility ``` **2. Preserve bf16 dtype after padding in dataset.py (line 149):** ```python # Before (pad_sequence converts back to f32): batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0) # After (explicitly maintain bf16): batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16) ``` **3. Fix model parameter initialization in rnn_model.py:** ```python # Before (defaults to f32): self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) # After (explicit bf16): self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)]) self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)]) self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16))) ``` **Root Causes Identified**: - `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16 - `torch.eye()` and `torch.zeros()` default to f32 unless explicit dtype is specified - All tensor creation points must explicitly specify `dtype=torch.bfloat16` for mixed precision consistency ### Final Implementation ```python # In rnn_trainer.py prepare_dataloaders() # Custom collate function that handles pre-batched data from our dataset def collate_fn(batch): # Our dataset returns full batches, so batch will be a list of single batch dict # Extract the first (and only) element since our dataset.__getitem__() returns a full batch if len(batch) == 1 and isinstance(batch[0], dict): return batch[0] else: # Fallback for unexpected batch structure return batch # DataLoader configuration compatible with Accelerate self.train_loader = DataLoader( self.train_dataset, batch_size = 1, # Use batch_size=1 since dataset returns full batches shuffle = shuffle_setting, num_workers = workers_setting, pin_memory = True, collate_fn = collate_fn ) ``` **Key Insight**: Our dataset's `__getitem__()` returns complete batches, but Accelerate expects individual samples. The solution is to use `batch_size=1` and a custom `collate_fn` that unwraps the pre-batched data. ## Complete Solution Summary ### Four-Step Fix for TPU Training 1. **DataLoaderConfiguration**: Added `even_batches=False` for batch_size=1 DataLoaders 2. **Custom collate_fn**: Handles pre-batched data from our dataset 3. **Data Type Conversion (Dataset)**: Convert input data to `bf16` for mixed precision compatibility 4. **Data Type Conversion (Model)**: Fix all model parameter initialization to use explicit `bf16` dtype ### Files Modified - COMPREHENSIVE SOLUTION ✅ - [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration - [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1 - [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16 - [dataset.py:149](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L149): Preserve bf16 dtype after padding - **[rnn_model.py:28-29](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L28-L29)**: Fixed NoiseModel day weights/biases dtype - **[rnn_model.py:55](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L55)**: Fixed NoiseModel h0 dtype - **[rnn_model.py:113-114](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L113-L114)**: Fixed CleanSpeechModel day weights/biases dtype - **[rnn_model.py:144](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L144)**: Fixed CleanSpeechModel h0 dtype - **[rnn_model.py:232](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_model.py#L232)**: Fixed NoisySpeechModel h0 dtype ### Next Steps 1. ~~Implement even_batches=False~~ ✅ DONE 2. ~~Fix batch_sampler None issue~~ ✅ DONE 3. ~~Fix data type mismatch (dataset level)~~ ✅ DONE 4. ~~Fix data type mismatch (model parameter level)~~ ✅ DONE 5. **READY**: Test TPU training with comprehensive dtype solution 6. Update CLAUDE.md with final TPU training guidance ## Final Status Update (2025-10-12 14:30) 🎯 **COMPREHENSIVE SOLUTION COMPLETED** All TPU training issues have been systematically identified and fixed: ✅ **Problem 1**: `even_batches` error → Fixed with DataLoaderConfiguration ✅ **Problem 2**: `batch_sampler=None` error → Fixed with custom collate_fn + batch_size=1 ✅ **Problem 3**: Data type mismatch (dataset) → Fixed bf16 conversion + padding preservation ✅ **Problem 4**: Data type mismatch (model) → Fixed all parameter initialization with explicit bf16 dtype **The solution addresses dtype consistency at ALL levels**: - Input data loading: `.to(torch.bfloat16)` - Padding operations: explicit bf16 preservation - Model parameters: `torch.eye(..., dtype=torch.bfloat16)` and `torch.zeros(..., dtype=torch.bfloat16)` **Ready for TPU training test** with 687M parameter brain-to-text model. --- ## New Issue: TPU Memory Exhaustion (2025-10-12 15:00) ``` RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 3.50M. That was not possible. There are 2.07M free.; (0x0x0_HBM0) ``` **Root Cause**: TPU HBM memory fragmentation with batch_size=64 - Single batch: 64 × (512 features × 14 patches) × 2 bytes = ~917KB per batch - Combined with 687M model parameters + gradients + activations → memory exhaustion - TPU memory allocation is stricter than GPU, requires contiguous blocks **Solution**: Memory-optimized configuration ```yaml # rnn_args.yaml optimizations: batch_size: 32 # reduced from 64 gradient_accumulation_steps: 2 # maintains effective batch size of 64 num_dataloader_workers: 0 # TPU compatibility ``` **Memory Calculation**: - New batch memory: 32 × 7168 × 2 bytes = ~458KB (50% reduction) - Gradient accumulation maintains training stability - Effective batch size unchanged: 2 steps × 32 = 64 samples ## Lessons Learned - **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors - **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype - **Documentation**: Record issues immediately to avoid repeated debugging cycles - Don't overcomplicate TPU conversion - identify systematic dtype issues - Read Accelerate documentation carefully for parameter placement - TPU memory allocation: fewer cores = less total memory