import torch from torch import nn class NoiseModel(nn.Module): ''' Noise Model: 2-layer GRU that learns to estimate noise in the neural data ''' def __init__(self, neural_dim, n_units, n_days, rnn_dropout=0.0, input_dropout=0.0, patch_size=0, patch_stride=0): super(NoiseModel, self).__init__() self.neural_dim = neural_dim self.n_units = n_units self.n_days = n_days self.rnn_dropout = rnn_dropout self.input_dropout = input_dropout self.patch_size = patch_size self.patch_stride = patch_stride # Day-specific input layers self.day_layer_activation = nn.Softsign() # Let Accelerator handle dtype automatically for TPU compatibility self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) self.day_layer_dropout = nn.Dropout(input_dropout) # Calculate input size after patching self.input_size = self.neural_dim if self.patch_size > 0: self.input_size *= self.patch_size # 2-layer GRU for noise estimation self.gru = nn.GRU( input_size=self.input_size, hidden_size=self.input_size, # Output same dimension as input num_layers=2, dropout=self.rnn_dropout, batch_first=True, bidirectional=False, ) # Initialize GRU parameters for name, param in self.gru.named_parameters(): if "weight_hh" in name: nn.init.orthogonal_(param) if "weight_ih" in name: nn.init.xavier_uniform_(param) # Learnable initial hidden state - let Accelerator handle dtype self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size))) def forward(self, x, day_idx, states=None): # XLA-friendly day-specific transformation using gather instead of dynamic indexing batch_size = x.size(0) # Stack all day weights and biases upfront for static indexing all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim] all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim] # XLA-friendly gather operation day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim] day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim] # Use bmm (batch matrix multiply) which is highly optimized in XLA # Ensure dtype consistency for mixed precision training x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) x = self.day_layer_activation(x) # XLA-friendly conditional dropout if self.input_dropout > 0: x = self.day_layer_dropout(x) # Apply patch processing if enabled (keep conditional for now, optimize later) if self.patch_size > 0: x = x.unsqueeze(1) x = x.permute(0, 3, 1, 2) x_unfold = x.unfold(3, self.patch_size, self.patch_stride) x_unfold = x_unfold.squeeze(2) x_unfold = x_unfold.permute(0, 2, 3, 1) x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) # XLA-friendly hidden state initialization - avoid dynamic allocation if states is None: states = self.h0.expand(2, batch_size, self.input_size).contiguous() # GRU forward pass output, hidden_states = self.gru(x, states) return output, hidden_states class CleanSpeechModel(nn.Module): ''' Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition ''' def __init__(self, neural_dim, n_units, n_days, n_classes, rnn_dropout=0.0, input_dropout=0.0, patch_size=0, patch_stride=0): super(CleanSpeechModel, self).__init__() self.neural_dim = neural_dim self.n_units = n_units self.n_days = n_days self.n_classes = n_classes self.rnn_dropout = rnn_dropout self.input_dropout = input_dropout self.patch_size = patch_size self.patch_stride = patch_stride # Day-specific input layers self.day_layer_activation = nn.Softsign() # Let Accelerator handle dtype automatically for TPU compatibility self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) self.day_layer_dropout = nn.Dropout(input_dropout) # Calculate input size after patching self.input_size = self.neural_dim if self.patch_size > 0: self.input_size *= self.patch_size # 3-layer GRU for clean speech recognition self.gru = nn.GRU( input_size=self.input_size, hidden_size=self.n_units, num_layers=3, dropout=self.rnn_dropout, batch_first=True, bidirectional=False, ) # Initialize GRU parameters for name, param in self.gru.named_parameters(): if "weight_hh" in name: nn.init.orthogonal_(param) if "weight_ih" in name: nn.init.xavier_uniform_(param) # Output classification layer self.out = nn.Linear(self.n_units, self.n_classes) nn.init.xavier_uniform_(self.out.weight) # Learnable initial hidden state self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) def forward(self, x, day_idx, states=None, return_state=False): # XLA-friendly day-specific transformation using gather instead of dynamic indexing batch_size = x.size(0) # Stack all day weights and biases upfront for static indexing all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim] all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim] # XLA-friendly gather operation day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim] day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim] # Use bmm (batch matrix multiply) which is highly optimized in XLA # Ensure dtype consistency for mixed precision training x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) x = self.day_layer_activation(x) if self.input_dropout > 0: x = self.day_layer_dropout(x) # Apply patch processing if enabled if self.patch_size > 0: x = x.unsqueeze(1) x = x.permute(0, 3, 1, 2) x_unfold = x.unfold(3, self.patch_size, self.patch_stride) x_unfold = x_unfold.squeeze(2) x_unfold = x_unfold.permute(0, 2, 3, 1) x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) # XLA-friendly hidden state initialization if states is None: states = self.h0.expand(3, batch_size, self.n_units).contiguous() # GRU forward pass output, hidden_states = self.gru(x, states) # Classification logits = self.out(output) if return_state: return logits, hidden_states return logits class NoisySpeechModel(nn.Module): ''' Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition ''' def __init__(self, neural_dim, n_units, n_days, n_classes, rnn_dropout=0.0, input_dropout=0.0, patch_size=0, patch_stride=0): super(NoisySpeechModel, self).__init__() self.neural_dim = neural_dim self.n_units = n_units self.n_days = n_days self.n_classes = n_classes self.rnn_dropout = rnn_dropout self.input_dropout = input_dropout self.patch_size = patch_size self.patch_stride = patch_stride # Calculate input size after patching self.input_size = self.neural_dim if self.patch_size > 0: self.input_size *= self.patch_size # 2-layer GRU for noisy speech recognition self.gru = nn.GRU( input_size=self.input_size, hidden_size=self.n_units, num_layers=2, dropout=self.rnn_dropout, batch_first=True, bidirectional=False, ) # Initialize GRU parameters for name, param in self.gru.named_parameters(): if "weight_hh" in name: nn.init.orthogonal_(param) if "weight_ih" in name: nn.init.xavier_uniform_(param) # Output classification layer self.out = nn.Linear(self.n_units, self.n_classes) nn.init.xavier_uniform_(self.out.weight) # Learnable initial hidden state self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) def forward(self, x, states=None, return_state=False): # Note: NoisySpeechModel doesn't need day-specific layers as it processes noise batch_size = x.size(0) # XLA-friendly hidden state initialization if states is None: states = self.h0.expand(2, batch_size, self.n_units).contiguous() # GRU forward pass output, hidden_states = self.gru(x, states) # Classification logits = self.out(output) if return_state: return logits, hidden_states return logits class TripleGRUDecoder(nn.Module): ''' Three-model adversarial architecture for neural speech decoding Combines: - NoiseModel: estimates noise in neural data - CleanSpeechModel: processes denoised signal for recognition - NoisySpeechModel: processes noise signal for recognition ''' def __init__(self, neural_dim, n_units, n_days, n_classes, rnn_dropout=0.0, input_dropout=0.0, patch_size=0, patch_stride=0, ): ''' neural_dim (int) - number of channels in a single timestep (e.g. 512) n_units (int) - number of hidden units in each recurrent layer n_days (int) - number of days in the dataset n_classes (int) - number of classes (phonemes) rnn_dropout (float) - percentage of units to dropout during training input_dropout (float) - percentage of input units to dropout during training patch_size (int) - number of timesteps to concat on initial input layer patch_stride(int) - number of timesteps to stride over when concatenating initial input ''' super(TripleGRUDecoder, self).__init__() self.neural_dim = neural_dim self.n_units = n_units self.n_classes = n_classes self.n_days = n_days self.rnn_dropout = rnn_dropout self.input_dropout = input_dropout self.patch_size = patch_size self.patch_stride = patch_stride # Create the three models self.noise_model = NoiseModel( neural_dim=neural_dim, n_units=n_units, n_days=n_days, rnn_dropout=rnn_dropout, input_dropout=input_dropout, patch_size=patch_size, patch_stride=patch_stride ) self.clean_speech_model = CleanSpeechModel( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=rnn_dropout, input_dropout=input_dropout, patch_size=patch_size, patch_stride=patch_stride ) self.noisy_speech_model = NoisySpeechModel( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=rnn_dropout, input_dropout=input_dropout, patch_size=patch_size, patch_stride=patch_stride ) # Training mode flag self.training_mode = 'full' # 'full', 'inference' def _apply_preprocessing(self, x, day_idx): '''XLA-friendly preprocessing with static operations''' batch_size = x.size(0) # XLA-friendly day-specific transformation using gather instead of dynamic indexing all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0) all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0) # XLA-friendly gather operation day_weights = torch.index_select(all_day_weights, 0, day_idx) day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # Use bmm (batch matrix multiply) which is highly optimized in XLA x_processed = torch.bmm(x, day_weights) + day_biases x_processed = self.clean_speech_model.day_layer_activation(x_processed) # Apply patch processing if enabled if self.patch_size > 0: x_processed = x_processed.unsqueeze(1) x_processed = x_processed.permute(0, 3, 1, 2) x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride) x_unfold = x_unfold.squeeze(2) x_unfold = x_unfold.permute(0, 2, 3, 1) x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1) return x_processed def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None): '''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)''' batch_size = x_processed.size(0) # XLA-friendly hidden state initialization if states is None: states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous() # GRU forward pass (skip preprocessing since input is already processed) output, hidden_states = self.clean_speech_model.gru(x_processed, states) # Classification logits = self.clean_speech_model.out(output) return logits def _noisy_forward_with_processed_input(self, x_processed, states=None): '''Forward pass for NoisySpeechModel with already processed input''' batch_size = x_processed.size(0) # XLA-friendly hidden state initialization if states is None: states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous() # GRU forward pass (NoisySpeechModel doesn't have day layers anyway) output, hidden_states = self.noisy_speech_model.gru(x_processed, states) # Classification logits = self.noisy_speech_model.out(output) return logits def forward(self, x, day_idx, states=None, return_state=False, mode='inference'): ''' Three-model adversarial forward pass x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim) day_idx (tensor) - tensor of day indices for each example in the batch states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only) ''' if mode == 'full': # Training mode: run all three models # 1. Noise model estimates noise in the data noise_output, noise_hidden = self.noise_model(x, day_idx, states['noise'] if states else None) # 2. For residual connection, we need x in the same space as noise_output # Apply the same preprocessing that the models use internally x_processed = self._apply_preprocessing(x, day_idx) # 3. Clean speech model processes denoised signal denoised_input = x_processed - noise_output # Residual connection in processed space # Clean speech model will apply its own preprocessing, so we pass the denoised processed data # But we need to reverse the preprocessing first, then let clean model do its own # Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, states['clean'] if states else None) # 4. Noisy speech model processes noise signal directly (no day layers needed) noisy_logits = self._noisy_forward_with_processed_input(noise_output, states['noisy'] if states else None) # XLA-friendly return - use tuple instead of dict for better compilation if return_state: return (clean_logits, noisy_logits, noise_output), noise_hidden return clean_logits, noisy_logits, noise_output elif mode == 'inference': # Inference mode: only noise model + clean speech model # 1. Estimate noise noise_output, noise_hidden = self.noise_model(x, day_idx, states['noise'] if states else None) # 2. For residual connection, we need x in the same space as noise_output x_processed = self._apply_preprocessing(x, day_idx) # 3. Process denoised signal denoised_input = x_processed - noise_output clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, states['clean'] if states else None) # XLA-friendly return - use tuple for consistency if return_state: return clean_logits, noise_hidden return clean_logits else: raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'") def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3): ''' Apply combined gradients to noise model parameters clean_grad (tensor) - gradients from clean speech model output layer noisy_grad (tensor) - gradients from noisy speech model output layer learning_rate (float) - learning rate for gradient update ''' # Combine gradients: negative from clean model, positive from noisy model combined_grad = -clean_grad + noisy_grad # Apply gradients to noise model parameters # This is a simplified implementation - in practice you'd want more sophisticated update rules with torch.no_grad(): for param in self.noise_model.parameters(): if param.grad is not None: # Scale the combined gradient appropriately # This is a placeholder - you'd need to implement proper gradient mapping param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data) def set_mode(self, mode): '''Set the operating mode''' self.training_mode = mode