From 4dad570eea3f7b849c9fc21e338f0af940297f56 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 12 Oct 2025 21:47:30 +0800 Subject: [PATCH] tpu --- TPU_ISSUES_RECORD.md | 22 ++++++++++++++++++---- model_training_nnn/dataset.py | 4 ++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/TPU_ISSUES_RECORD.md b/TPU_ISSUES_RECORD.md index e5aeceb..f02ad5b 100644 --- a/TPU_ISSUES_RECORD.md +++ b/TPU_ISSUES_RECORD.md @@ -95,12 +95,12 @@ TypeError: 'NoneType' object is not iterable - But Accelerate expects a proper batch_sampler when iterating - This is a fundamental incompatibility between our batching approach and Accelerate's expectations -## COMPREHENSIVE SOLUTION ✅ +## COMPREHENSIVE SOLUTION ✅ (v2.0) ### Problem Resolution Status 1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration 2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn -3. ~~Data Type Mismatch Error~~ ✅ RESOLVED with bf16 conversion in dataset +3. ~~Data Type Mismatch Error~~ ✅ RESOLVED - Fixed both input conversion and padding dtype preservation ### Latest Error (2025-10-12 13:38) ``` @@ -115,8 +115,10 @@ INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32 - But input data remains as `f32`, causing type mismatch during forward pass - TPU XLA compiler is strict about type matching -### Solution: Data Type Conversion in Dataset -Fixed in `dataset.py:130` by converting neural data to `bf16`: +### Solution: Comprehensive Data Type Conversion in Dataset +Fixed in `dataset.py` with two changes: + +**1. Convert input data to bf16 (line 130):** ```python # Before (causes type mismatch): input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32 @@ -125,6 +127,17 @@ input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32 input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility ``` +**2. Preserve bf16 dtype after padding (line 149):** +```python +# Before (pad_sequence converts back to f32): +batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0) + +# After (explicitly maintain bf16): +batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16) +``` + +**Root Cause**: `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16. + ### Final Implementation ```python # In rnn_trainer.py prepare_dataloaders() @@ -163,6 +176,7 @@ self.train_loader = DataLoader( - [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration - [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1 - [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16 +- [dataset.py:149](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L149): Preserve bf16 dtype after padding ### Next Steps 1. ~~Implement even_batches=False~~ ✅ DONE diff --git a/model_training_nnn/dataset.py b/model_training_nnn/dataset.py index e964676..086370e 100644 --- a/model_training_nnn/dataset.py +++ b/model_training_nnn/dataset.py @@ -145,8 +145,8 @@ class BrainToTextDataset(Dataset): print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}') continue - # Pad data to form a cohesive batch - batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0) + # Pad data to form a cohesive batch - ensure bf16 dtype is preserved + batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16) batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0) batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])