518 lines
21 KiB
Python
518 lines
21 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
class GradientReversalFn(torch.autograd.Function):
|
|
"""
|
|
Gradient Reversal Layer (GRL)
|
|
Forward: identity
|
|
Backward: multiply incoming gradient by -lambda
|
|
"""
|
|
@staticmethod
|
|
def forward(ctx, x, lambd: float):
|
|
ctx.lambd = lambd
|
|
return x.view_as(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return -ctx.lambd * grad_output, None
|
|
|
|
def gradient_reverse(x, lambd: float = 1.0):
|
|
return GradientReversalFn.apply(x, lambd)
|
|
|
|
class NoiseModel(nn.Module):
|
|
'''
|
|
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
|
'''
|
|
def __init__(self,
|
|
neural_dim,
|
|
n_units,
|
|
n_days,
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=0,
|
|
patch_stride=0):
|
|
super(NoiseModel, self).__init__()
|
|
|
|
self.neural_dim = neural_dim
|
|
self.n_units = n_units
|
|
self.n_days = n_days
|
|
self.rnn_dropout = rnn_dropout
|
|
self.input_dropout = input_dropout
|
|
self.patch_size = patch_size
|
|
self.patch_stride = patch_stride
|
|
|
|
# Day-specific input layers
|
|
self.day_layer_activation = nn.Softsign()
|
|
# Let Accelerator handle dtype automatically for TPU compatibility
|
|
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
|
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
|
self.day_layer_dropout = nn.Dropout(input_dropout)
|
|
|
|
# Calculate input size after patching
|
|
self.input_size = self.neural_dim
|
|
if self.patch_size > 0:
|
|
self.input_size *= self.patch_size
|
|
|
|
# 2-layer GRU for noise estimation
|
|
self.gru = nn.GRU(
|
|
input_size=self.input_size,
|
|
hidden_size=self.input_size, # Output same dimension as input
|
|
num_layers=2,
|
|
dropout=self.rnn_dropout,
|
|
batch_first=True,
|
|
bidirectional=False,
|
|
)
|
|
|
|
# Initialize GRU parameters
|
|
for name, param in self.gru.named_parameters():
|
|
if "weight_hh" in name:
|
|
nn.init.orthogonal_(param)
|
|
if "weight_ih" in name:
|
|
nn.init.xavier_uniform_(param)
|
|
|
|
# Learnable initial hidden state - let Accelerator handle dtype
|
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
|
|
|
|
def forward(self, x, day_idx, states=None):
|
|
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
|
batch_size = x.size(0)
|
|
|
|
# Stack all day weights and biases upfront for static indexing
|
|
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
|
|
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
|
|
|
|
# XLA-friendly gather operation
|
|
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
|
|
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
|
|
|
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
|
# Ensure dtype consistency for mixed precision training
|
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
|
x = self.day_layer_activation(x)
|
|
|
|
# XLA-friendly conditional dropout
|
|
if self.input_dropout > 0:
|
|
x = self.day_layer_dropout(x)
|
|
|
|
# Apply patch processing if enabled (keep conditional for now, optimize later)
|
|
if self.patch_size > 0:
|
|
x = x.unsqueeze(1)
|
|
x = x.permute(0, 3, 1, 2)
|
|
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
|
x_unfold = x_unfold.squeeze(2)
|
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
|
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
|
|
|
# XLA-friendly hidden state initialization - avoid dynamic allocation
|
|
if states is None:
|
|
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
|
|
|
|
# GRU forward pass
|
|
output, hidden_states = self.gru(x, states)
|
|
|
|
return output, hidden_states
|
|
|
|
|
|
class CleanSpeechModel(nn.Module):
|
|
'''
|
|
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
|
|
'''
|
|
def __init__(self,
|
|
neural_dim,
|
|
n_units,
|
|
n_days,
|
|
n_classes,
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=0,
|
|
patch_stride=0):
|
|
super(CleanSpeechModel, self).__init__()
|
|
|
|
self.neural_dim = neural_dim
|
|
self.n_units = n_units
|
|
self.n_days = n_days
|
|
self.n_classes = n_classes
|
|
self.rnn_dropout = rnn_dropout
|
|
self.input_dropout = input_dropout
|
|
self.patch_size = patch_size
|
|
self.patch_stride = patch_stride
|
|
|
|
# Day-specific input layers
|
|
self.day_layer_activation = nn.Softsign()
|
|
# Let Accelerator handle dtype automatically for TPU compatibility
|
|
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
|
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
|
self.day_layer_dropout = nn.Dropout(input_dropout)
|
|
|
|
# Calculate input size after patching
|
|
self.input_size = self.neural_dim
|
|
if self.patch_size > 0:
|
|
self.input_size *= self.patch_size
|
|
|
|
# 3-layer GRU for clean speech recognition
|
|
self.gru = nn.GRU(
|
|
input_size=self.input_size,
|
|
hidden_size=self.n_units,
|
|
num_layers=3,
|
|
dropout=self.rnn_dropout,
|
|
batch_first=True,
|
|
bidirectional=False,
|
|
)
|
|
|
|
# Initialize GRU parameters
|
|
for name, param in self.gru.named_parameters():
|
|
if "weight_hh" in name:
|
|
nn.init.orthogonal_(param)
|
|
if "weight_ih" in name:
|
|
nn.init.xavier_uniform_(param)
|
|
|
|
# Output classification layer
|
|
self.out = nn.Linear(self.n_units, self.n_classes)
|
|
nn.init.xavier_uniform_(self.out.weight)
|
|
|
|
# Learnable initial hidden state
|
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
|
|
|
def forward(self, x, day_idx, states=None, return_state=False):
|
|
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
|
batch_size = x.size(0)
|
|
|
|
# Stack all day weights and biases upfront for static indexing
|
|
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
|
|
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
|
|
|
|
# XLA-friendly gather operation
|
|
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
|
|
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
|
|
|
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
|
# Ensure dtype consistency for mixed precision training
|
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
|
x = self.day_layer_activation(x)
|
|
|
|
if self.input_dropout > 0:
|
|
x = self.day_layer_dropout(x)
|
|
|
|
# Apply patch processing if enabled
|
|
if self.patch_size > 0:
|
|
x = x.unsqueeze(1)
|
|
x = x.permute(0, 3, 1, 2)
|
|
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
|
x_unfold = x_unfold.squeeze(2)
|
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
|
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
|
|
|
# XLA-friendly hidden state initialization
|
|
if states is None:
|
|
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
|
|
|
|
# GRU forward pass
|
|
output, hidden_states = self.gru(x, states)
|
|
|
|
# Classification
|
|
logits = self.out(output)
|
|
|
|
if return_state:
|
|
return logits, hidden_states
|
|
return logits
|
|
|
|
|
|
class NoisySpeechModel(nn.Module):
|
|
'''
|
|
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
|
|
'''
|
|
def __init__(self,
|
|
neural_dim,
|
|
n_units,
|
|
n_days,
|
|
n_classes,
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=0,
|
|
patch_stride=0):
|
|
super(NoisySpeechModel, self).__init__()
|
|
|
|
self.neural_dim = neural_dim
|
|
self.n_units = n_units
|
|
self.n_days = n_days
|
|
self.n_classes = n_classes
|
|
self.rnn_dropout = rnn_dropout
|
|
self.input_dropout = input_dropout
|
|
self.patch_size = patch_size
|
|
self.patch_stride = patch_stride
|
|
|
|
# Calculate input size after patching
|
|
self.input_size = self.neural_dim
|
|
if self.patch_size > 0:
|
|
self.input_size *= self.patch_size
|
|
|
|
# 2-layer GRU for noisy speech recognition
|
|
self.gru = nn.GRU(
|
|
input_size=self.input_size,
|
|
hidden_size=self.n_units,
|
|
num_layers=2,
|
|
dropout=self.rnn_dropout,
|
|
batch_first=True,
|
|
bidirectional=False,
|
|
)
|
|
|
|
# Initialize GRU parameters
|
|
for name, param in self.gru.named_parameters():
|
|
if "weight_hh" in name:
|
|
nn.init.orthogonal_(param)
|
|
if "weight_ih" in name:
|
|
nn.init.xavier_uniform_(param)
|
|
|
|
# Output classification layer
|
|
self.out = nn.Linear(self.n_units, self.n_classes)
|
|
nn.init.xavier_uniform_(self.out.weight)
|
|
|
|
# Learnable initial hidden state
|
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
|
|
|
def forward(self, x, states=None, return_state=False):
|
|
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
|
|
batch_size = x.size(0)
|
|
|
|
# XLA-friendly hidden state initialization
|
|
if states is None:
|
|
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
|
|
|
|
# GRU forward pass
|
|
output, hidden_states = self.gru(x, states)
|
|
|
|
# Classification
|
|
logits = self.out(output)
|
|
|
|
if return_state:
|
|
return logits, hidden_states
|
|
return logits
|
|
|
|
|
|
class TripleGRUDecoder(nn.Module):
|
|
'''
|
|
Three-model adversarial architecture for neural speech decoding
|
|
|
|
Combines:
|
|
- NoiseModel: estimates noise in neural data
|
|
- CleanSpeechModel: processes denoised signal for recognition
|
|
- NoisySpeechModel: processes noise signal for recognition
|
|
'''
|
|
def __init__(self,
|
|
neural_dim,
|
|
n_units,
|
|
n_days,
|
|
n_classes,
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=0,
|
|
patch_stride=0,
|
|
):
|
|
'''
|
|
neural_dim (int) - number of channels in a single timestep (e.g. 512)
|
|
n_units (int) - number of hidden units in each recurrent layer
|
|
n_days (int) - number of days in the dataset
|
|
n_classes (int) - number of classes (phonemes)
|
|
rnn_dropout (float) - percentage of units to dropout during training
|
|
input_dropout (float) - percentage of input units to dropout during training
|
|
patch_size (int) - number of timesteps to concat on initial input layer
|
|
patch_stride(int) - number of timesteps to stride over when concatenating initial input
|
|
'''
|
|
super(TripleGRUDecoder, self).__init__()
|
|
|
|
self.neural_dim = neural_dim
|
|
self.n_units = n_units
|
|
self.n_classes = n_classes
|
|
self.n_days = n_days
|
|
|
|
self.rnn_dropout = rnn_dropout
|
|
self.input_dropout = input_dropout
|
|
self.patch_size = patch_size
|
|
self.patch_stride = patch_stride
|
|
|
|
# Create the three models
|
|
self.noise_model = NoiseModel(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
rnn_dropout=rnn_dropout,
|
|
input_dropout=input_dropout,
|
|
patch_size=patch_size,
|
|
patch_stride=patch_stride
|
|
)
|
|
|
|
self.clean_speech_model = CleanSpeechModel(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
n_classes=n_classes,
|
|
rnn_dropout=rnn_dropout,
|
|
input_dropout=input_dropout,
|
|
patch_size=patch_size,
|
|
patch_stride=patch_stride
|
|
)
|
|
|
|
self.noisy_speech_model = NoisySpeechModel(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
n_classes=n_classes,
|
|
rnn_dropout=rnn_dropout,
|
|
input_dropout=input_dropout,
|
|
patch_size=patch_size,
|
|
patch_stride=patch_stride
|
|
)
|
|
|
|
# Training mode flag
|
|
self.training_mode = 'full' # 'full', 'inference'
|
|
|
|
def _apply_preprocessing(self, x, day_idx):
|
|
'''XLA-friendly preprocessing with static operations'''
|
|
batch_size = x.size(0)
|
|
|
|
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
|
all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
|
|
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
|
|
|
|
# XLA-friendly gather operation
|
|
day_weights = torch.index_select(all_day_weights, 0, day_idx)
|
|
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
|
|
|
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
|
# Ensure dtype consistency for mixed precision training
|
|
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
|
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
|
|
|
# Apply patch processing if enabled
|
|
if self.patch_size > 0:
|
|
x_processed = x_processed.unsqueeze(1)
|
|
x_processed = x_processed.permute(0, 3, 1, 2)
|
|
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
|
|
x_unfold = x_unfold.squeeze(2)
|
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
|
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
|
|
|
return x_processed
|
|
|
|
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
|
|
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
|
|
batch_size = x_processed.size(0)
|
|
|
|
# XLA-friendly hidden state initialization
|
|
if states is None:
|
|
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
|
|
|
# GRU forward pass (skip preprocessing since input is already processed)
|
|
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
|
|
|
# Classification
|
|
logits = self.clean_speech_model.out(output)
|
|
return logits
|
|
|
|
def _noisy_forward_with_processed_input(self, x_processed, states=None):
|
|
'''Forward pass for NoisySpeechModel with already processed input'''
|
|
batch_size = x_processed.size(0)
|
|
|
|
# XLA-friendly hidden state initialization
|
|
if states is None:
|
|
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
|
|
|
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
|
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
|
|
|
# Classification
|
|
logits = self.noisy_speech_model.out(output)
|
|
return logits
|
|
|
|
def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
|
|
'''
|
|
Three-model adversarial forward pass
|
|
|
|
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
|
|
day_idx (tensor) - tensor of day indices for each example in the batch
|
|
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
|
|
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
|
|
grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
|
|
'''
|
|
|
|
if mode == 'full':
|
|
# Training mode: run all three models
|
|
|
|
# 1. Noise model estimates noise in the data
|
|
noise_output, noise_hidden = self.noise_model(x, day_idx,
|
|
states['noise'] if states else None)
|
|
|
|
# 2. For residual connection, we need x in the same space as noise_output
|
|
# Apply the same preprocessing that the models use internally
|
|
x_processed = self._apply_preprocessing(x, day_idx)
|
|
|
|
# 3. Clean speech model processes denoised signal
|
|
# Ensure dtype consistency for mixed precision training in residual connection
|
|
denoised_input = x_processed - noise_output.to(x_processed.dtype) # Residual connection in processed space
|
|
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
|
|
# But we need to reverse the preprocessing first, then let clean model do its own
|
|
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
|
|
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
|
states['clean'] if states else None)
|
|
|
|
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
|
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
|
|
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
|
|
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
|
|
states['noisy'] if states else None)
|
|
|
|
# XLA-friendly return - use tuple instead of dict for better compilation
|
|
if return_state:
|
|
return (clean_logits, noisy_logits, noise_output), noise_hidden
|
|
return clean_logits, noisy_logits, noise_output
|
|
|
|
elif mode == 'inference':
|
|
# Inference mode: only noise model + clean speech model
|
|
|
|
# 1. Estimate noise
|
|
noise_output, noise_hidden = self.noise_model(x, day_idx,
|
|
states['noise'] if states else None)
|
|
|
|
# 2. For residual connection, we need x in the same space as noise_output
|
|
x_processed = self._apply_preprocessing(x, day_idx)
|
|
|
|
# 3. Process denoised signal
|
|
# Ensure dtype consistency for mixed precision training in residual connection
|
|
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
|
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
|
states['clean'] if states else None)
|
|
|
|
# XLA-friendly return - use tuple for consistency
|
|
if return_state:
|
|
return clean_logits, noise_hidden
|
|
return clean_logits
|
|
|
|
else:
|
|
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
|
|
|
|
def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3):
|
|
'''
|
|
Apply combined gradients to noise model parameters
|
|
|
|
clean_grad (tensor) - gradients from clean speech model output layer
|
|
noisy_grad (tensor) - gradients from noisy speech model output layer
|
|
learning_rate (float) - learning rate for gradient update
|
|
'''
|
|
# Combine gradients: negative from clean model, positive from noisy model
|
|
combined_grad = -clean_grad + noisy_grad
|
|
|
|
# Apply gradients to noise model parameters
|
|
# This is a simplified implementation - in practice you'd want more sophisticated update rules
|
|
with torch.no_grad():
|
|
for param in self.noise_model.parameters():
|
|
if param.grad is not None:
|
|
# Scale the combined gradient appropriately
|
|
# This is a placeholder - you'd need to implement proper gradient mapping
|
|
param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data)
|
|
|
|
def set_mode(self, mode):
|
|
'''Set the operating mode'''
|
|
self.training_mode = mode
|
|
|
|
|