tpu
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import cast
|
||||
|
||||
class GradientReversalFn(torch.autograd.Function):
|
||||
"""
|
||||
@@ -106,9 +107,15 @@ class NoiseModel(nn.Module):
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x = x.to(original_dtype)
|
||||
|
||||
gru_dtype = next(self.gru.parameters()).dtype
|
||||
if x.dtype != gru_dtype:
|
||||
x = x.to(gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization - avoid dynamic allocation
|
||||
if states is None:
|
||||
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
|
||||
if states.dtype != gru_dtype:
|
||||
states = states.to(gru_dtype)
|
||||
|
||||
# GRU forward pass
|
||||
output, hidden_states = self.gru(x, states)
|
||||
@@ -208,9 +215,15 @@ class CleanSpeechModel(nn.Module):
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x = x.to(original_dtype)
|
||||
|
||||
gru_dtype = next(self.gru.parameters()).dtype
|
||||
if x.dtype != gru_dtype:
|
||||
x = x.to(gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
|
||||
if states.dtype != gru_dtype:
|
||||
states = states.to(gru_dtype)
|
||||
|
||||
# GRU forward pass
|
||||
output, hidden_states = self.gru(x, states)
|
||||
@@ -280,9 +293,21 @@ class NoisySpeechModel(nn.Module):
|
||||
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
|
||||
batch_size = x.size(0)
|
||||
|
||||
gru_dtype = next(self.gru.parameters()).dtype
|
||||
if x.dtype != gru_dtype:
|
||||
x = x.to(gru_dtype)
|
||||
|
||||
gru_dtype = next(self.gru.parameters()).dtype
|
||||
if x.dtype != gru_dtype:
|
||||
x = x.to(gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
|
||||
if states.dtype != gru_dtype:
|
||||
states = states.to(gru_dtype)
|
||||
if states.dtype != gru_dtype:
|
||||
states = states.to(gru_dtype)
|
||||
|
||||
# GRU forward pass
|
||||
output, hidden_states = self.gru(x, states)
|
||||
@@ -407,11 +432,16 @@ class TripleGRUDecoder(nn.Module):
|
||||
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
|
||||
batch_size = x_processed.size(0)
|
||||
|
||||
clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype
|
||||
if x_processed.dtype != clean_gru_dtype:
|
||||
x_processed = x_processed.to(clean_gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization with dtype consistency
|
||||
if states is None:
|
||||
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
||||
# Ensure hidden states match input dtype for mixed precision training
|
||||
states = states.to(x_processed.dtype)
|
||||
if states.dtype != clean_gru_dtype:
|
||||
states = states.to(clean_gru_dtype)
|
||||
|
||||
# GRU forward pass (skip preprocessing since input is already processed)
|
||||
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
||||
@@ -424,11 +454,16 @@ class TripleGRUDecoder(nn.Module):
|
||||
'''Forward pass for NoisySpeechModel with already processed input'''
|
||||
batch_size = x_processed.size(0)
|
||||
|
||||
noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype
|
||||
if x_processed.dtype != noisy_gru_dtype:
|
||||
x_processed = x_processed.to(noisy_gru_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization with dtype consistency
|
||||
if states is None:
|
||||
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
||||
# Ensure hidden states match input dtype for mixed precision training
|
||||
states = states.to(x_processed.dtype)
|
||||
if states.dtype != noisy_gru_dtype:
|
||||
states = states.to(noisy_gru_dtype)
|
||||
|
||||
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
||||
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
||||
@@ -458,9 +493,13 @@ class TripleGRUDecoder(nn.Module):
|
||||
# 2. For residual connection, we need x in the same space as noise_output
|
||||
# Apply the same preprocessing that the models use internally
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
clean_dtype = next(self.clean_speech_model.parameters()).dtype
|
||||
if x_processed.dtype != clean_dtype:
|
||||
x_processed = x_processed.to(clean_dtype)
|
||||
|
||||
# Ensure dtype consistency between processed input and noise output
|
||||
noise_output = noise_output.to(x_processed.dtype)
|
||||
if noise_output.dtype != clean_dtype:
|
||||
noise_output = noise_output.to(clean_dtype)
|
||||
|
||||
# 3. Clean speech model processes denoised signal
|
||||
denoised_input = x_processed - noise_output # Residual connection in processed space
|
||||
@@ -473,9 +512,10 @@ class TripleGRUDecoder(nn.Module):
|
||||
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
||||
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
|
||||
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
|
||||
# Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility
|
||||
# Use x_processed.dtype as reference since it's the main data flow dtype
|
||||
noisy_input = noisy_input.to(x_processed.dtype)
|
||||
noisy_input = cast(torch.Tensor, noisy_input)
|
||||
noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
|
||||
if noisy_input.dtype != noisy_dtype:
|
||||
noisy_input = noisy_input.to(noisy_dtype)
|
||||
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
|
||||
states['noisy'] if states else None)
|
||||
|
||||
@@ -493,9 +533,13 @@ class TripleGRUDecoder(nn.Module):
|
||||
|
||||
# 2. For residual connection, we need x in the same space as noise_output
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
clean_dtype = next(self.clean_speech_model.parameters()).dtype
|
||||
if x_processed.dtype != clean_dtype:
|
||||
x_processed = x_processed.to(clean_dtype)
|
||||
|
||||
# Ensure dtype consistency for mixed precision residual connection
|
||||
noise_output = noise_output.to(x_processed.dtype)
|
||||
if noise_output.dtype != clean_dtype:
|
||||
noise_output = noise_output.to(clean_dtype)
|
||||
denoised_input = x_processed - noise_output
|
||||
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||
states['clean'] if states else None)
|
||||
@@ -514,10 +558,6 @@ class TripleGRUDecoder(nn.Module):
|
||||
|
||||
clean_grad (tensor) - gradients from clean speech model output layer
|
||||
noisy_grad (tensor) - gradients from noisy speech model output layer
|
||||
if grl_lambda and grl_lambda != 0.0:
|
||||
noisy_input = gradient_reverse(noise_output, grl_lambda)
|
||||
else:
|
||||
noisy_input = noise_output
|
||||
'''
|
||||
# Combine gradients: negative from clean model, positive from noisy model
|
||||
combined_grad = -clean_grad + noisy_grad
|
||||
|
Reference in New Issue
Block a user