7.9 KiB
TPU Training Issues Record
Core Problem
Primary Error: ValueError: You need to use 'even_batches=False' when the batch sampler has no batch size
This error occurs when using TPU with Hugging Face Accelerate framework and custom DataLoaders that have batch_size=None
.
Root Cause Analysis
- Our custom dataset returns full batches (not individual samples)
- DataLoader is created with
batch_size=None
because batching is handled by the dataset - TPU training with Accelerate requires
even_batches=False
for this configuration - The
even_batches
parameter needs to be set in the DataLoader preparation, not Accelerator initialization
Failed Solution Attempts
Attempt 1: Adding even_batches to Accelerator.init()
self.accelerator = Accelerator(
mixed_precision='bf16',
gradient_accumulation_steps=1,
even_batches=False # ❌ WRONG - This parameter doesn't exist in Accelerator.__init__()
)
Error: TypeError: Accelerator.__init__() got an unexpected keyword argument 'even_batches'
Attempt 2: Complex TPU-specific DataLoader handling
- Created conditional TPU/GPU logic
- Manual data movement with
to(device)
- Custom collate_fn modifications
- Result: Overengineered solution that didn't address root cause
Attempt 3: Memory optimization
- Reduced TPU cores from 8 to 2
- Reduced batch size
- Misunderstood TPU memory allocation (fewer cores = less total memory, not more per core)
Attempt 4: Removing all TPU-specific logic
- Let Accelerator handle everything automatically
- Result: Same even_batches error returned
Correct Solution
The even_batches=False
parameter should be passed using DataLoaderConfiguration
when initializing the Accelerator:
from accelerate import Accelerator, DataLoaderConfiguration
# Configure DataLoader behavior for TPU
dataloader_config = DataLoaderConfiguration(
even_batches=False # Required for batch_size=None DataLoaders
)
self.accelerator = Accelerator(
mixed_precision='bf16' if args.get('use_amp', True) else 'no',
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
log_with=None,
project_dir=args.get('output_dir', './output'),
dataloader_config=dataloader_config # ✅ CORRECT - Pass DataLoaderConfiguration
)
Technical Context
- Model: Brain-to-text RNN with 687M parameters
- Dataset: Custom dataset that returns full batches (batch_size=None in DataLoader)
- TPU Config: 8 cores × 16GB = 128GB total memory
- Batch Size: 64
- Framework: PyTorch XLA with Hugging Face Accelerate
Key Files Modified
model_training_nnn/rnn_trainer.py
- Main trainer classmodel_training_nnn/rnn_args.yaml
- Configuration filemodel_training_nnn/dataset.py
- Custom dataset class
Memory Allocation Facts
- TPU v5e-8: 8 cores × 16GB = 128GB total
- Fewer cores = LESS total memory (not more per core)
Latest Status (2025-10-12)
After DataLoaderConfiguration Fix
✅ even_batches Error RESOLVED - No more ValueError: You need to use 'even_batches=False'
❌ NEW ERROR: TypeError: 'NoneType' object is not iterable
File "/usr/local/lib/python3.12/site-packages/accelerate/data_loader.py", line 221, in _iter_with_no_split
for idx, batch in enumerate(self.batch_sampler):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not iterable
Root Cause: batch_sampler
becomes None
when our DataLoader has batch_size=None
Current Investigation
- The issue is in Accelerate's data_loader.py line 221
- Our custom dataset returns full batches, so we use
batch_size=None
in DataLoader - But Accelerate expects a proper batch_sampler when iterating
- This is a fundamental incompatibility between our batching approach and Accelerate's expectations
COMPREHENSIVE SOLUTION ✅ (v2.0)
Problem Resolution Status
even_batches Error✅ RESOLVED with DataLoaderConfigurationbatch_sampler None Error✅ RESOLVED with custom collate_fnData Type Mismatch Error✅ RESOLVED - Fixed both input conversion and padding dtype preservation
Latest Error (2025-10-12 13:38)
INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[64,7168], argument shape: bf16[64,7168].
Root Cause: Mixed precision training with mixed_precision='bf16'
expects all tensors to be bf16
, but our data is being loaded as f32
(float32).
Analysis:
- We enabled
bf16
mixed precision in Accelerator configuration - Model parameters are automatically converted to
bf16
- But input data remains as
f32
, causing type mismatch during forward pass - TPU XLA compiler is strict about type matching
Solution: Comprehensive Data Type Conversion in Dataset
Fixed in dataset.py
with two changes:
1. Convert input data to bf16 (line 130):
# Before (causes type mismatch):
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
# After (TPU compatible):
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
2. Preserve bf16 dtype after padding (line 149):
# Before (pad_sequence converts back to f32):
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
# After (explicitly maintain bf16):
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
Root Cause: pad_sequence
function resets dtype to default (f32) even if input tensors are bf16.
Final Implementation
# In rnn_trainer.py prepare_dataloaders()
# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
# Our dataset returns full batches, so batch will be a list of single batch dict
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
if len(batch) == 1 and isinstance(batch[0], dict):
return batch[0]
else:
# Fallback for unexpected batch structure
return batch
# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
self.train_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = shuffle_setting,
num_workers = workers_setting,
pin_memory = True,
collate_fn = collate_fn
)
Key Insight: Our dataset's __getitem__()
returns complete batches, but Accelerate expects individual samples. The solution is to use batch_size=1
and a custom collate_fn
that unwraps the pre-batched data.
Complete Solution Summary
Three-Step Fix for TPU Training
- DataLoaderConfiguration: Added
even_batches=False
for batch_size=1 DataLoaders - Custom collate_fn: Handles pre-batched data from our dataset
- Data Type Conversion: Convert input data to
bf16
for mixed precision compatibility
Files Modified
- rnn_trainer.py:44-46: Added DataLoaderConfiguration
- rnn_trainer.py:193-210: Custom collate_fn and batch_size=1
- dataset.py:130: Convert neural data to bf16
- dataset.py:149: Preserve bf16 dtype after padding
Next Steps
Implement even_batches=False✅ DONEFix batch_sampler None issue✅ DONEFix data type mismatch✅ DONE- Test TPU training with complete solution
- Integrate final solution into CLAUDE.md
Lessons Learned
- Don't overcomplicate TPU conversion - it should be straightforward
- Read Accelerate documentation carefully for parameter placement
- Document issues immediately to avoid confusion
- TPU memory allocation: fewer cores = less total memory