修复B模型未启用的错误
This commit is contained in:
@@ -1,6 +1,24 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class GradientReversalFn(torch.autograd.Function):
|
||||
"""
|
||||
Gradient Reversal Layer (GRL)
|
||||
Forward: identity
|
||||
Backward: multiply incoming gradient by -lambda
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x, lambd: float):
|
||||
ctx.lambd = lambd
|
||||
return x.view_as(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return -ctx.lambd * grad_output, None
|
||||
|
||||
def gradient_reverse(x, lambd: float = 1.0):
|
||||
return GradientReversalFn.apply(x, lambd)
|
||||
|
||||
class NoiseModel(nn.Module):
|
||||
'''
|
||||
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
||||
@@ -361,7 +379,8 @@ class TripleGRUDecoder(nn.Module):
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
x_processed = torch.bmm(x, day_weights) + day_biases
|
||||
# Ensure dtype consistency for mixed precision training
|
||||
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
@@ -405,7 +424,7 @@ class TripleGRUDecoder(nn.Module):
|
||||
logits = self.noisy_speech_model.out(output)
|
||||
return logits
|
||||
|
||||
def forward(self, x, day_idx, states=None, return_state=False, mode='inference'):
|
||||
def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
|
||||
'''
|
||||
Three-model adversarial forward pass
|
||||
|
||||
@@ -413,6 +432,7 @@ class TripleGRUDecoder(nn.Module):
|
||||
day_idx (tensor) - tensor of day indices for each example in the batch
|
||||
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
|
||||
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
|
||||
grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
|
||||
'''
|
||||
|
||||
if mode == 'full':
|
||||
@@ -435,7 +455,9 @@ class TripleGRUDecoder(nn.Module):
|
||||
states['clean'] if states else None)
|
||||
|
||||
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
||||
noisy_logits = self._noisy_forward_with_processed_input(noise_output,
|
||||
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
|
||||
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
|
||||
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
|
||||
states['noisy'] if states else None)
|
||||
|
||||
# XLA-friendly return - use tuple instead of dict for better compilation
|
||||
|
Reference in New Issue
Block a user