diff --git a/model_training_nnn/rnn_model.py b/model_training_nnn/rnn_model.py index bb6798a..f560caa 100644 --- a/model_training_nnn/rnn_model.py +++ b/model_training_nnn/rnn_model.py @@ -1,5 +1,6 @@ import torch from torch import nn +from typing import cast class GradientReversalFn(torch.autograd.Function): """ @@ -458,9 +459,13 @@ class TripleGRUDecoder(nn.Module): # 2. For residual connection, we need x in the same space as noise_output # Apply the same preprocessing that the models use internally x_processed = self._apply_preprocessing(x, day_idx) + clean_dtype = next(self.clean_speech_model.parameters()).dtype + if x_processed.dtype != clean_dtype: + x_processed = x_processed.to(clean_dtype) # Ensure dtype consistency between processed input and noise output - noise_output = noise_output.to(x_processed.dtype) + if noise_output.dtype != clean_dtype: + noise_output = noise_output.to(clean_dtype) # 3. Clean speech model processes denoised signal denoised_input = x_processed - noise_output # Residual connection in processed space @@ -473,9 +478,10 @@ class TripleGRUDecoder(nn.Module): # 4. Noisy speech model processes noise signal directly (no day layers needed) # Optionally apply Gradient Reversal to enforce adversarial training on noise output noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output - # Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility - # Use x_processed.dtype as reference since it's the main data flow dtype - noisy_input = noisy_input.to(x_processed.dtype) + noisy_input = cast(torch.Tensor, noisy_input) + noisy_dtype = next(self.noisy_speech_model.parameters()).dtype + if noisy_input.dtype != noisy_dtype: + noisy_input = noisy_input.to(noisy_dtype) noisy_logits = self._noisy_forward_with_processed_input(noisy_input, states['noisy'] if states else None) @@ -493,9 +499,13 @@ class TripleGRUDecoder(nn.Module): # 2. For residual connection, we need x in the same space as noise_output x_processed = self._apply_preprocessing(x, day_idx) + clean_dtype = next(self.clean_speech_model.parameters()).dtype + if x_processed.dtype != clean_dtype: + x_processed = x_processed.to(clean_dtype) # Ensure dtype consistency for mixed precision residual connection - noise_output = noise_output.to(x_processed.dtype) + if noise_output.dtype != clean_dtype: + noise_output = noise_output.to(clean_dtype) denoised_input = x_processed - noise_output clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, states['clean'] if states else None) @@ -514,10 +524,6 @@ class TripleGRUDecoder(nn.Module): clean_grad (tensor) - gradients from clean speech model output layer noisy_grad (tensor) - gradients from noisy speech model output layer - if grl_lambda and grl_lambda != 0.0: - noisy_input = gradient_reverse(noise_output, grl_lambda) - else: - noisy_input = noise_output ''' # Combine gradients: negative from clean model, positive from noisy model combined_grad = -clean_grad + noisy_grad