Files
b2txt25/model_training_nnn/rnn_model.py

579 lines
24 KiB
Python
Raw Normal View History

2025-10-12 09:11:32 +08:00
import torch
from torch import nn
2025-10-14 23:54:53 +08:00
from typing import cast
2025-10-12 09:11:32 +08:00
2025-10-14 22:48:28 +08:00
class GradientReversalFn(torch.autograd.Function):
"""
Gradient Reversal Layer (GRL)
Forward: identity
Backward: multiply incoming gradient by -lambda
"""
@staticmethod
def forward(ctx, x, lambd: float):
ctx.lambd = lambd
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambd * grad_output, None
def gradient_reverse(x, lambd: float = 1.0):
return GradientReversalFn.apply(x, lambd)
2025-10-12 09:11:32 +08:00
class NoiseModel(nn.Module):
'''
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
'''
def __init__(self,
neural_dim,
n_units,
n_days,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(NoiseModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
2025-10-12 22:52:38 +08:00
# Let Accelerator handle dtype automatically for TPU compatibility
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
2025-10-12 09:11:32 +08:00
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noise estimation
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.input_size, # Output same dimension as input
num_layers=2,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
2025-10-12 22:52:38 +08:00
# Learnable initial hidden state - let Accelerator handle dtype
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
2025-10-12 09:11:32 +08:00
def forward(self, x, day_idx, states=None):
2025-10-12 23:36:16 +08:00
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
batch_size = x.size(0)
2025-10-12 09:11:32 +08:00
2025-10-12 23:36:16 +08:00
# Stack all day weights and biases upfront for static indexing
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
2025-10-12 23:36:58 +08:00
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
2025-10-12 09:11:32 +08:00
x = self.day_layer_activation(x)
2025-10-12 23:36:16 +08:00
# XLA-friendly conditional dropout
2025-10-12 09:11:32 +08:00
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
2025-10-14 23:22:59 +08:00
# Apply patch processing if enabled with dtype preservation for mixed precision training
2025-10-12 09:11:32 +08:00
if self.patch_size > 0:
2025-10-14 23:22:59 +08:00
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
2025-10-12 09:11:32 +08:00
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
2025-10-12 23:36:16 +08:00
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
2025-10-14 23:22:59 +08:00
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
2025-10-12 09:11:32 +08:00
2025-10-14 23:54:53 +08:00
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
2025-10-12 23:36:16 +08:00
# XLA-friendly hidden state initialization - avoid dynamic allocation
2025-10-12 09:11:32 +08:00
if states is None:
2025-10-12 23:36:16 +08:00
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
2025-10-14 23:54:53 +08:00
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
2025-10-12 09:11:32 +08:00
# GRU forward pass
output, hidden_states = self.gru(x, states)
return output, hidden_states
class CleanSpeechModel(nn.Module):
'''
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(CleanSpeechModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
2025-10-12 22:52:38 +08:00
# Let Accelerator handle dtype automatically for TPU compatibility
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
2025-10-12 09:11:32 +08:00
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 3-layer GRU for clean speech recognition
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.n_units,
num_layers=3,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Output classification layer
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
2025-10-12 22:52:38 +08:00
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
2025-10-12 09:11:32 +08:00
def forward(self, x, day_idx, states=None, return_state=False):
2025-10-12 23:36:16 +08:00
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
batch_size = x.size(0)
2025-10-12 09:11:32 +08:00
2025-10-12 23:36:16 +08:00
# Stack all day weights and biases upfront for static indexing
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
2025-10-12 23:36:58 +08:00
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
2025-10-12 09:11:32 +08:00
x = self.day_layer_activation(x)
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
2025-10-14 23:22:59 +08:00
# Apply patch processing if enabled with dtype preservation for mixed precision training
2025-10-12 09:11:32 +08:00
if self.patch_size > 0:
2025-10-14 23:22:59 +08:00
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
2025-10-12 09:11:32 +08:00
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
2025-10-12 23:36:16 +08:00
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
2025-10-14 23:22:59 +08:00
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
2025-10-12 09:11:32 +08:00
2025-10-14 23:54:53 +08:00
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
2025-10-12 23:36:16 +08:00
# XLA-friendly hidden state initialization
2025-10-12 09:11:32 +08:00
if states is None:
2025-10-12 23:36:16 +08:00
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
2025-10-14 23:54:53 +08:00
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
2025-10-12 09:11:32 +08:00
# GRU forward pass
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits
class NoisySpeechModel(nn.Module):
'''
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(NoisySpeechModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noisy speech recognition
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.n_units,
num_layers=2,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Output classification layer
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
2025-10-12 22:52:38 +08:00
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
2025-10-12 09:11:32 +08:00
def forward(self, x, states=None, return_state=False):
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
2025-10-12 23:36:16 +08:00
batch_size = x.size(0)
2025-10-12 09:11:32 +08:00
2025-10-14 23:54:53 +08:00
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
2025-10-12 23:36:16 +08:00
# XLA-friendly hidden state initialization
2025-10-12 09:11:32 +08:00
if states is None:
2025-10-12 23:36:16 +08:00
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
2025-10-14 23:54:53 +08:00
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
2025-10-12 09:11:32 +08:00
# GRU forward pass
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits
class TripleGRUDecoder(nn.Module):
'''
Three-model adversarial architecture for neural speech decoding
Combines:
- NoiseModel: estimates noise in neural data
- CleanSpeechModel: processes denoised signal for recognition
- NoisySpeechModel: processes noise signal for recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0,
):
'''
neural_dim (int) - number of channels in a single timestep (e.g. 512)
n_units (int) - number of hidden units in each recurrent layer
n_days (int) - number of days in the dataset
n_classes (int) - number of classes (phonemes)
rnn_dropout (float) - percentage of units to dropout during training
input_dropout (float) - percentage of input units to dropout during training
patch_size (int) - number of timesteps to concat on initial input layer
patch_stride(int) - number of timesteps to stride over when concatenating initial input
'''
super(TripleGRUDecoder, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_classes = n_classes
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Create the three models
self.noise_model = NoiseModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
self.clean_speech_model = CleanSpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
self.noisy_speech_model = NoisySpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
# Training mode flag
self.training_mode = 'full' # 'full', 'inference'
def _apply_preprocessing(self, x, day_idx):
2025-10-12 23:36:16 +08:00
'''XLA-friendly preprocessing with static operations'''
batch_size = x.size(0)
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
2025-10-12 09:11:32 +08:00
2025-10-12 23:36:16 +08:00
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx)
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
# Use bmm (batch matrix multiply) which is highly optimized in XLA
2025-10-14 22:48:28 +08:00
# Ensure dtype consistency for mixed precision training
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
2025-10-12 09:11:32 +08:00
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
2025-10-14 23:22:59 +08:00
# Apply patch processing if enabled with dtype preservation for mixed precision training
2025-10-12 09:11:32 +08:00
if self.patch_size > 0:
2025-10-14 23:22:59 +08:00
original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
2025-10-12 09:11:32 +08:00
x_processed = x_processed.unsqueeze(1)
x_processed = x_processed.permute(0, 3, 1, 2)
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
2025-10-12 23:36:16 +08:00
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
2025-10-14 23:22:59 +08:00
# Ensure dtype consistency after patch processing operations
x_processed = x_processed.to(original_dtype)
2025-10-12 09:11:32 +08:00
return x_processed
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
2025-10-12 23:36:16 +08:00
batch_size = x_processed.size(0)
2025-10-14 23:54:53 +08:00
clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype
if x_processed.dtype != clean_gru_dtype:
x_processed = x_processed.to(clean_gru_dtype)
2025-10-14 23:35:42 +08:00
# XLA-friendly hidden state initialization with dtype consistency
2025-10-12 09:11:32 +08:00
if states is None:
2025-10-12 23:36:16 +08:00
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
2025-10-14 23:35:42 +08:00
# Ensure hidden states match input dtype for mixed precision training
2025-10-14 23:54:53 +08:00
if states.dtype != clean_gru_dtype:
states = states.to(clean_gru_dtype)
2025-10-12 09:11:32 +08:00
# GRU forward pass (skip preprocessing since input is already processed)
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
# Classification
logits = self.clean_speech_model.out(output)
return logits
def _noisy_forward_with_processed_input(self, x_processed, states=None):
'''Forward pass for NoisySpeechModel with already processed input'''
2025-10-12 23:36:16 +08:00
batch_size = x_processed.size(0)
2025-10-14 23:54:53 +08:00
noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype
if x_processed.dtype != noisy_gru_dtype:
x_processed = x_processed.to(noisy_gru_dtype)
2025-10-14 23:35:42 +08:00
# XLA-friendly hidden state initialization with dtype consistency
2025-10-12 09:11:32 +08:00
if states is None:
2025-10-12 23:36:16 +08:00
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
2025-10-14 23:35:42 +08:00
# Ensure hidden states match input dtype for mixed precision training
2025-10-14 23:54:53 +08:00
if states.dtype != noisy_gru_dtype:
states = states.to(noisy_gru_dtype)
2025-10-12 09:11:32 +08:00
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
# Classification
logits = self.noisy_speech_model.out(output)
return logits
2025-10-14 22:48:28 +08:00
def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
2025-10-12 09:11:32 +08:00
'''
Three-model adversarial forward pass
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
day_idx (tensor) - tensor of day indices for each example in the batch
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
2025-10-14 22:48:28 +08:00
grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
2025-10-12 09:11:32 +08:00
'''
if mode == 'full':
# Training mode: run all three models
# 1. Noise model estimates noise in the data
noise_output, noise_hidden = self.noise_model(x, day_idx,
states['noise'] if states else None)
# 2. For residual connection, we need x in the same space as noise_output
# Apply the same preprocessing that the models use internally
x_processed = self._apply_preprocessing(x, day_idx)
2025-10-14 23:54:53 +08:00
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
2025-10-12 09:11:32 +08:00
2025-10-14 23:35:42 +08:00
# Ensure dtype consistency between processed input and noise output
2025-10-14 23:54:53 +08:00
if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
2025-10-14 23:35:42 +08:00
2025-10-12 09:11:32 +08:00
# 3. Clean speech model processes denoised signal
2025-10-14 23:35:42 +08:00
denoised_input = x_processed - noise_output # Residual connection in processed space
2025-10-12 09:11:32 +08:00
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
# But we need to reverse the preprocessing first, then let clean model do its own
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
# 4. Noisy speech model processes noise signal directly (no day layers needed)
2025-10-14 22:48:28 +08:00
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
2025-10-14 23:54:53 +08:00
noisy_input = cast(torch.Tensor, noisy_input)
noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
if noisy_input.dtype != noisy_dtype:
noisy_input = noisy_input.to(noisy_dtype)
2025-10-14 22:48:28 +08:00
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
2025-10-12 09:11:32 +08:00
states['noisy'] if states else None)
2025-10-12 23:36:16 +08:00
# XLA-friendly return - use tuple instead of dict for better compilation
2025-10-12 09:11:32 +08:00
if return_state:
2025-10-12 23:36:16 +08:00
return (clean_logits, noisy_logits, noise_output), noise_hidden
return clean_logits, noisy_logits, noise_output
2025-10-12 09:11:32 +08:00
elif mode == 'inference':
# Inference mode: only noise model + clean speech model
# 1. Estimate noise
noise_output, noise_hidden = self.noise_model(x, day_idx,
states['noise'] if states else None)
# 2. For residual connection, we need x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx)
2025-10-14 23:54:53 +08:00
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
2025-10-12 09:11:32 +08:00
2025-10-14 23:35:42 +08:00
# Ensure dtype consistency for mixed precision residual connection
2025-10-14 23:54:53 +08:00
if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
2025-10-14 23:35:42 +08:00
denoised_input = x_processed - noise_output
2025-10-12 09:11:32 +08:00
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
2025-10-12 23:36:16 +08:00
# XLA-friendly return - use tuple for consistency
2025-10-12 09:11:32 +08:00
if return_state:
2025-10-12 23:36:16 +08:00
return clean_logits, noise_hidden
2025-10-12 09:11:32 +08:00
return clean_logits
else:
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3):
'''
Apply combined gradients to noise model parameters
clean_grad (tensor) - gradients from clean speech model output layer
noisy_grad (tensor) - gradients from noisy speech model output layer
'''
# Combine gradients: negative from clean model, positive from noisy model
combined_grad = -clean_grad + noisy_grad
# Apply gradients to noise model parameters
# This is a simplified implementation - in practice you'd want more sophisticated update rules
with torch.no_grad():
for param in self.noise_model.parameters():
if param.grad is not None:
# Scale the combined gradient appropriately
# This is a placeholder - you'd need to implement proper gradient mapping
param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data)
def set_mode(self, mode):
'''Set the operating mode'''
self.training_mode = mode