This commit is contained in:
Zchen
2025-10-14 23:54:53 +08:00
parent 4b6d680283
commit ec4f6a25ef

View File

@@ -1,5 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from typing import cast
class GradientReversalFn(torch.autograd.Function): class GradientReversalFn(torch.autograd.Function):
""" """
@@ -458,9 +459,13 @@ class TripleGRUDecoder(nn.Module):
# 2. For residual connection, we need x in the same space as noise_output # 2. For residual connection, we need x in the same space as noise_output
# Apply the same preprocessing that the models use internally # Apply the same preprocessing that the models use internally
x_processed = self._apply_preprocessing(x, day_idx) x_processed = self._apply_preprocessing(x, day_idx)
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
# Ensure dtype consistency between processed input and noise output # Ensure dtype consistency between processed input and noise output
noise_output = noise_output.to(x_processed.dtype) if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
# 3. Clean speech model processes denoised signal # 3. Clean speech model processes denoised signal
denoised_input = x_processed - noise_output # Residual connection in processed space denoised_input = x_processed - noise_output # Residual connection in processed space
@@ -473,9 +478,10 @@ class TripleGRUDecoder(nn.Module):
# 4. Noisy speech model processes noise signal directly (no day layers needed) # 4. Noisy speech model processes noise signal directly (no day layers needed)
# Optionally apply Gradient Reversal to enforce adversarial training on noise output # Optionally apply Gradient Reversal to enforce adversarial training on noise output
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
# Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility noisy_input = cast(torch.Tensor, noisy_input)
# Use x_processed.dtype as reference since it's the main data flow dtype noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
noisy_input = noisy_input.to(x_processed.dtype) if noisy_input.dtype != noisy_dtype:
noisy_input = noisy_input.to(noisy_dtype)
noisy_logits = self._noisy_forward_with_processed_input(noisy_input, noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
states['noisy'] if states else None) states['noisy'] if states else None)
@@ -493,9 +499,13 @@ class TripleGRUDecoder(nn.Module):
# 2. For residual connection, we need x in the same space as noise_output # 2. For residual connection, we need x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx) x_processed = self._apply_preprocessing(x, day_idx)
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
# Ensure dtype consistency for mixed precision residual connection # Ensure dtype consistency for mixed precision residual connection
noise_output = noise_output.to(x_processed.dtype) if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
denoised_input = x_processed - noise_output denoised_input = x_processed - noise_output
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None) states['clean'] if states else None)
@@ -514,10 +524,6 @@ class TripleGRUDecoder(nn.Module):
clean_grad (tensor) - gradients from clean speech model output layer clean_grad (tensor) - gradients from clean speech model output layer
noisy_grad (tensor) - gradients from noisy speech model output layer noisy_grad (tensor) - gradients from noisy speech model output layer
if grl_lambda and grl_lambda != 0.0:
noisy_input = gradient_reverse(noise_output, grl_lambda)
else:
noisy_input = noise_output
''' '''
# Combine gradients: negative from clean model, positive from noisy model # Combine gradients: negative from clean model, positive from noisy model
combined_grad = -clean_grad + noisy_grad combined_grad = -clean_grad + noisy_grad