修复B模型未启用的错误
This commit is contained in:
@@ -86,6 +86,14 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
self.transform_args = self.args['dataset']['data_transforms']
|
||||
|
||||
# Adversarial training config (safe defaults if not provided)
|
||||
adv_cfg = self.args.get('adversarial', {})
|
||||
self.adv_enabled = adv_cfg.get('enabled', False)
|
||||
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength
|
||||
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC
|
||||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output
|
||||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps
|
||||
|
||||
# Create output directory
|
||||
if args['mode'] == 'train':
|
||||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||||
@@ -291,6 +299,8 @@ class BrainToTextDecoder_Trainer:
|
||||
)
|
||||
|
||||
self.logger.info("Prepared model and dataloaders with Accelerator")
|
||||
if self.adv_enabled:
|
||||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
|
||||
|
||||
def create_optimizer(self):
|
||||
'''
|
||||
@@ -583,17 +593,47 @@ class BrainToTextDecoder_Trainer:
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
# Forward pass: enable full adversarial mode if configured and past warmup
|
||||
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
|
||||
if use_full:
|
||||
clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
|
||||
else:
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
# Calculate CTC Loss
|
||||
loss = self.ctc_loss(
|
||||
log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]),
|
||||
targets = labels,
|
||||
input_lengths = adjusted_lens,
|
||||
target_lengths = phone_seq_lens
|
||||
if use_full:
|
||||
# Clean CTC loss
|
||||
clean_loss = self.ctc_loss(
|
||||
torch.permute(clean_logits.log_softmax(2), [1, 0, 2]),
|
||||
labels,
|
||||
adjusted_lens,
|
||||
phone_seq_lens
|
||||
)
|
||||
|
||||
loss = torch.mean(loss) # take mean loss over batches
|
||||
clean_loss = torch.mean(clean_loss)
|
||||
|
||||
# Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
|
||||
noisy_loss = self.ctc_loss(
|
||||
torch.permute(noisy_logits.log_softmax(2), [1, 0, 2]),
|
||||
labels,
|
||||
adjusted_lens,
|
||||
phone_seq_lens
|
||||
)
|
||||
noisy_loss = torch.mean(noisy_loss)
|
||||
|
||||
# Optional noise energy regularization
|
||||
noise_l2 = torch.tensor(0.0, device=self.device)
|
||||
if self.adv_noise_l2_weight > 0.0:
|
||||
noise_l2 = torch.mean(noise_output.pow(2))
|
||||
|
||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||
else:
|
||||
loss = self.ctc_loss(
|
||||
log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]),
|
||||
targets = labels,
|
||||
input_lengths = adjusted_lens,
|
||||
target_lengths = phone_seq_lens
|
||||
)
|
||||
loss = torch.mean(loss) # take mean loss over batches
|
||||
|
||||
# Use Accelerator's backward for distributed training
|
||||
self.accelerator.backward(loss)
|
||||
@@ -673,7 +713,7 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
# Optionally save this validation checkpoint, regardless of performance
|
||||
if self.args['save_all_val_steps']:
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'])
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
|
||||
|
||||
# Early stopping
|
||||
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
|
||||
@@ -689,7 +729,8 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
# Save final model
|
||||
if self.args['save_final_model']:
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1])
|
||||
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
|
||||
|
||||
train_stats = {}
|
||||
train_stats['train_losses'] = train_losses
|
||||
|
Reference in New Issue
Block a user