diff --git a/model_training_nnn/rnn_model.py b/model_training_nnn/rnn_model.py index bb6798a..8cc0219 100644 --- a/model_training_nnn/rnn_model.py +++ b/model_training_nnn/rnn_model.py @@ -1,5 +1,6 @@ import torch from torch import nn +from typing import cast class GradientReversalFn(torch.autograd.Function): """ @@ -106,9 +107,15 @@ class NoiseModel(nn.Module): # Ensure dtype consistency after patch processing operations x = x.to(original_dtype) + gru_dtype = next(self.gru.parameters()).dtype + if x.dtype != gru_dtype: + x = x.to(gru_dtype) + # XLA-friendly hidden state initialization - avoid dynamic allocation if states is None: states = self.h0.expand(2, batch_size, self.input_size).contiguous() + if states.dtype != gru_dtype: + states = states.to(gru_dtype) # GRU forward pass output, hidden_states = self.gru(x, states) @@ -208,9 +215,15 @@ class CleanSpeechModel(nn.Module): # Ensure dtype consistency after patch processing operations x = x.to(original_dtype) + gru_dtype = next(self.gru.parameters()).dtype + if x.dtype != gru_dtype: + x = x.to(gru_dtype) + # XLA-friendly hidden state initialization if states is None: states = self.h0.expand(3, batch_size, self.n_units).contiguous() + if states.dtype != gru_dtype: + states = states.to(gru_dtype) # GRU forward pass output, hidden_states = self.gru(x, states) @@ -280,9 +293,21 @@ class NoisySpeechModel(nn.Module): # Note: NoisySpeechModel doesn't need day-specific layers as it processes noise batch_size = x.size(0) + gru_dtype = next(self.gru.parameters()).dtype + if x.dtype != gru_dtype: + x = x.to(gru_dtype) + + gru_dtype = next(self.gru.parameters()).dtype + if x.dtype != gru_dtype: + x = x.to(gru_dtype) + # XLA-friendly hidden state initialization if states is None: states = self.h0.expand(2, batch_size, self.n_units).contiguous() + if states.dtype != gru_dtype: + states = states.to(gru_dtype) + if states.dtype != gru_dtype: + states = states.to(gru_dtype) # GRU forward pass output, hidden_states = self.gru(x, states) @@ -407,11 +432,16 @@ class TripleGRUDecoder(nn.Module): '''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)''' batch_size = x_processed.size(0) + clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype + if x_processed.dtype != clean_gru_dtype: + x_processed = x_processed.to(clean_gru_dtype) + # XLA-friendly hidden state initialization with dtype consistency if states is None: states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous() # Ensure hidden states match input dtype for mixed precision training - states = states.to(x_processed.dtype) + if states.dtype != clean_gru_dtype: + states = states.to(clean_gru_dtype) # GRU forward pass (skip preprocessing since input is already processed) output, hidden_states = self.clean_speech_model.gru(x_processed, states) @@ -424,11 +454,16 @@ class TripleGRUDecoder(nn.Module): '''Forward pass for NoisySpeechModel with already processed input''' batch_size = x_processed.size(0) + noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype + if x_processed.dtype != noisy_gru_dtype: + x_processed = x_processed.to(noisy_gru_dtype) + # XLA-friendly hidden state initialization with dtype consistency if states is None: states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous() # Ensure hidden states match input dtype for mixed precision training - states = states.to(x_processed.dtype) + if states.dtype != noisy_gru_dtype: + states = states.to(noisy_gru_dtype) # GRU forward pass (NoisySpeechModel doesn't have day layers anyway) output, hidden_states = self.noisy_speech_model.gru(x_processed, states) @@ -458,9 +493,13 @@ class TripleGRUDecoder(nn.Module): # 2. For residual connection, we need x in the same space as noise_output # Apply the same preprocessing that the models use internally x_processed = self._apply_preprocessing(x, day_idx) + clean_dtype = next(self.clean_speech_model.parameters()).dtype + if x_processed.dtype != clean_dtype: + x_processed = x_processed.to(clean_dtype) # Ensure dtype consistency between processed input and noise output - noise_output = noise_output.to(x_processed.dtype) + if noise_output.dtype != clean_dtype: + noise_output = noise_output.to(clean_dtype) # 3. Clean speech model processes denoised signal denoised_input = x_processed - noise_output # Residual connection in processed space @@ -473,9 +512,10 @@ class TripleGRUDecoder(nn.Module): # 4. Noisy speech model processes noise signal directly (no day layers needed) # Optionally apply Gradient Reversal to enforce adversarial training on noise output noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output - # Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility - # Use x_processed.dtype as reference since it's the main data flow dtype - noisy_input = noisy_input.to(x_processed.dtype) + noisy_input = cast(torch.Tensor, noisy_input) + noisy_dtype = next(self.noisy_speech_model.parameters()).dtype + if noisy_input.dtype != noisy_dtype: + noisy_input = noisy_input.to(noisy_dtype) noisy_logits = self._noisy_forward_with_processed_input(noisy_input, states['noisy'] if states else None) @@ -493,9 +533,13 @@ class TripleGRUDecoder(nn.Module): # 2. For residual connection, we need x in the same space as noise_output x_processed = self._apply_preprocessing(x, day_idx) + clean_dtype = next(self.clean_speech_model.parameters()).dtype + if x_processed.dtype != clean_dtype: + x_processed = x_processed.to(clean_dtype) # Ensure dtype consistency for mixed precision residual connection - noise_output = noise_output.to(x_processed.dtype) + if noise_output.dtype != clean_dtype: + noise_output = noise_output.to(clean_dtype) denoised_input = x_processed - noise_output clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, states['clean'] if states else None) @@ -514,10 +558,6 @@ class TripleGRUDecoder(nn.Module): clean_grad (tensor) - gradients from clean speech model output layer noisy_grad (tensor) - gradients from noisy speech model output layer - if grl_lambda and grl_lambda != 0.0: - noisy_input = gradient_reverse(noise_output, grl_lambda) - else: - noisy_input = noise_output ''' # Combine gradients: negative from clean model, positive from noisy model combined_grad = -clean_grad + noisy_grad diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index a5217fd..6cba21b 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -589,9 +589,9 @@ class BrainToTextDecoder_Trainer: # Get phoneme predictions using inference mode during training # (We use inference mode for simplicity - only clean logits are used for CTC loss) # Ensure features tensor matches model parameter dtype for TPU compatibility - if self.accelerator.mixed_precision == 'bf16': - # In mixed precision mode, ensure features match the expected precision - features = features.to(torch.float32) + model_param = next(self.model.parameters()) if self.model is not None else None + if model_param is not None and features.dtype != model_param.dtype: + features = features.to(model_param.dtype) # Forward pass: enable full adversarial mode if configured and past warmup use_full = self.adv_enabled and (i >= self.adv_warmup_steps) @@ -621,7 +621,7 @@ class BrainToTextDecoder_Trainer: noisy_loss = torch.mean(noisy_loss) # Optional noise energy regularization - noise_l2 = torch.tensor(0.0, device=self.device) + noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype) if self.adv_noise_l2_weight > 0.0: noise_l2 = torch.mean(noise_output.pow(2)) @@ -799,9 +799,9 @@ class BrainToTextDecoder_Trainer: adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Ensure features tensor matches model parameter dtype for TPU compatibility - if self.accelerator.mixed_precision == 'bf16': - # In mixed precision mode, ensure features match the expected precision - features = features.to(torch.float32) + model_param = next(self.model.parameters()) if self.model is not None else None + if model_param is not None and features.dtype != model_param.dtype: + features = features.to(model_param.dtype) logits = self.model(features, day_indicies, None, False, 'inference') @@ -878,9 +878,9 @@ class BrainToTextDecoder_Trainer: features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Ensure features tensor matches model parameter dtype for TPU compatibility - if self.accelerator.mixed_precision == 'bf16': - # In mixed precision mode, ensure features match the expected precision - features = features.to(torch.float32) + model_param = next(self.model.parameters()) if self.model is not None else None + if model_param is not None and features.dtype != model_param.dtype: + features = features.to(model_param.dtype) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode) @@ -907,9 +907,9 @@ class BrainToTextDecoder_Trainer: adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Ensure features tensor matches model parameter dtype for TPU compatibility - if self.accelerator.mixed_precision == 'bf16': - # In mixed precision mode, ensure features match the expected precision - features = features.to(torch.float32) + model_param = next(self.model.parameters()) if self.model is not None else None + if model_param is not None and features.dtype != model_param.dtype: + features = features.to(model_param.dtype) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode)