修复B模型未启用的错误
This commit is contained in:
18
CLAUDE.md
18
CLAUDE.md
@@ -162,7 +162,17 @@ day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
|||||||
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||||||
|
|
||||||
# After (XLA-optimized):
|
# After (XLA-optimized):
|
||||||
x = torch.bmm(x, day_weights) + day_biases # bmm is highly optimized in XLA
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5. Mixed Precision Dtype Consistency
|
||||||
|
**Problem**: Mixed precision training causes dtype mismatches in bmm operations
|
||||||
|
**Solution**: Ensure all operands match input tensor dtype
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
||||||
|
# Fix: Add dtype conversions for all bmm operands
|
||||||
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 3. Hidden State Initialization
|
#### 3. Hidden State Initialization
|
||||||
@@ -199,11 +209,11 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return
|
|||||||
### Files Modified for XLA Optimization
|
### Files Modified for XLA Optimization
|
||||||
|
|
||||||
- **`model_training_nnn/rnn_model.py`**: All three models optimized
|
- **`model_training_nnn/rnn_model.py`**: All three models optimized
|
||||||
- `NoiseModel.forward()`: Dynamic indexing → static gather operations
|
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency
|
||||||
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops
|
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency
|
||||||
- `NoisySpeechModel.forward()`: Hidden state optimization
|
- `NoisySpeechModel.forward()`: Hidden state optimization
|
||||||
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns
|
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns
|
||||||
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations
|
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency
|
||||||
|
|
||||||
### Benefits of XLA Optimizations
|
### Benefits of XLA Optimizations
|
||||||
|
|
||||||
|
@@ -1,6 +1,24 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
class GradientReversalFn(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Gradient Reversal Layer (GRL)
|
||||||
|
Forward: identity
|
||||||
|
Backward: multiply incoming gradient by -lambda
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, lambd: float):
|
||||||
|
ctx.lambd = lambd
|
||||||
|
return x.view_as(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return -ctx.lambd * grad_output, None
|
||||||
|
|
||||||
|
def gradient_reverse(x, lambd: float = 1.0):
|
||||||
|
return GradientReversalFn.apply(x, lambd)
|
||||||
|
|
||||||
class NoiseModel(nn.Module):
|
class NoiseModel(nn.Module):
|
||||||
'''
|
'''
|
||||||
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
||||||
@@ -361,7 +379,8 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
||||||
|
|
||||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||||
x_processed = torch.bmm(x, day_weights) + day_biases
|
# Ensure dtype consistency for mixed precision training
|
||||||
|
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||||
|
|
||||||
# Apply patch processing if enabled
|
# Apply patch processing if enabled
|
||||||
@@ -405,7 +424,7 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
logits = self.noisy_speech_model.out(output)
|
logits = self.noisy_speech_model.out(output)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def forward(self, x, day_idx, states=None, return_state=False, mode='inference'):
|
def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
|
||||||
'''
|
'''
|
||||||
Three-model adversarial forward pass
|
Three-model adversarial forward pass
|
||||||
|
|
||||||
@@ -413,6 +432,7 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
day_idx (tensor) - tensor of day indices for each example in the batch
|
day_idx (tensor) - tensor of day indices for each example in the batch
|
||||||
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
|
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
|
||||||
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
|
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
|
||||||
|
grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if mode == 'full':
|
if mode == 'full':
|
||||||
@@ -435,7 +455,9 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
states['clean'] if states else None)
|
states['clean'] if states else None)
|
||||||
|
|
||||||
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
||||||
noisy_logits = self._noisy_forward_with_processed_input(noise_output,
|
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
|
||||||
|
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
|
||||||
|
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
|
||||||
states['noisy'] if states else None)
|
states['noisy'] if states else None)
|
||||||
|
|
||||||
# XLA-friendly return - use tuple instead of dict for better compilation
|
# XLA-friendly return - use tuple instead of dict for better compilation
|
||||||
|
@@ -86,6 +86,14 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
self.transform_args = self.args['dataset']['data_transforms']
|
self.transform_args = self.args['dataset']['data_transforms']
|
||||||
|
|
||||||
|
# Adversarial training config (safe defaults if not provided)
|
||||||
|
adv_cfg = self.args.get('adversarial', {})
|
||||||
|
self.adv_enabled = adv_cfg.get('enabled', False)
|
||||||
|
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength
|
||||||
|
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC
|
||||||
|
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output
|
||||||
|
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps
|
||||||
|
|
||||||
# Create output directory
|
# Create output directory
|
||||||
if args['mode'] == 'train':
|
if args['mode'] == 'train':
|
||||||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||||||
@@ -291,6 +299,8 @@ class BrainToTextDecoder_Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.logger.info("Prepared model and dataloaders with Accelerator")
|
self.logger.info("Prepared model and dataloaders with Accelerator")
|
||||||
|
if self.adv_enabled:
|
||||||
|
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
'''
|
'''
|
||||||
@@ -583,17 +593,47 @@ class BrainToTextDecoder_Trainer:
|
|||||||
# In mixed precision mode, ensure features match the expected precision
|
# In mixed precision mode, ensure features match the expected precision
|
||||||
features = features.to(torch.float32)
|
features = features.to(torch.float32)
|
||||||
|
|
||||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
# Forward pass: enable full adversarial mode if configured and past warmup
|
||||||
|
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
|
||||||
|
if use_full:
|
||||||
|
clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
|
||||||
|
else:
|
||||||
|
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||||
|
|
||||||
# Calculate CTC Loss
|
# Calculate CTC Loss
|
||||||
loss = self.ctc_loss(
|
if use_full:
|
||||||
log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]),
|
# Clean CTC loss
|
||||||
targets = labels,
|
clean_loss = self.ctc_loss(
|
||||||
input_lengths = adjusted_lens,
|
torch.permute(clean_logits.log_softmax(2), [1, 0, 2]),
|
||||||
target_lengths = phone_seq_lens
|
labels,
|
||||||
|
adjusted_lens,
|
||||||
|
phone_seq_lens
|
||||||
)
|
)
|
||||||
|
clean_loss = torch.mean(clean_loss)
|
||||||
|
|
||||||
loss = torch.mean(loss) # take mean loss over batches
|
# Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
|
||||||
|
noisy_loss = self.ctc_loss(
|
||||||
|
torch.permute(noisy_logits.log_softmax(2), [1, 0, 2]),
|
||||||
|
labels,
|
||||||
|
adjusted_lens,
|
||||||
|
phone_seq_lens
|
||||||
|
)
|
||||||
|
noisy_loss = torch.mean(noisy_loss)
|
||||||
|
|
||||||
|
# Optional noise energy regularization
|
||||||
|
noise_l2 = torch.tensor(0.0, device=self.device)
|
||||||
|
if self.adv_noise_l2_weight > 0.0:
|
||||||
|
noise_l2 = torch.mean(noise_output.pow(2))
|
||||||
|
|
||||||
|
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||||
|
else:
|
||||||
|
loss = self.ctc_loss(
|
||||||
|
log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]),
|
||||||
|
targets = labels,
|
||||||
|
input_lengths = adjusted_lens,
|
||||||
|
target_lengths = phone_seq_lens
|
||||||
|
)
|
||||||
|
loss = torch.mean(loss) # take mean loss over batches
|
||||||
|
|
||||||
# Use Accelerator's backward for distributed training
|
# Use Accelerator's backward for distributed training
|
||||||
self.accelerator.backward(loss)
|
self.accelerator.backward(loss)
|
||||||
@@ -673,7 +713,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
# Optionally save this validation checkpoint, regardless of performance
|
# Optionally save this validation checkpoint, regardless of performance
|
||||||
if self.args['save_all_val_steps']:
|
if self.args['save_all_val_steps']:
|
||||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'])
|
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
|
||||||
|
|
||||||
# Early stopping
|
# Early stopping
|
||||||
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
|
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
|
||||||
@@ -689,7 +729,8 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
# Save final model
|
# Save final model
|
||||||
if self.args['save_final_model']:
|
if self.args['save_final_model']:
|
||||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1])
|
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||||||
|
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
|
||||||
|
|
||||||
train_stats = {}
|
train_stats = {}
|
||||||
train_stats['train_losses'] = train_losses
|
train_stats['train_losses'] = train_losses
|
||||||
|
Reference in New Issue
Block a user