From 5c941d9efaa75aa8a54af640f053e257b0e417b8 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 12 Oct 2025 22:52:38 +0800 Subject: [PATCH] tpu --- TPU_ISSUES_RECORD.md | 31 +++++++++++++++++++++++++++++++ model_training_nnn/rnn_model.py | 18 ++++++++++-------- model_training_nnn/rnn_trainer.py | 10 ++++++---- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/TPU_ISSUES_RECORD.md b/TPU_ISSUES_RECORD.md index d4a94a0..992ac7b 100644 --- a/TPU_ISSUES_RECORD.md +++ b/TPU_ISSUES_RECORD.md @@ -319,6 +319,37 @@ if xm.get_xla_supported_devices(): **预期改进**: XLA图编译时间从5-15分钟缩短到2-8分钟 +## New Issue: DType Mismatch in adjusted_lens Calculation (2025-10-12 16:45) + +### Error Description +``` +Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 1 shape: f32[21504], argument shape: bf16[21504]. +``` + +### Root Cause +The `adjusted_lens` calculation was causing dtype mismatches in TPU mixed precision (bf16) training. When `n_time_steps` is processed under `accelerator.autocast()`, it becomes bfloat16, but the arithmetic operations were creating float32 results. + +### Problem Code +```python +# Before (causes f32/bf16 mismatch): +adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) +``` + +### Solution +Explicit float conversion before dtype casting: + +```python +# After (explicit dtype control): +adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) +``` + +### Fixed Locations +- `rnn_trainer.py:577` - Training loop +- `rnn_trainer.py:753` - Validation loop +- `rnn_trainer.py:851` - Inference batch function + +**Key Insight**: Mixed precision training requires explicit dtype management for ALL tensor operations, even intermediate calculations. + ## Lessons Learned - **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors - **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype diff --git a/model_training_nnn/rnn_model.py b/model_training_nnn/rnn_model.py index 8666aed..e5b99bf 100644 --- a/model_training_nnn/rnn_model.py +++ b/model_training_nnn/rnn_model.py @@ -25,8 +25,9 @@ class NoiseModel(nn.Module): # Day-specific input layers self.day_layer_activation = nn.Softsign() - self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)]) - self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)]) + # Let Accelerator handle dtype automatically for TPU compatibility + self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) + self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) self.day_layer_dropout = nn.Dropout(input_dropout) # Calculate input size after patching @@ -51,8 +52,8 @@ class NoiseModel(nn.Module): if "weight_ih" in name: nn.init.xavier_uniform_(param) - # Learnable initial hidden state - self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size, dtype=torch.bfloat16))) + # Learnable initial hidden state - let Accelerator handle dtype + self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size))) def forward(self, x, day_idx, states=None): # Apply day-specific transformation @@ -110,8 +111,9 @@ class CleanSpeechModel(nn.Module): # Day-specific input layers self.day_layer_activation = nn.Softsign() - self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)]) - self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)]) + # Let Accelerator handle dtype automatically for TPU compatibility + self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) + self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) self.day_layer_dropout = nn.Dropout(input_dropout) # Calculate input size after patching @@ -141,7 +143,7 @@ class CleanSpeechModel(nn.Module): nn.init.xavier_uniform_(self.out.weight) # Learnable initial hidden state - self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16))) + self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) def forward(self, x, day_idx, states=None, return_state=False): # Apply day-specific transformation @@ -229,7 +231,7 @@ class NoisySpeechModel(nn.Module): nn.init.xavier_uniform_(self.out.weight) # Learnable initial hidden state - self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16))) + self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) def forward(self, x, states=None, return_state=False): # Note: NoisySpeechModel doesn't need day-specific layers as it processes noise diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 6009717..82e3476 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -573,7 +573,8 @@ class BrainToTextDecoder_Trainer: # Apply augmentations to the data features, n_time_steps = self.transform_data(features, n_time_steps, 'train') - adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) + # Ensure proper dtype handling for TPU mixed precision + adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Get phoneme predictions using inference mode during training # (We use inference mode for simplicity - only clean logits are used for CTC loss) @@ -748,7 +749,8 @@ class BrainToTextDecoder_Trainer: with self.accelerator.autocast(): features, n_time_steps = self.transform_data(features, n_time_steps, 'val') - adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) + # Ensure proper dtype handling for TPU mixed precision + adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) logits = self.model(features, day_indicies, None, False, 'inference') @@ -845,8 +847,8 @@ class BrainToTextDecoder_Trainer: # Apply data transformations (no augmentation for inference) features, n_time_steps = self.transform_data(features, n_time_steps, 'val') - # Calculate adjusted sequence lengths for CTC - adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) + # Calculate adjusted sequence lengths for CTC with proper dtype handling + adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode)