Files
b2txt25/model_training_nnn/rnn_model.py
Zchen 0cbb83e052 tpu
2025-10-12 21:56:34 +08:00

478 lines
18 KiB
Python

import torch
from torch import nn
class NoiseModel(nn.Module):
'''
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
'''
def __init__(self,
neural_dim,
n_units,
n_days,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(NoiseModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noise estimation
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.input_size, # Output same dimension as input
num_layers=2,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size, dtype=torch.bfloat16)))
def forward(self, x, day_idx, states=None):
# Apply day-specific transformation
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
x = self.day_layer_activation(x)
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled
if self.patch_size > 0:
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
# Initialize hidden states
if states is None:
states = self.h0.expand(2, x.shape[0], self.input_size).contiguous()
# GRU forward pass
output, hidden_states = self.gru(x, states)
return output, hidden_states
class CleanSpeechModel(nn.Module):
'''
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(CleanSpeechModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 3-layer GRU for clean speech recognition
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.n_units,
num_layers=3,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Output classification layer
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
def forward(self, x, day_idx, states=None, return_state=False):
# Apply day-specific transformation
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
x = self.day_layer_activation(x)
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled
if self.patch_size > 0:
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
# Initialize hidden states
if states is None:
states = self.h0.expand(3, x.shape[0], self.n_units).contiguous()
# GRU forward pass
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits
class NoisySpeechModel(nn.Module):
'''
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(NoisySpeechModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noisy speech recognition
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.n_units,
num_layers=2,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Output classification layer
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
def forward(self, x, states=None, return_state=False):
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
# Initialize hidden states
if states is None:
states = self.h0.expand(2, x.shape[0], self.n_units).contiguous()
# GRU forward pass
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits
class TripleGRUDecoder(nn.Module):
'''
Three-model adversarial architecture for neural speech decoding
Combines:
- NoiseModel: estimates noise in neural data
- CleanSpeechModel: processes denoised signal for recognition
- NoisySpeechModel: processes noise signal for recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0,
):
'''
neural_dim (int) - number of channels in a single timestep (e.g. 512)
n_units (int) - number of hidden units in each recurrent layer
n_days (int) - number of days in the dataset
n_classes (int) - number of classes (phonemes)
rnn_dropout (float) - percentage of units to dropout during training
input_dropout (float) - percentage of input units to dropout during training
patch_size (int) - number of timesteps to concat on initial input layer
patch_stride(int) - number of timesteps to stride over when concatenating initial input
'''
super(TripleGRUDecoder, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_classes = n_classes
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Create the three models
self.noise_model = NoiseModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
self.clean_speech_model = CleanSpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
self.noisy_speech_model = NoisySpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
# Training mode flag
self.training_mode = 'full' # 'full', 'inference'
def _apply_preprocessing(self, x, day_idx):
'''Apply day-specific transformation and patch processing to match what models expect'''
# Apply day-specific transformation (same as in each model)
day_weights = torch.stack([self.clean_speech_model.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.clean_speech_model.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
x_processed = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
# Apply patch processing if enabled (same as in each model)
if self.patch_size > 0:
x_processed = x_processed.unsqueeze(1)
x_processed = x_processed.permute(0, 3, 1, 2)
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x_processed = x_unfold.reshape(x_processed.size(0), x_unfold.size(1), -1)
return x_processed
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
# Initialize hidden states
if states is None:
states = self.clean_speech_model.h0.expand(3, x_processed.shape[0], self.clean_speech_model.n_units).contiguous()
# GRU forward pass (skip preprocessing since input is already processed)
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
# Classification
logits = self.clean_speech_model.out(output)
return logits
def _noisy_forward_with_processed_input(self, x_processed, states=None):
'''Forward pass for NoisySpeechModel with already processed input'''
# Initialize hidden states
if states is None:
states = self.noisy_speech_model.h0.expand(2, x_processed.shape[0], self.noisy_speech_model.n_units).contiguous()
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
# Classification
logits = self.noisy_speech_model.out(output)
return logits
def forward(self, x, day_idx, states=None, return_state=False, mode='inference'):
'''
Three-model adversarial forward pass
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
day_idx (tensor) - tensor of day indices for each example in the batch
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
'''
if mode == 'full':
# Training mode: run all three models
# 1. Noise model estimates noise in the data
noise_output, noise_hidden = self.noise_model(x, day_idx,
states['noise'] if states else None)
# 2. For residual connection, we need x in the same space as noise_output
# Apply the same preprocessing that the models use internally
x_processed = self._apply_preprocessing(x, day_idx)
# 3. Clean speech model processes denoised signal
denoised_input = x_processed - noise_output # Residual connection in processed space
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
# But we need to reverse the preprocessing first, then let clean model do its own
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
# 4. Noisy speech model processes noise signal directly (no day layers needed)
noisy_logits = self._noisy_forward_with_processed_input(noise_output,
states['noisy'] if states else None)
if return_state:
return_states = {
'noise': noise_hidden,
'clean': None, # CleanSpeechModel doesn't return hidden states in this call
'noisy': None # NoisySpeechModel doesn't return hidden states in this call
}
return {
'clean_logits': clean_logits,
'noisy_logits': noisy_logits,
'noise_output': noise_output
}, return_states
return {
'clean_logits': clean_logits,
'noisy_logits': noisy_logits,
'noise_output': noise_output
}
elif mode == 'inference':
# Inference mode: only noise model + clean speech model
# 1. Estimate noise
noise_output, noise_hidden = self.noise_model(x, day_idx,
states['noise'] if states else None)
# 2. For residual connection, we need x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx)
# 3. Process denoised signal
denoised_input = x_processed - noise_output
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
if return_state:
return_states = {
'noise': noise_hidden,
'clean': None
}
return clean_logits, return_states
return clean_logits
else:
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3):
'''
Apply combined gradients to noise model parameters
clean_grad (tensor) - gradients from clean speech model output layer
noisy_grad (tensor) - gradients from noisy speech model output layer
learning_rate (float) - learning rate for gradient update
'''
# Combine gradients: negative from clean model, positive from noisy model
combined_grad = -clean_grad + noisy_grad
# Apply gradients to noise model parameters
# This is a simplified implementation - in practice you'd want more sophisticated update rules
with torch.no_grad():
for param in self.noise_model.parameters():
if param.grad is not None:
# Scale the combined gradient appropriately
# This is a placeholder - you'd need to implement proper gradient mapping
param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data)
def set_mode(self, mode):
'''Set the operating mode'''
self.training_mode = mode