This commit is contained in:
Zchen
2025-10-17 11:38:57 +08:00
parent 6c7abfcca8
commit eb4e3fc69f

View File

@@ -10,6 +10,13 @@ import sys
from typing import Dict, Any, Tuple, Optional, List
from omegaconf import OmegaConf
# For accurate PER calculation
try:
import editdistance
except ImportError:
print("Warning: editdistance not available, falling back to approximation")
editdistance = None
from rnn_model_tf import (
TripleGRUDecoder,
CTCLoss,
@@ -99,6 +106,30 @@ class BrainToTextDecoderTrainerTF:
self.lr_scheduler = self._create_lr_scheduler()
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
# Create unified checkpoint management
self.ckpt = tf.train.Checkpoint(
optimizer=self.optimizer,
model=self.model
)
self.ckpt_manager = tf.train.CheckpointManager(
self.ckpt,
directory=self.args['checkpoint_dir'],
max_to_keep=5 # Keep only the 5 most recent checkpoints
)
# Try to restore from latest checkpoint
if self.ckpt_manager.latest_checkpoint:
print(f"🔄 Restoring from {self.ckpt_manager.latest_checkpoint}")
if self.logger:
self.logger.info(f"Restoring from {self.ckpt_manager.latest_checkpoint}")
# expect_partial() avoids failures due to model structure changes
self.ckpt.restore(self.ckpt_manager.latest_checkpoint).expect_partial()
print("✅ Checkpoint restored successfully")
else:
print("🆕 Initializing training from scratch")
if self.logger:
self.logger.info("Initializing training from scratch")
# Log model information
self._log_model_info()
@@ -110,12 +141,10 @@ class BrainToTextDecoderTrainerTF:
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
# Manual weight decay handling for all environments (since we use Adam)
# Manual weight decay handling - disable since AdamW handles it
self.manual_weight_decay = False
if self.args.get('weight_decay', 0.0) > 0:
self.manual_weight_decay = True
self.weight_decay_rate = self.args['weight_decay']
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
print(f"🔧 Weight decay configured in AdamW: {self.args.get('weight_decay', 0.0)}")
else:
print("💡 No weight decay configured")
@@ -402,39 +431,18 @@ class BrainToTextDecoderTrainerTF:
return model
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Create AdamW optimizer with parameter groups"""
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
# We'll use a single optimizer and handle different learning rates in the scheduler
"""Create AdamW optimizer"""
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
print("Using AdamW optimizer for TPU training")
# For TPU training, we need to be more explicit about optimizer configuration
# to avoid strategy context issues
# IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs
# AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0
# We implement manual L2 regularization (weight decay) in the training step instead
print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
print("💡 Manual L2 regularization will be applied in training step")
# Use legacy Adam optimizer for better TPU distributed training compatibility
# Legacy optimizers have more stable distributed training implementations
try:
optimizer = tf.keras.optimizers.legacy.Adam(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon']
)
print("✅ Using legacy Adam optimizer for better TPU compatibility")
except AttributeError:
# Fallback to standard Adam if legacy is not available
optimizer = tf.keras.optimizers.Adam(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon']
)
print("⚠️ Using standard Adam optimizer (legacy not available)")
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=self.args.get('weight_decay', 0.0)
)
print("✅ Using AdamW optimizer")
return optimizer
@@ -486,6 +494,7 @@ class BrainToTextDecoderTrainerTF:
else:
print(f"Model has {total_params:,} trainable parameters")
@tf.function
def _train_step(self, batch, step):
"""Single training step with gradient tape"""
features = batch['input_features']
@@ -554,16 +563,7 @@ class BrainToTextDecoderTrainerTF:
loss = self.ctc_loss(loss_input, clean_logits)
loss = tf.reduce_mean(loss)
# Add manual L2 regularization for TPU (since weight_decay is disabled)
if self.manual_weight_decay:
l2_loss = tf.constant(0.0, dtype=loss.dtype)
for var in self.model.trainable_variables:
# Ensure dtype consistency for mixed precision training
var_l2 = tf.nn.l2_loss(var)
var_l2 = tf.cast(var_l2, dtype=loss.dtype) # Cast to match loss dtype
l2_loss += var_l2
loss += self.weight_decay_rate * l2_loss
# AdamW handles weight decay automatically - no manual L2 regularization needed
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理
# TPU v5e-8使用bfloat16不需要loss scaling
@@ -599,7 +599,7 @@ class BrainToTextDecoderTrainerTF:
@tf.function
def _validation_step(self, batch):
"""Single validation step"""
"""Single validation step - returns data for accurate PER calculation"""
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
@@ -630,38 +630,11 @@ class BrainToTextDecoderTrainerTF:
loss = self.ctc_loss(loss_input, logits)
loss = tf.reduce_mean(loss)
# Calculate simplified PER approximation (TPU-compatible)
# For TPU training, we use a simplified metric that avoids complex loops
# This gives an approximation of PER but is much faster and TPU-compatible
# Greedy decoding for PER calculation
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
# Greedy decoding
predicted_ids = tf.argmax(logits, axis=-1)
# Simple approximation: count exact matches vs mismatches
# This is less accurate than true edit distance but TPU-compatible
batch_size = tf.shape(logits)[0]
# For each sample, compare predicted vs true sequences
total_mismatches = tf.constant(0, dtype=tf.int32)
for i in tf.range(batch_size):
# Get sequences for this sample
pred_seq = predicted_ids[i, :adjusted_lens[i]]
true_seq = labels[i, :phone_seq_lens[i]]
# Pad to same length for comparison
max_len = tf.maximum(tf.shape(pred_seq)[0], tf.shape(true_seq)[0])
pred_padded = tf.pad(pred_seq, [[0, max_len - tf.shape(pred_seq)[0]]], constant_values=0)
true_padded = tf.pad(true_seq, [[0, max_len - tf.shape(true_seq)[0]]], constant_values=0)
# Count mismatches
mismatches = tf.reduce_sum(tf.cast(tf.not_equal(pred_padded, true_padded), tf.int32))
total_mismatches += mismatches
# Approximate edit distance as number of mismatches
batch_edit_distance = tf.cast(total_mismatches, tf.float32)
return loss, batch_edit_distance, tf.cast(tf.reduce_sum(phone_seq_lens), tf.float32)
# Return all necessary data for accurate PER calculation on CPU
return loss, predicted_ids, labels, adjusted_lens, phone_seq_lens
def train(self) -> Dict[str, Any]:
"""Main training loop"""
@@ -671,28 +644,28 @@ class BrainToTextDecoderTrainerTF:
initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# Create distributed datasets
train_dataset = create_input_fn(
self.train_dataset_tf,
self.args['dataset']['data_transforms'],
training=True
)
val_dataset = create_input_fn(
self.val_dataset_tf,
self.args['dataset']['data_transforms'],
training=False
)
# Distribute datasets with timing
# Create datasets using modern distribution API
def create_dist_dataset_fn(input_dataset_tf, training):
"""Create distributed dataset function for modern TPU strategy"""
def dataset_fn(input_context):
# create_input_fn returns a complete, batched tf.data.Dataset
return create_input_fn(
input_dataset_tf,
self.args['dataset']['data_transforms'],
training=training
)
return self.strategy.distribute_datasets_from_function(dataset_fn)
# Distribute datasets using modern API
self.logger.info("🔄 Distributing training dataset across TPU cores...")
dist_start_time = time.time()
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
train_dist_time = time.time() - dist_start_time
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
val_start_time = time.time()
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
val_dist_time = time.time() - val_start_time
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
@@ -709,18 +682,13 @@ class BrainToTextDecoderTrainerTF:
# Training loop
step = 0
# Add timing diagnostic for first batch iteration
self.logger.info("🔄 Starting training loop iteration...")
loop_start_time = time.time()
self.logger.info("🔄 Starting training loop...")
for batch in train_dist_dataset:
if step == 0:
first_batch_iteration_time = time.time() - loop_start_time
self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s")
if step >= self.args['num_training_batches']:
self.logger.info("Reached maximum training batches, stopping training")
break
start_time = time.time()
# Distributed training step
@@ -794,7 +762,7 @@ class BrainToTextDecoderTrainerTF:
if new_best:
if self.args.get('save_best_checkpoint', True):
self.logger.info("Checkpointing model")
self._save_checkpoint('best_checkpoint', step)
self._save_checkpoint(step)
if self.args.get('save_val_metrics', True):
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
@@ -806,7 +774,7 @@ class BrainToTextDecoderTrainerTF:
# Optional save all validation checkpoints
if self.args.get('save_all_val_steps', False):
self._save_checkpoint(f'checkpoint_batch_{step}', step)
self._save_checkpoint(step)
# Early stopping
if (self.args.get('early_stopping', False) and
@@ -825,7 +793,7 @@ class BrainToTextDecoderTrainerTF:
# Save final model
if self.args.get('save_final_model', False):
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1)
self._save_checkpoint(step-1)
return {
'train_losses': train_losses,
@@ -835,25 +803,58 @@ class BrainToTextDecoderTrainerTF:
}
def _validate(self, val_dataset) -> Dict[str, Any]:
"""Run validation on entire validation dataset"""
"""Run validation on entire validation dataset with accurate PER calculation"""
total_loss = 0.0
total_edit_distance = 0
total_seq_length = 0
num_batches = 0
for batch in val_dataset:
per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = (
# Get predictions and labels from all TPU cores
per_replica_losses, per_replica_preds, per_replica_labels, per_replica_pred_lens, per_replica_label_lens = (
self.strategy.run(self._validation_step, args=(batch,))
)
# Reduce across replicas
# Reduce loss across replicas
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None)
batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None)
total_loss += float(batch_loss.numpy())
total_edit_distance += float(batch_edit_distance.numpy())
total_seq_length += float(batch_seq_length.numpy())
# Gather all data from TPU cores to CPU for accurate PER calculation
all_preds = self.strategy.gather(per_replica_preds, axis=0)
all_labels = self.strategy.gather(per_replica_labels, axis=0)
all_pred_lens = self.strategy.gather(per_replica_pred_lens, axis=0)
all_label_lens = self.strategy.gather(per_replica_label_lens, axis=0)
# Calculate accurate edit distance on CPU
batch_size = all_preds.shape[0]
for i in range(batch_size):
pred_len = int(all_pred_lens[i].numpy())
label_len = int(all_label_lens[i].numpy())
# Extract sequences and remove CTC blanks (assuming blank_index=0)
pred_seq = all_preds[i, :pred_len].numpy()
pred_seq = [p for p in pred_seq if p != 0] # Remove blanks
# Remove consecutive duplicates (CTC decoding)
if len(pred_seq) > 0:
deduped_pred = [pred_seq[0]]
for j in range(1, len(pred_seq)):
if pred_seq[j] != pred_seq[j-1]:
deduped_pred.append(pred_seq[j])
pred_seq = deduped_pred
true_seq = all_labels[i, :label_len].numpy().tolist()
# Calculate edit distance using proper library if available
if editdistance is not None:
edit_dist = editdistance.eval(pred_seq, true_seq)
else:
# Fallback to simple approximation if editdistance not available
edit_dist = self._simple_edit_distance(pred_seq, true_seq)
total_edit_distance += edit_dist
total_seq_length += label_len
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
@@ -867,50 +868,95 @@ class BrainToTextDecoderTrainerTF:
'num_batches': num_batches
}
def _save_checkpoint(self, name: str, step: int):
"""Save model checkpoint"""
checkpoint_path = os.path.join(self.args['checkpoint_dir'], name)
def _simple_edit_distance(self, seq1, seq2):
"""Simple edit distance implementation as fallback"""
# Dynamic programming implementation of edit distance
m, n = len(seq1), len(seq2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
# Save model weights
self.model.save_weights(checkpoint_path + '.weights.h5')
# Initialize base cases
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
# Save optimizer state
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
optimizer_checkpoint.save(checkpoint_path + '.optimizer')
# Fill the DP table
for i in range(1, m + 1):
for j in range(1, n + 1):
if seq1[i-1] == seq2[j-1]:
dp[i][j] = dp[i-1][j-1]
else:
dp[i][j] = 1 + min(
dp[i-1][j], # deletion
dp[i][j-1], # insertion
dp[i-1][j-1] # substitution
)
# Save training state
return dp[m][n]
def _save_checkpoint(self, step: int, name: str = ""):
"""Save checkpoint using the unified CheckpointManager"""
# CheckpointManager automatically handles naming and numbering
# The 'name' parameter is kept for backward compatibility but not used
save_path = self.ckpt_manager.save(checkpoint_number=step)
if self.logger:
self.logger.info(f"Saved checkpoint for step {step}: {save_path}")
else:
print(f"Saved checkpoint for step {step}: {save_path}")
# Save non-TensorFlow Python state separately
state = {
'step': step,
'best_val_per': float(self.best_val_per),
'best_val_loss': float(self.best_val_loss)
}
with open(checkpoint_path + '.state.json', 'w') as f:
# Associate state file with checkpoint
state_path = os.path.join(self.args['checkpoint_dir'], f'state-{step}.json')
with open(state_path, 'w') as f:
json.dump(state, f)
# Save config
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
self.logger.info(f"Saved checkpoint: {checkpoint_path}")
# Save config file (only once)
config_path = os.path.join(self.args['checkpoint_dir'], 'args.yaml')
if not os.path.exists(config_path):
with open(config_path, 'w') as f:
OmegaConf.save(config=self.args, f=f)
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint"""
# Load model weights
self.model.load_weights(checkpoint_path + '.weights.h5')
"""Load a specific checkpoint and its associated training state"""
if self.logger:
self.logger.info(f"Loading checkpoint from: {checkpoint_path}")
else:
print(f"Loading checkpoint from: {checkpoint_path}")
# Load optimizer state
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1')
# Restore TensorFlow objects (model and optimizer)
self.ckpt.restore(checkpoint_path).expect_partial()
# Load training state
with open(checkpoint_path + '.state.json', 'r') as f:
state = json.load(f)
# Restore non-TensorFlow training state
try:
# Extract step number from checkpoint path (e.g., ckpt-123 -> 123)
step = int(checkpoint_path.split('-')[-1])
state_path = os.path.join(os.path.dirname(checkpoint_path), f'state-{step}.json')
self.best_val_per = state['best_val_per']
self.best_val_loss = state['best_val_loss']
with open(state_path, 'r') as f:
state = json.load(f)
self.logger.info(f"Loaded checkpoint: {checkpoint_path}")
self.best_val_per = state['best_val_per']
self.best_val_loss = state['best_val_loss']
if self.logger:
self.logger.info(f"Restored training state from: {state_path}")
else:
print(f"Restored training state from: {state_path}")
except (IOError, ValueError, KeyError) as e:
warning_msg = (f"Could not load or parse state file for checkpoint {checkpoint_path}. "
f"Starting with fresh state. Error: {e}")
if self.logger:
self.logger.warning(warning_msg)
else:
print(f"⚠️ {warning_msg}")
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor: