This commit is contained in:
Zchen
2025-10-14 23:35:42 +08:00
parent cd52ba51ba
commit 4b6d680283
2 changed files with 118 additions and 23 deletions

View File

@@ -407,9 +407,11 @@ class TripleGRUDecoder(nn.Module):
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
batch_size = x_processed.size(0)
# XLA-friendly hidden state initialization
# XLA-friendly hidden state initialization with dtype consistency
if states is None:
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype)
# GRU forward pass (skip preprocessing since input is already processed)
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
@@ -422,9 +424,11 @@ class TripleGRUDecoder(nn.Module):
'''Forward pass for NoisySpeechModel with already processed input'''
batch_size = x_processed.size(0)
# XLA-friendly hidden state initialization
# XLA-friendly hidden state initialization with dtype consistency
if states is None:
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype)
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
@@ -455,9 +459,11 @@ class TripleGRUDecoder(nn.Module):
# Apply the same preprocessing that the models use internally
x_processed = self._apply_preprocessing(x, day_idx)
# Ensure dtype consistency between processed input and noise output
noise_output = noise_output.to(x_processed.dtype)
# 3. Clean speech model processes denoised signal
# Ensure dtype consistency for mixed precision training in residual connection
denoised_input = x_processed - noise_output.to(x_processed.dtype) # Residual connection in processed space
denoised_input = x_processed - noise_output # Residual connection in processed space
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
# But we need to reverse the preprocessing first, then let clean model do its own
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
@@ -467,6 +473,9 @@ class TripleGRUDecoder(nn.Module):
# 4. Noisy speech model processes noise signal directly (no day layers needed)
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
# Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility
# Use x_processed.dtype as reference since it's the main data flow dtype
noisy_input = noisy_input.to(x_processed.dtype)
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
states['noisy'] if states else None)
@@ -485,9 +494,9 @@ class TripleGRUDecoder(nn.Module):
# 2. For residual connection, we need x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx)
# 3. Process denoised signal
# Ensure dtype consistency for mixed precision training in residual connection
denoised_input = x_processed - noise_output.to(x_processed.dtype)
# Ensure dtype consistency for mixed precision residual connection
noise_output = noise_output.to(x_processed.dtype)
denoised_input = x_processed - noise_output
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
@@ -505,7 +514,10 @@ class TripleGRUDecoder(nn.Module):
clean_grad (tensor) - gradients from clean speech model output layer
noisy_grad (tensor) - gradients from noisy speech model output layer
learning_rate (float) - learning rate for gradient update
if grl_lambda and grl_lambda != 0.0:
noisy_input = gradient_reverse(noise_output, grl_lambda)
else:
noisy_input = noise_output
'''
# Combine gradients: negative from clean model, positive from noisy model
combined_grad = -clean_grad + noisy_grad