fff
This commit is contained in:
@@ -10,6 +10,13 @@ import sys
|
||||
from typing import Dict, Any, Tuple, Optional, List
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# For accurate PER calculation
|
||||
try:
|
||||
import editdistance
|
||||
except ImportError:
|
||||
print("Warning: editdistance not available, falling back to approximation")
|
||||
editdistance = None
|
||||
|
||||
from rnn_model_tf import (
|
||||
TripleGRUDecoder,
|
||||
CTCLoss,
|
||||
@@ -99,6 +106,30 @@ class BrainToTextDecoderTrainerTF:
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
|
||||
# Create unified checkpoint management
|
||||
self.ckpt = tf.train.Checkpoint(
|
||||
optimizer=self.optimizer,
|
||||
model=self.model
|
||||
)
|
||||
self.ckpt_manager = tf.train.CheckpointManager(
|
||||
self.ckpt,
|
||||
directory=self.args['checkpoint_dir'],
|
||||
max_to_keep=5 # Keep only the 5 most recent checkpoints
|
||||
)
|
||||
|
||||
# Try to restore from latest checkpoint
|
||||
if self.ckpt_manager.latest_checkpoint:
|
||||
print(f"🔄 Restoring from {self.ckpt_manager.latest_checkpoint}")
|
||||
if self.logger:
|
||||
self.logger.info(f"Restoring from {self.ckpt_manager.latest_checkpoint}")
|
||||
# expect_partial() avoids failures due to model structure changes
|
||||
self.ckpt.restore(self.ckpt_manager.latest_checkpoint).expect_partial()
|
||||
print("✅ Checkpoint restored successfully")
|
||||
else:
|
||||
print("🆕 Initializing training from scratch")
|
||||
if self.logger:
|
||||
self.logger.info("Initializing training from scratch")
|
||||
|
||||
# Log model information
|
||||
self._log_model_info()
|
||||
|
||||
@@ -110,12 +141,10 @@ class BrainToTextDecoderTrainerTF:
|
||||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
||||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
||||
|
||||
# Manual weight decay handling for all environments (since we use Adam)
|
||||
# Manual weight decay handling - disable since AdamW handles it
|
||||
self.manual_weight_decay = False
|
||||
if self.args.get('weight_decay', 0.0) > 0:
|
||||
self.manual_weight_decay = True
|
||||
self.weight_decay_rate = self.args['weight_decay']
|
||||
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
|
||||
print(f"🔧 Weight decay configured in AdamW: {self.args.get('weight_decay', 0.0)}")
|
||||
else:
|
||||
print("💡 No weight decay configured")
|
||||
|
||||
@@ -402,39 +431,18 @@ class BrainToTextDecoderTrainerTF:
|
||||
return model
|
||||
|
||||
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
||||
"""Create AdamW optimizer with parameter groups"""
|
||||
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
||||
# We'll use a single optimizer and handle different learning rates in the scheduler
|
||||
|
||||
"""Create AdamW optimizer"""
|
||||
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
||||
print("Using AdamW optimizer for TPU training")
|
||||
|
||||
# For TPU training, we need to be more explicit about optimizer configuration
|
||||
# to avoid strategy context issues
|
||||
# IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs
|
||||
# AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0
|
||||
# We implement manual L2 regularization (weight decay) in the training step instead
|
||||
print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
|
||||
print("💡 Manual L2 regularization will be applied in training step")
|
||||
|
||||
# Use legacy Adam optimizer for better TPU distributed training compatibility
|
||||
# Legacy optimizers have more stable distributed training implementations
|
||||
try:
|
||||
optimizer = tf.keras.optimizers.legacy.Adam(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon']
|
||||
)
|
||||
print("✅ Using legacy Adam optimizer for better TPU compatibility")
|
||||
except AttributeError:
|
||||
# Fallback to standard Adam if legacy is not available
|
||||
optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon']
|
||||
)
|
||||
print("⚠️ Using standard Adam optimizer (legacy not available)")
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon'],
|
||||
weight_decay=self.args.get('weight_decay', 0.0)
|
||||
)
|
||||
print("✅ Using AdamW optimizer")
|
||||
|
||||
return optimizer
|
||||
|
||||
@@ -486,6 +494,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
else:
|
||||
print(f"Model has {total_params:,} trainable parameters")
|
||||
|
||||
@tf.function
|
||||
def _train_step(self, batch, step):
|
||||
"""Single training step with gradient tape"""
|
||||
features = batch['input_features']
|
||||
@@ -554,16 +563,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
loss = self.ctc_loss(loss_input, clean_logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Add manual L2 regularization for TPU (since weight_decay is disabled)
|
||||
if self.manual_weight_decay:
|
||||
l2_loss = tf.constant(0.0, dtype=loss.dtype)
|
||||
for var in self.model.trainable_variables:
|
||||
# Ensure dtype consistency for mixed precision training
|
||||
var_l2 = tf.nn.l2_loss(var)
|
||||
var_l2 = tf.cast(var_l2, dtype=loss.dtype) # Cast to match loss dtype
|
||||
l2_loss += var_l2
|
||||
loss += self.weight_decay_rate * l2_loss
|
||||
|
||||
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
||||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||||
# TPU v5e-8使用bfloat16,不需要loss scaling
|
||||
|
||||
@@ -599,7 +599,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
@tf.function
|
||||
def _validation_step(self, batch):
|
||||
"""Single validation step"""
|
||||
"""Single validation step - returns data for accurate PER calculation"""
|
||||
features = batch['input_features']
|
||||
labels = batch['seq_class_ids']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
@@ -630,38 +630,11 @@ class BrainToTextDecoderTrainerTF:
|
||||
loss = self.ctc_loss(loss_input, logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Calculate simplified PER approximation (TPU-compatible)
|
||||
# For TPU training, we use a simplified metric that avoids complex loops
|
||||
# This gives an approximation of PER but is much faster and TPU-compatible
|
||||
# Greedy decoding for PER calculation
|
||||
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||
|
||||
# Greedy decoding
|
||||
predicted_ids = tf.argmax(logits, axis=-1)
|
||||
|
||||
# Simple approximation: count exact matches vs mismatches
|
||||
# This is less accurate than true edit distance but TPU-compatible
|
||||
batch_size = tf.shape(logits)[0]
|
||||
|
||||
# For each sample, compare predicted vs true sequences
|
||||
total_mismatches = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
for i in tf.range(batch_size):
|
||||
# Get sequences for this sample
|
||||
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
||||
true_seq = labels[i, :phone_seq_lens[i]]
|
||||
|
||||
# Pad to same length for comparison
|
||||
max_len = tf.maximum(tf.shape(pred_seq)[0], tf.shape(true_seq)[0])
|
||||
pred_padded = tf.pad(pred_seq, [[0, max_len - tf.shape(pred_seq)[0]]], constant_values=0)
|
||||
true_padded = tf.pad(true_seq, [[0, max_len - tf.shape(true_seq)[0]]], constant_values=0)
|
||||
|
||||
# Count mismatches
|
||||
mismatches = tf.reduce_sum(tf.cast(tf.not_equal(pred_padded, true_padded), tf.int32))
|
||||
total_mismatches += mismatches
|
||||
|
||||
# Approximate edit distance as number of mismatches
|
||||
batch_edit_distance = tf.cast(total_mismatches, tf.float32)
|
||||
|
||||
return loss, batch_edit_distance, tf.cast(tf.reduce_sum(phone_seq_lens), tf.float32)
|
||||
# Return all necessary data for accurate PER calculation on CPU
|
||||
return loss, predicted_ids, labels, adjusted_lens, phone_seq_lens
|
||||
|
||||
def train(self) -> Dict[str, Any]:
|
||||
"""Main training loop"""
|
||||
@@ -671,28 +644,28 @@ class BrainToTextDecoderTrainerTF:
|
||||
initial_tpu_status = self._get_detailed_tpu_status()
|
||||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||
|
||||
# Create distributed datasets
|
||||
train_dataset = create_input_fn(
|
||||
self.train_dataset_tf,
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=True
|
||||
)
|
||||
|
||||
val_dataset = create_input_fn(
|
||||
self.val_dataset_tf,
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=False
|
||||
)
|
||||
# Distribute datasets with timing
|
||||
# Create datasets using modern distribution API
|
||||
def create_dist_dataset_fn(input_dataset_tf, training):
|
||||
"""Create distributed dataset function for modern TPU strategy"""
|
||||
def dataset_fn(input_context):
|
||||
# create_input_fn returns a complete, batched tf.data.Dataset
|
||||
return create_input_fn(
|
||||
input_dataset_tf,
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=training
|
||||
)
|
||||
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
# Distribute datasets using modern API
|
||||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
||||
dist_start_time = time.time()
|
||||
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
|
||||
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
|
||||
train_dist_time = time.time() - dist_start_time
|
||||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||||
|
||||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
||||
val_start_time = time.time()
|
||||
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
|
||||
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
|
||||
val_dist_time = time.time() - val_start_time
|
||||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||||
|
||||
@@ -709,18 +682,13 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Training loop
|
||||
step = 0
|
||||
|
||||
# Add timing diagnostic for first batch iteration
|
||||
self.logger.info("🔄 Starting training loop iteration...")
|
||||
loop_start_time = time.time()
|
||||
self.logger.info("🔄 Starting training loop...")
|
||||
|
||||
for batch in train_dist_dataset:
|
||||
if step == 0:
|
||||
first_batch_iteration_time = time.time() - loop_start_time
|
||||
self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s")
|
||||
if step >= self.args['num_training_batches']:
|
||||
self.logger.info("Reached maximum training batches, stopping training")
|
||||
break
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Distributed training step
|
||||
@@ -794,7 +762,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
if new_best:
|
||||
if self.args.get('save_best_checkpoint', True):
|
||||
self.logger.info("Checkpointing model")
|
||||
self._save_checkpoint('best_checkpoint', step)
|
||||
self._save_checkpoint(step)
|
||||
|
||||
if self.args.get('save_val_metrics', True):
|
||||
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
||||
@@ -806,7 +774,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# Optional save all validation checkpoints
|
||||
if self.args.get('save_all_val_steps', False):
|
||||
self._save_checkpoint(f'checkpoint_batch_{step}', step)
|
||||
self._save_checkpoint(step)
|
||||
|
||||
# Early stopping
|
||||
if (self.args.get('early_stopping', False) and
|
||||
@@ -825,7 +793,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Save final model
|
||||
if self.args.get('save_final_model', False):
|
||||
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||||
self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1)
|
||||
self._save_checkpoint(step-1)
|
||||
|
||||
return {
|
||||
'train_losses': train_losses,
|
||||
@@ -835,25 +803,58 @@ class BrainToTextDecoderTrainerTF:
|
||||
}
|
||||
|
||||
def _validate(self, val_dataset) -> Dict[str, Any]:
|
||||
"""Run validation on entire validation dataset"""
|
||||
"""Run validation on entire validation dataset with accurate PER calculation"""
|
||||
total_loss = 0.0
|
||||
total_edit_distance = 0
|
||||
total_seq_length = 0
|
||||
num_batches = 0
|
||||
|
||||
for batch in val_dataset:
|
||||
per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = (
|
||||
# Get predictions and labels from all TPU cores
|
||||
per_replica_losses, per_replica_preds, per_replica_labels, per_replica_pred_lens, per_replica_label_lens = (
|
||||
self.strategy.run(self._validation_step, args=(batch,))
|
||||
)
|
||||
|
||||
# Reduce across replicas
|
||||
# Reduce loss across replicas
|
||||
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||
batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None)
|
||||
batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None)
|
||||
|
||||
total_loss += float(batch_loss.numpy())
|
||||
total_edit_distance += float(batch_edit_distance.numpy())
|
||||
total_seq_length += float(batch_seq_length.numpy())
|
||||
|
||||
# Gather all data from TPU cores to CPU for accurate PER calculation
|
||||
all_preds = self.strategy.gather(per_replica_preds, axis=0)
|
||||
all_labels = self.strategy.gather(per_replica_labels, axis=0)
|
||||
all_pred_lens = self.strategy.gather(per_replica_pred_lens, axis=0)
|
||||
all_label_lens = self.strategy.gather(per_replica_label_lens, axis=0)
|
||||
|
||||
# Calculate accurate edit distance on CPU
|
||||
batch_size = all_preds.shape[0]
|
||||
for i in range(batch_size):
|
||||
pred_len = int(all_pred_lens[i].numpy())
|
||||
label_len = int(all_label_lens[i].numpy())
|
||||
|
||||
# Extract sequences and remove CTC blanks (assuming blank_index=0)
|
||||
pred_seq = all_preds[i, :pred_len].numpy()
|
||||
pred_seq = [p for p in pred_seq if p != 0] # Remove blanks
|
||||
|
||||
# Remove consecutive duplicates (CTC decoding)
|
||||
if len(pred_seq) > 0:
|
||||
deduped_pred = [pred_seq[0]]
|
||||
for j in range(1, len(pred_seq)):
|
||||
if pred_seq[j] != pred_seq[j-1]:
|
||||
deduped_pred.append(pred_seq[j])
|
||||
pred_seq = deduped_pred
|
||||
|
||||
true_seq = all_labels[i, :label_len].numpy().tolist()
|
||||
|
||||
# Calculate edit distance using proper library if available
|
||||
if editdistance is not None:
|
||||
edit_dist = editdistance.eval(pred_seq, true_seq)
|
||||
else:
|
||||
# Fallback to simple approximation if editdistance not available
|
||||
edit_dist = self._simple_edit_distance(pred_seq, true_seq)
|
||||
|
||||
total_edit_distance += edit_dist
|
||||
total_seq_length += label_len
|
||||
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
@@ -867,50 +868,95 @@ class BrainToTextDecoderTrainerTF:
|
||||
'num_batches': num_batches
|
||||
}
|
||||
|
||||
def _save_checkpoint(self, name: str, step: int):
|
||||
"""Save model checkpoint"""
|
||||
checkpoint_path = os.path.join(self.args['checkpoint_dir'], name)
|
||||
def _simple_edit_distance(self, seq1, seq2):
|
||||
"""Simple edit distance implementation as fallback"""
|
||||
# Dynamic programming implementation of edit distance
|
||||
m, n = len(seq1), len(seq2)
|
||||
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
|
||||
# Save model weights
|
||||
self.model.save_weights(checkpoint_path + '.weights.h5')
|
||||
# Initialize base cases
|
||||
for i in range(m + 1):
|
||||
dp[i][0] = i
|
||||
for j in range(n + 1):
|
||||
dp[0][j] = j
|
||||
|
||||
# Save optimizer state
|
||||
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
|
||||
optimizer_checkpoint.save(checkpoint_path + '.optimizer')
|
||||
# Fill the DP table
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
if seq1[i-1] == seq2[j-1]:
|
||||
dp[i][j] = dp[i-1][j-1]
|
||||
else:
|
||||
dp[i][j] = 1 + min(
|
||||
dp[i-1][j], # deletion
|
||||
dp[i][j-1], # insertion
|
||||
dp[i-1][j-1] # substitution
|
||||
)
|
||||
|
||||
# Save training state
|
||||
return dp[m][n]
|
||||
|
||||
def _save_checkpoint(self, step: int, name: str = ""):
|
||||
"""Save checkpoint using the unified CheckpointManager"""
|
||||
# CheckpointManager automatically handles naming and numbering
|
||||
# The 'name' parameter is kept for backward compatibility but not used
|
||||
save_path = self.ckpt_manager.save(checkpoint_number=step)
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"Saved checkpoint for step {step}: {save_path}")
|
||||
else:
|
||||
print(f"Saved checkpoint for step {step}: {save_path}")
|
||||
|
||||
# Save non-TensorFlow Python state separately
|
||||
state = {
|
||||
'step': step,
|
||||
'best_val_per': float(self.best_val_per),
|
||||
'best_val_loss': float(self.best_val_loss)
|
||||
}
|
||||
|
||||
with open(checkpoint_path + '.state.json', 'w') as f:
|
||||
# Associate state file with checkpoint
|
||||
state_path = os.path.join(self.args['checkpoint_dir'], f'state-{step}.json')
|
||||
with open(state_path, 'w') as f:
|
||||
json.dump(state, f)
|
||||
|
||||
# Save config
|
||||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||||
OmegaConf.save(config=self.args, f=f)
|
||||
|
||||
self.logger.info(f"Saved checkpoint: {checkpoint_path}")
|
||||
# Save config file (only once)
|
||||
config_path = os.path.join(self.args['checkpoint_dir'], 'args.yaml')
|
||||
if not os.path.exists(config_path):
|
||||
with open(config_path, 'w') as f:
|
||||
OmegaConf.save(config=self.args, f=f)
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load model checkpoint"""
|
||||
# Load model weights
|
||||
self.model.load_weights(checkpoint_path + '.weights.h5')
|
||||
"""Load a specific checkpoint and its associated training state"""
|
||||
if self.logger:
|
||||
self.logger.info(f"Loading checkpoint from: {checkpoint_path}")
|
||||
else:
|
||||
print(f"Loading checkpoint from: {checkpoint_path}")
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
|
||||
optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1')
|
||||
# Restore TensorFlow objects (model and optimizer)
|
||||
self.ckpt.restore(checkpoint_path).expect_partial()
|
||||
|
||||
# Load training state
|
||||
with open(checkpoint_path + '.state.json', 'r') as f:
|
||||
state = json.load(f)
|
||||
# Restore non-TensorFlow training state
|
||||
try:
|
||||
# Extract step number from checkpoint path (e.g., ckpt-123 -> 123)
|
||||
step = int(checkpoint_path.split('-')[-1])
|
||||
state_path = os.path.join(os.path.dirname(checkpoint_path), f'state-{step}.json')
|
||||
|
||||
self.best_val_per = state['best_val_per']
|
||||
self.best_val_loss = state['best_val_loss']
|
||||
with open(state_path, 'r') as f:
|
||||
state = json.load(f)
|
||||
|
||||
self.logger.info(f"Loaded checkpoint: {checkpoint_path}")
|
||||
self.best_val_per = state['best_val_per']
|
||||
self.best_val_loss = state['best_val_loss']
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"Restored training state from: {state_path}")
|
||||
else:
|
||||
print(f"Restored training state from: {state_path}")
|
||||
|
||||
except (IOError, ValueError, KeyError) as e:
|
||||
warning_msg = (f"Could not load or parse state file for checkpoint {checkpoint_path}. "
|
||||
f"Starting with fresh state. Error: {e}")
|
||||
if self.logger:
|
||||
self.logger.warning(warning_msg)
|
||||
else:
|
||||
print(f"⚠️ {warning_msg}")
|
||||
|
||||
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
|
||||
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
|
||||
|
Reference in New Issue
Block a user