3.~~Data Type Mismatch Error~~ ✅ RESOLVED with bf16 conversion in dataset
### Latest Error (2025-10-12 13:38)
```
INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[64,7168], argument shape: bf16[64,7168].
```
**Root Cause**: Mixed precision training with `mixed_precision='bf16'` expects all tensors to be `bf16`, but our data is being loaded as `f32` (float32).
**Analysis**:
- We enabled `bf16` mixed precision in Accelerator configuration
- Model parameters are automatically converted to `bf16`
- But input data remains as `f32`, causing type mismatch during forward pass
- TPU XLA compiler is strict about type matching
### Solution: Data Type Conversion in Dataset
Fixed in `dataset.py:130` by converting neural data to `bf16`:
```python
# Before (causes type mismatch):
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
# After (TPU compatible):
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
# Our dataset returns full batches, so batch will be a list of single batch dict
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
if len(batch) == 1 and isinstance(batch[0], dict):
return batch[0]
else:
# Fallback for unexpected batch structure
return batch
# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
self.train_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = shuffle_setting,
num_workers = workers_setting,
pin_memory = True,
collate_fn = collate_fn
)
```
**Key Insight**: Our dataset's `__getitem__()` returns complete batches, but Accelerate expects individual samples. The solution is to use `batch_size=1` and a custom `collate_fn` that unwraps the pre-batched data.