| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | import torch | 
					
						
							|  |  |  | from torch import nn | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-14 22:48:28 +08:00
										 |  |  | class GradientReversalFn(torch.autograd.Function): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Gradient Reversal Layer (GRL) | 
					
						
							|  |  |  |     Forward: identity | 
					
						
							|  |  |  |     Backward: multiply incoming gradient by -lambda | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def forward(ctx, x, lambd: float): | 
					
						
							|  |  |  |         ctx.lambd = lambd | 
					
						
							|  |  |  |         return x.view_as(x) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def backward(ctx, grad_output): | 
					
						
							|  |  |  |         return -ctx.lambd * grad_output, None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def gradient_reverse(x, lambd: float = 1.0): | 
					
						
							|  |  |  |     return GradientReversalFn.apply(x, lambd) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | class NoiseModel(nn.Module): | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  |     Noise Model: 2-layer GRU that learns to estimate noise in the neural data | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  neural_dim, | 
					
						
							|  |  |  |                  n_units, | 
					
						
							|  |  |  |                  n_days, | 
					
						
							|  |  |  |                  rnn_dropout=0.0, | 
					
						
							|  |  |  |                  input_dropout=0.0, | 
					
						
							|  |  |  |                  patch_size=0, | 
					
						
							|  |  |  |                  patch_stride=0): | 
					
						
							|  |  |  |         super(NoiseModel, self).__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.neural_dim = neural_dim | 
					
						
							|  |  |  |         self.n_units = n_units | 
					
						
							|  |  |  |         self.n_days = n_days | 
					
						
							|  |  |  |         self.rnn_dropout = rnn_dropout | 
					
						
							|  |  |  |         self.input_dropout = input_dropout | 
					
						
							|  |  |  |         self.patch_size = patch_size | 
					
						
							|  |  |  |         self.patch_stride = patch_stride | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Day-specific input layers | 
					
						
							|  |  |  |         self.day_layer_activation = nn.Softsign() | 
					
						
							| 
									
										
										
										
											2025-10-12 22:52:38 +08:00
										 |  |  |         # Let Accelerator handle dtype automatically for TPU compatibility | 
					
						
							|  |  |  |         self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) | 
					
						
							|  |  |  |         self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         self.day_layer_dropout = nn.Dropout(input_dropout) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Calculate input size after patching | 
					
						
							|  |  |  |         self.input_size = self.neural_dim | 
					
						
							|  |  |  |         if self.patch_size > 0: | 
					
						
							|  |  |  |             self.input_size *= self.patch_size | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # 2-layer GRU for noise estimation | 
					
						
							|  |  |  |         self.gru = nn.GRU( | 
					
						
							|  |  |  |             input_size=self.input_size, | 
					
						
							|  |  |  |             hidden_size=self.input_size,  # Output same dimension as input | 
					
						
							|  |  |  |             num_layers=2, | 
					
						
							|  |  |  |             dropout=self.rnn_dropout, | 
					
						
							|  |  |  |             batch_first=True, | 
					
						
							|  |  |  |             bidirectional=False, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Initialize GRU parameters | 
					
						
							|  |  |  |         for name, param in self.gru.named_parameters(): | 
					
						
							|  |  |  |             if "weight_hh" in name: | 
					
						
							|  |  |  |                 nn.init.orthogonal_(param) | 
					
						
							|  |  |  |             if "weight_ih" in name: | 
					
						
							|  |  |  |                 nn.init.xavier_uniform_(param) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 22:52:38 +08:00
										 |  |  |         # Learnable initial hidden state - let Accelerator handle dtype | 
					
						
							|  |  |  |         self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size))) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x, day_idx, states=None): | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # XLA-friendly day-specific transformation using gather instead of dynamic indexing | 
					
						
							|  |  |  |         batch_size = x.size(0) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # Stack all day weights and biases upfront for static indexing | 
					
						
							|  |  |  |         all_day_weights = torch.stack(list(self.day_weights), dim=0)  # [n_days, neural_dim, neural_dim] | 
					
						
							|  |  |  |         all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0)  # [n_days, neural_dim] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # XLA-friendly gather operation | 
					
						
							|  |  |  |         day_weights = torch.index_select(all_day_weights, 0, day_idx)  # [batch_size, neural_dim, neural_dim] | 
					
						
							|  |  |  |         day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)  # [batch_size, 1, neural_dim] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Use bmm (batch matrix multiply) which is highly optimized in XLA | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:58 +08:00
										 |  |  |         # Ensure dtype consistency for mixed precision training | 
					
						
							|  |  |  |         x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         x = self.day_layer_activation(x) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # XLA-friendly conditional dropout | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         if self.input_dropout > 0: | 
					
						
							|  |  |  |             x = self.day_layer_dropout(x) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # Apply patch processing if enabled (keep conditional for now, optimize later) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         if self.patch_size > 0: | 
					
						
							|  |  |  |             x = x.unsqueeze(1) | 
					
						
							|  |  |  |             x = x.permute(0, 3, 1, 2) | 
					
						
							|  |  |  |             x_unfold = x.unfold(3, self.patch_size, self.patch_stride) | 
					
						
							|  |  |  |             x_unfold = x_unfold.squeeze(2) | 
					
						
							|  |  |  |             x_unfold = x_unfold.permute(0, 2, 3, 1) | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # XLA-friendly hidden state initialization - avoid dynamic allocation | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         if states is None: | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             states = self.h0.expand(2, batch_size, self.input_size).contiguous() | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # GRU forward pass | 
					
						
							|  |  |  |         output, hidden_states = self.gru(x, states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return output, hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CleanSpeechModel(nn.Module): | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  |     Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  neural_dim, | 
					
						
							|  |  |  |                  n_units, | 
					
						
							|  |  |  |                  n_days, | 
					
						
							|  |  |  |                  n_classes, | 
					
						
							|  |  |  |                  rnn_dropout=0.0, | 
					
						
							|  |  |  |                  input_dropout=0.0, | 
					
						
							|  |  |  |                  patch_size=0, | 
					
						
							|  |  |  |                  patch_stride=0): | 
					
						
							|  |  |  |         super(CleanSpeechModel, self).__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.neural_dim = neural_dim | 
					
						
							|  |  |  |         self.n_units = n_units | 
					
						
							|  |  |  |         self.n_days = n_days | 
					
						
							|  |  |  |         self.n_classes = n_classes | 
					
						
							|  |  |  |         self.rnn_dropout = rnn_dropout | 
					
						
							|  |  |  |         self.input_dropout = input_dropout | 
					
						
							|  |  |  |         self.patch_size = patch_size | 
					
						
							|  |  |  |         self.patch_stride = patch_stride | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Day-specific input layers | 
					
						
							|  |  |  |         self.day_layer_activation = nn.Softsign() | 
					
						
							| 
									
										
										
										
											2025-10-12 22:52:38 +08:00
										 |  |  |         # Let Accelerator handle dtype automatically for TPU compatibility | 
					
						
							|  |  |  |         self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) | 
					
						
							|  |  |  |         self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         self.day_layer_dropout = nn.Dropout(input_dropout) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Calculate input size after patching | 
					
						
							|  |  |  |         self.input_size = self.neural_dim | 
					
						
							|  |  |  |         if self.patch_size > 0: | 
					
						
							|  |  |  |             self.input_size *= self.patch_size | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # 3-layer GRU for clean speech recognition | 
					
						
							|  |  |  |         self.gru = nn.GRU( | 
					
						
							|  |  |  |             input_size=self.input_size, | 
					
						
							|  |  |  |             hidden_size=self.n_units, | 
					
						
							|  |  |  |             num_layers=3, | 
					
						
							|  |  |  |             dropout=self.rnn_dropout, | 
					
						
							|  |  |  |             batch_first=True, | 
					
						
							|  |  |  |             bidirectional=False, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Initialize GRU parameters | 
					
						
							|  |  |  |         for name, param in self.gru.named_parameters(): | 
					
						
							|  |  |  |             if "weight_hh" in name: | 
					
						
							|  |  |  |                 nn.init.orthogonal_(param) | 
					
						
							|  |  |  |             if "weight_ih" in name: | 
					
						
							|  |  |  |                 nn.init.xavier_uniform_(param) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Output classification layer | 
					
						
							|  |  |  |         self.out = nn.Linear(self.n_units, self.n_classes) | 
					
						
							|  |  |  |         nn.init.xavier_uniform_(self.out.weight) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Learnable initial hidden state | 
					
						
							| 
									
										
										
										
											2025-10-12 22:52:38 +08:00
										 |  |  |         self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x, day_idx, states=None, return_state=False): | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # XLA-friendly day-specific transformation using gather instead of dynamic indexing | 
					
						
							|  |  |  |         batch_size = x.size(0) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # Stack all day weights and biases upfront for static indexing | 
					
						
							|  |  |  |         all_day_weights = torch.stack(list(self.day_weights), dim=0)  # [n_days, neural_dim, neural_dim] | 
					
						
							|  |  |  |         all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0)  # [n_days, neural_dim] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # XLA-friendly gather operation | 
					
						
							|  |  |  |         day_weights = torch.index_select(all_day_weights, 0, day_idx)  # [batch_size, neural_dim, neural_dim] | 
					
						
							|  |  |  |         day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)  # [batch_size, 1, neural_dim] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Use bmm (batch matrix multiply) which is highly optimized in XLA | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:58 +08:00
										 |  |  |         # Ensure dtype consistency for mixed precision training | 
					
						
							|  |  |  |         x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         x = self.day_layer_activation(x) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if self.input_dropout > 0: | 
					
						
							|  |  |  |             x = self.day_layer_dropout(x) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Apply patch processing if enabled | 
					
						
							|  |  |  |         if self.patch_size > 0: | 
					
						
							|  |  |  |             x = x.unsqueeze(1) | 
					
						
							|  |  |  |             x = x.permute(0, 3, 1, 2) | 
					
						
							|  |  |  |             x_unfold = x.unfold(3, self.patch_size, self.patch_stride) | 
					
						
							|  |  |  |             x_unfold = x_unfold.squeeze(2) | 
					
						
							|  |  |  |             x_unfold = x_unfold.permute(0, 2, 3, 1) | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # XLA-friendly hidden state initialization | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         if states is None: | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             states = self.h0.expand(3, batch_size, self.n_units).contiguous() | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # GRU forward pass | 
					
						
							|  |  |  |         output, hidden_states = self.gru(x, states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Classification | 
					
						
							|  |  |  |         logits = self.out(output) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if return_state: | 
					
						
							|  |  |  |             return logits, hidden_states | 
					
						
							|  |  |  |         return logits | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class NoisySpeechModel(nn.Module): | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  |     Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  neural_dim, | 
					
						
							|  |  |  |                  n_units, | 
					
						
							|  |  |  |                  n_days, | 
					
						
							|  |  |  |                  n_classes, | 
					
						
							|  |  |  |                  rnn_dropout=0.0, | 
					
						
							|  |  |  |                  input_dropout=0.0, | 
					
						
							|  |  |  |                  patch_size=0, | 
					
						
							|  |  |  |                  patch_stride=0): | 
					
						
							|  |  |  |         super(NoisySpeechModel, self).__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.neural_dim = neural_dim | 
					
						
							|  |  |  |         self.n_units = n_units | 
					
						
							|  |  |  |         self.n_days = n_days | 
					
						
							|  |  |  |         self.n_classes = n_classes | 
					
						
							|  |  |  |         self.rnn_dropout = rnn_dropout | 
					
						
							|  |  |  |         self.input_dropout = input_dropout | 
					
						
							|  |  |  |         self.patch_size = patch_size | 
					
						
							|  |  |  |         self.patch_stride = patch_stride | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Calculate input size after patching | 
					
						
							|  |  |  |         self.input_size = self.neural_dim | 
					
						
							|  |  |  |         if self.patch_size > 0: | 
					
						
							|  |  |  |             self.input_size *= self.patch_size | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # 2-layer GRU for noisy speech recognition | 
					
						
							|  |  |  |         self.gru = nn.GRU( | 
					
						
							|  |  |  |             input_size=self.input_size, | 
					
						
							|  |  |  |             hidden_size=self.n_units, | 
					
						
							|  |  |  |             num_layers=2, | 
					
						
							|  |  |  |             dropout=self.rnn_dropout, | 
					
						
							|  |  |  |             batch_first=True, | 
					
						
							|  |  |  |             bidirectional=False, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Initialize GRU parameters | 
					
						
							|  |  |  |         for name, param in self.gru.named_parameters(): | 
					
						
							|  |  |  |             if "weight_hh" in name: | 
					
						
							|  |  |  |                 nn.init.orthogonal_(param) | 
					
						
							|  |  |  |             if "weight_ih" in name: | 
					
						
							|  |  |  |                 nn.init.xavier_uniform_(param) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Output classification layer | 
					
						
							|  |  |  |         self.out = nn.Linear(self.n_units, self.n_classes) | 
					
						
							|  |  |  |         nn.init.xavier_uniform_(self.out.weight) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Learnable initial hidden state | 
					
						
							| 
									
										
										
										
											2025-10-12 22:52:38 +08:00
										 |  |  |         self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x, states=None, return_state=False): | 
					
						
							|  |  |  |         # Note: NoisySpeechModel doesn't need day-specific layers as it processes noise | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         batch_size = x.size(0) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # XLA-friendly hidden state initialization | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         if states is None: | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             states = self.h0.expand(2, batch_size, self.n_units).contiguous() | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # GRU forward pass | 
					
						
							|  |  |  |         output, hidden_states = self.gru(x, states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Classification | 
					
						
							|  |  |  |         logits = self.out(output) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if return_state: | 
					
						
							|  |  |  |             return logits, hidden_states | 
					
						
							|  |  |  |         return logits | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TripleGRUDecoder(nn.Module): | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  |     Three-model adversarial architecture for neural speech decoding | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Combines: | 
					
						
							|  |  |  |     - NoiseModel: estimates noise in neural data | 
					
						
							|  |  |  |     - CleanSpeechModel: processes denoised signal for recognition | 
					
						
							|  |  |  |     - NoisySpeechModel: processes noise signal for recognition | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  neural_dim, | 
					
						
							|  |  |  |                  n_units, | 
					
						
							|  |  |  |                  n_days, | 
					
						
							|  |  |  |                  n_classes, | 
					
						
							|  |  |  |                  rnn_dropout=0.0, | 
					
						
							|  |  |  |                  input_dropout=0.0, | 
					
						
							|  |  |  |                  patch_size=0, | 
					
						
							|  |  |  |                  patch_stride=0, | 
					
						
							|  |  |  |                  ): | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         neural_dim  (int)      - number of channels in a single timestep (e.g. 512) | 
					
						
							|  |  |  |         n_units     (int)      - number of hidden units in each recurrent layer | 
					
						
							|  |  |  |         n_days      (int)      - number of days in the dataset | 
					
						
							|  |  |  |         n_classes   (int)      - number of classes (phonemes) | 
					
						
							|  |  |  |         rnn_dropout    (float) - percentage of units to dropout during training | 
					
						
							|  |  |  |         input_dropout (float)  - percentage of input units to dropout during training | 
					
						
							|  |  |  |         patch_size  (int)      - number of timesteps to concat on initial input layer | 
					
						
							|  |  |  |         patch_stride(int)      - number of timesteps to stride over when concatenating initial input | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         super(TripleGRUDecoder, self).__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.neural_dim = neural_dim | 
					
						
							|  |  |  |         self.n_units = n_units | 
					
						
							|  |  |  |         self.n_classes = n_classes | 
					
						
							|  |  |  |         self.n_days = n_days | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.rnn_dropout = rnn_dropout | 
					
						
							|  |  |  |         self.input_dropout = input_dropout | 
					
						
							|  |  |  |         self.patch_size = patch_size | 
					
						
							|  |  |  |         self.patch_stride = patch_stride | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Create the three models | 
					
						
							|  |  |  |         self.noise_model = NoiseModel( | 
					
						
							|  |  |  |             neural_dim=neural_dim, | 
					
						
							|  |  |  |             n_units=n_units, | 
					
						
							|  |  |  |             n_days=n_days, | 
					
						
							|  |  |  |             rnn_dropout=rnn_dropout, | 
					
						
							|  |  |  |             input_dropout=input_dropout, | 
					
						
							|  |  |  |             patch_size=patch_size, | 
					
						
							|  |  |  |             patch_stride=patch_stride | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.clean_speech_model = CleanSpeechModel( | 
					
						
							|  |  |  |             neural_dim=neural_dim, | 
					
						
							|  |  |  |             n_units=n_units, | 
					
						
							|  |  |  |             n_days=n_days, | 
					
						
							|  |  |  |             n_classes=n_classes, | 
					
						
							|  |  |  |             rnn_dropout=rnn_dropout, | 
					
						
							|  |  |  |             input_dropout=input_dropout, | 
					
						
							|  |  |  |             patch_size=patch_size, | 
					
						
							|  |  |  |             patch_stride=patch_stride | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.noisy_speech_model = NoisySpeechModel( | 
					
						
							|  |  |  |             neural_dim=neural_dim, | 
					
						
							|  |  |  |             n_units=n_units, | 
					
						
							|  |  |  |             n_days=n_days, | 
					
						
							|  |  |  |             n_classes=n_classes, | 
					
						
							|  |  |  |             rnn_dropout=rnn_dropout, | 
					
						
							|  |  |  |             input_dropout=input_dropout, | 
					
						
							|  |  |  |             patch_size=patch_size, | 
					
						
							|  |  |  |             patch_stride=patch_stride | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Training mode flag | 
					
						
							|  |  |  |         self.training_mode = 'full'  # 'full', 'inference' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _apply_preprocessing(self, x, day_idx): | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         '''XLA-friendly preprocessing with static operations''' | 
					
						
							|  |  |  |         batch_size = x.size(0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # XLA-friendly day-specific transformation using gather instead of dynamic indexing | 
					
						
							|  |  |  |         all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0) | 
					
						
							|  |  |  |         all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # XLA-friendly gather operation | 
					
						
							|  |  |  |         day_weights = torch.index_select(all_day_weights, 0, day_idx) | 
					
						
							|  |  |  |         day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Use bmm (batch matrix multiply) which is highly optimized in XLA | 
					
						
							| 
									
										
										
										
											2025-10-14 22:48:28 +08:00
										 |  |  |         # Ensure dtype consistency for mixed precision training | 
					
						
							|  |  |  |         x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         x_processed = self.clean_speech_model.day_layer_activation(x_processed) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         # Apply patch processing if enabled | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         if self.patch_size > 0: | 
					
						
							|  |  |  |             x_processed = x_processed.unsqueeze(1) | 
					
						
							|  |  |  |             x_processed = x_processed.permute(0, 3, 1, 2) | 
					
						
							|  |  |  |             x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride) | 
					
						
							|  |  |  |             x_unfold = x_unfold.squeeze(2) | 
					
						
							|  |  |  |             x_unfold = x_unfold.permute(0, 2, 3, 1) | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return x_processed | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None): | 
					
						
							|  |  |  |         '''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)''' | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         batch_size = x_processed.size(0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # XLA-friendly hidden state initialization | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         if states is None: | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous() | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # GRU forward pass (skip preprocessing since input is already processed) | 
					
						
							|  |  |  |         output, hidden_states = self.clean_speech_model.gru(x_processed, states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Classification | 
					
						
							|  |  |  |         logits = self.clean_speech_model.out(output) | 
					
						
							|  |  |  |         return logits | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _noisy_forward_with_processed_input(self, x_processed, states=None): | 
					
						
							|  |  |  |         '''Forward pass for NoisySpeechModel with already processed input''' | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |         batch_size = x_processed.size(0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # XLA-friendly hidden state initialization | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         if states is None: | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous() | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # GRU forward pass (NoisySpeechModel doesn't have day layers anyway) | 
					
						
							|  |  |  |         output, hidden_states = self.noisy_speech_model.gru(x_processed, states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Classification | 
					
						
							|  |  |  |         logits = self.noisy_speech_model.out(output) | 
					
						
							|  |  |  |         return logits | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-14 22:48:28 +08:00
										 |  |  |     def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0): | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         '''
 | 
					
						
							|  |  |  |         Three-model adversarial forward pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         x        (tensor)  - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim) | 
					
						
							|  |  |  |         day_idx  (tensor)  - tensor of day indices for each example in the batch | 
					
						
							|  |  |  |         states   (dict)    - dictionary with 'noise', 'clean', 'noisy' states or None | 
					
						
							|  |  |  |         mode     (str)     - 'full' for training (all three models), 'inference' for inference (noise + clean only) | 
					
						
							| 
									
										
										
										
											2025-10-14 22:48:28 +08:00
										 |  |  |         grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |         '''
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if mode == 'full': | 
					
						
							|  |  |  |             # Training mode: run all three models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 1. Noise model estimates noise in the data | 
					
						
							|  |  |  |             noise_output, noise_hidden = self.noise_model(x, day_idx, | 
					
						
							|  |  |  |                                                          states['noise'] if states else None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 2. For residual connection, we need x in the same space as noise_output | 
					
						
							|  |  |  |             # Apply the same preprocessing that the models use internally | 
					
						
							|  |  |  |             x_processed = self._apply_preprocessing(x, day_idx) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 3. Clean speech model processes denoised signal | 
					
						
							| 
									
										
										
										
											2025-10-14 23:11:54 +08:00
										 |  |  |             # Ensure dtype consistency for mixed precision training in residual connection | 
					
						
							|  |  |  |             denoised_input = x_processed - noise_output.to(x_processed.dtype)  # Residual connection in processed space | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |             # Clean speech model will apply its own preprocessing, so we pass the denoised processed data | 
					
						
							|  |  |  |             # But we need to reverse the preprocessing first, then let clean model do its own | 
					
						
							|  |  |  |             # Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing | 
					
						
							|  |  |  |             clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, | 
					
						
							|  |  |  |                                                                    states['clean'] if states else None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 4. Noisy speech model processes noise signal directly (no day layers needed) | 
					
						
							| 
									
										
										
										
											2025-10-14 22:48:28 +08:00
										 |  |  |             # Optionally apply Gradient Reversal to enforce adversarial training on noise output | 
					
						
							|  |  |  |             noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output | 
					
						
							|  |  |  |             noisy_logits = self._noisy_forward_with_processed_input(noisy_input, | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |                                                                    states['noisy'] if states else None) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             # XLA-friendly return - use tuple instead of dict for better compilation | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |             if return_state: | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |                 return (clean_logits, noisy_logits, noise_output), noise_hidden | 
					
						
							|  |  |  |             return clean_logits, noisy_logits, noise_output | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         elif mode == 'inference': | 
					
						
							|  |  |  |             # Inference mode: only noise model + clean speech model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 1. Estimate noise | 
					
						
							|  |  |  |             noise_output, noise_hidden = self.noise_model(x, day_idx, | 
					
						
							|  |  |  |                                                          states['noise'] if states else None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 2. For residual connection, we need x in the same space as noise_output | 
					
						
							|  |  |  |             x_processed = self._apply_preprocessing(x, day_idx) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # 3. Process denoised signal | 
					
						
							| 
									
										
										
										
											2025-10-14 23:11:54 +08:00
										 |  |  |             # Ensure dtype consistency for mixed precision training in residual connection | 
					
						
							|  |  |  |             denoised_input = x_processed - noise_output.to(x_processed.dtype) | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |             clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, | 
					
						
							|  |  |  |                                                                    states['clean'] if states else None) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |             # XLA-friendly return - use tuple for consistency | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |             if return_state: | 
					
						
							| 
									
										
										
										
											2025-10-12 23:36:16 +08:00
										 |  |  |                 return clean_logits, noise_hidden | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  |             return clean_logits | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3): | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         Apply combined gradients to noise model parameters | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         clean_grad (tensor) - gradients from clean speech model output layer | 
					
						
							|  |  |  |         noisy_grad (tensor) - gradients from noisy speech model output layer | 
					
						
							|  |  |  |         learning_rate (float) - learning rate for gradient update | 
					
						
							|  |  |  |         '''
 | 
					
						
							|  |  |  |         # Combine gradients: negative from clean model, positive from noisy model | 
					
						
							|  |  |  |         combined_grad = -clean_grad + noisy_grad | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Apply gradients to noise model parameters | 
					
						
							|  |  |  |         # This is a simplified implementation - in practice you'd want more sophisticated update rules | 
					
						
							|  |  |  |         with torch.no_grad(): | 
					
						
							|  |  |  |             for param in self.noise_model.parameters(): | 
					
						
							|  |  |  |                 if param.grad is not None: | 
					
						
							|  |  |  |                     # Scale the combined gradient appropriately | 
					
						
							|  |  |  |                     # This is a placeholder - you'd need to implement proper gradient mapping | 
					
						
							|  |  |  |                     param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def set_mode(self, mode): | 
					
						
							|  |  |  |         '''Set the operating mode''' | 
					
						
							|  |  |  |         self.training_mode = mode | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  | 
 |