fff
This commit is contained in:
@@ -10,6 +10,13 @@ import sys
|
|||||||
from typing import Dict, Any, Tuple, Optional, List
|
from typing import Dict, Any, Tuple, Optional, List
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
# For accurate PER calculation
|
||||||
|
try:
|
||||||
|
import editdistance
|
||||||
|
except ImportError:
|
||||||
|
print("Warning: editdistance not available, falling back to approximation")
|
||||||
|
editdistance = None
|
||||||
|
|
||||||
from rnn_model_tf import (
|
from rnn_model_tf import (
|
||||||
TripleGRUDecoder,
|
TripleGRUDecoder,
|
||||||
CTCLoss,
|
CTCLoss,
|
||||||
@@ -99,6 +106,30 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
self.lr_scheduler = self._create_lr_scheduler()
|
self.lr_scheduler = self._create_lr_scheduler()
|
||||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||||
|
|
||||||
|
# Create unified checkpoint management
|
||||||
|
self.ckpt = tf.train.Checkpoint(
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
model=self.model
|
||||||
|
)
|
||||||
|
self.ckpt_manager = tf.train.CheckpointManager(
|
||||||
|
self.ckpt,
|
||||||
|
directory=self.args['checkpoint_dir'],
|
||||||
|
max_to_keep=5 # Keep only the 5 most recent checkpoints
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to restore from latest checkpoint
|
||||||
|
if self.ckpt_manager.latest_checkpoint:
|
||||||
|
print(f"🔄 Restoring from {self.ckpt_manager.latest_checkpoint}")
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"Restoring from {self.ckpt_manager.latest_checkpoint}")
|
||||||
|
# expect_partial() avoids failures due to model structure changes
|
||||||
|
self.ckpt.restore(self.ckpt_manager.latest_checkpoint).expect_partial()
|
||||||
|
print("✅ Checkpoint restored successfully")
|
||||||
|
else:
|
||||||
|
print("🆕 Initializing training from scratch")
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info("Initializing training from scratch")
|
||||||
|
|
||||||
# Log model information
|
# Log model information
|
||||||
self._log_model_info()
|
self._log_model_info()
|
||||||
|
|
||||||
@@ -110,12 +141,10 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
||||||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
||||||
|
|
||||||
# Manual weight decay handling for all environments (since we use Adam)
|
# Manual weight decay handling - disable since AdamW handles it
|
||||||
self.manual_weight_decay = False
|
self.manual_weight_decay = False
|
||||||
if self.args.get('weight_decay', 0.0) > 0:
|
if self.args.get('weight_decay', 0.0) > 0:
|
||||||
self.manual_weight_decay = True
|
print(f"🔧 Weight decay configured in AdamW: {self.args.get('weight_decay', 0.0)}")
|
||||||
self.weight_decay_rate = self.args['weight_decay']
|
|
||||||
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
|
|
||||||
else:
|
else:
|
||||||
print("💡 No weight decay configured")
|
print("💡 No weight decay configured")
|
||||||
|
|
||||||
@@ -402,39 +431,18 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
||||||
"""Create AdamW optimizer with parameter groups"""
|
"""Create AdamW optimizer"""
|
||||||
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
|
||||||
# We'll use a single optimizer and handle different learning rates in the scheduler
|
|
||||||
|
|
||||||
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
||||||
|
print("Using AdamW optimizer for TPU training")
|
||||||
|
|
||||||
# For TPU training, we need to be more explicit about optimizer configuration
|
optimizer = tf.keras.optimizers.AdamW(
|
||||||
# to avoid strategy context issues
|
|
||||||
# IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs
|
|
||||||
# AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0
|
|
||||||
# We implement manual L2 regularization (weight decay) in the training step instead
|
|
||||||
print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
|
|
||||||
print("💡 Manual L2 regularization will be applied in training step")
|
|
||||||
|
|
||||||
# Use legacy Adam optimizer for better TPU distributed training compatibility
|
|
||||||
# Legacy optimizers have more stable distributed training implementations
|
|
||||||
try:
|
|
||||||
optimizer = tf.keras.optimizers.legacy.Adam(
|
|
||||||
learning_rate=self.args['lr_max'],
|
learning_rate=self.args['lr_max'],
|
||||||
beta_1=self.args['beta0'],
|
beta_1=self.args['beta0'],
|
||||||
beta_2=self.args['beta1'],
|
beta_2=self.args['beta1'],
|
||||||
epsilon=self.args['epsilon']
|
epsilon=self.args['epsilon'],
|
||||||
|
weight_decay=self.args.get('weight_decay', 0.0)
|
||||||
)
|
)
|
||||||
print("✅ Using legacy Adam optimizer for better TPU compatibility")
|
print("✅ Using AdamW optimizer")
|
||||||
except AttributeError:
|
|
||||||
# Fallback to standard Adam if legacy is not available
|
|
||||||
optimizer = tf.keras.optimizers.Adam(
|
|
||||||
learning_rate=self.args['lr_max'],
|
|
||||||
beta_1=self.args['beta0'],
|
|
||||||
beta_2=self.args['beta1'],
|
|
||||||
epsilon=self.args['epsilon']
|
|
||||||
)
|
|
||||||
print("⚠️ Using standard Adam optimizer (legacy not available)")
|
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
@@ -486,6 +494,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
else:
|
else:
|
||||||
print(f"Model has {total_params:,} trainable parameters")
|
print(f"Model has {total_params:,} trainable parameters")
|
||||||
|
|
||||||
|
@tf.function
|
||||||
def _train_step(self, batch, step):
|
def _train_step(self, batch, step):
|
||||||
"""Single training step with gradient tape"""
|
"""Single training step with gradient tape"""
|
||||||
features = batch['input_features']
|
features = batch['input_features']
|
||||||
@@ -554,16 +563,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
loss = self.ctc_loss(loss_input, clean_logits)
|
loss = self.ctc_loss(loss_input, clean_logits)
|
||||||
loss = tf.reduce_mean(loss)
|
loss = tf.reduce_mean(loss)
|
||||||
|
|
||||||
# Add manual L2 regularization for TPU (since weight_decay is disabled)
|
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
||||||
if self.manual_weight_decay:
|
|
||||||
l2_loss = tf.constant(0.0, dtype=loss.dtype)
|
|
||||||
for var in self.model.trainable_variables:
|
|
||||||
# Ensure dtype consistency for mixed precision training
|
|
||||||
var_l2 = tf.nn.l2_loss(var)
|
|
||||||
var_l2 = tf.cast(var_l2, dtype=loss.dtype) # Cast to match loss dtype
|
|
||||||
l2_loss += var_l2
|
|
||||||
loss += self.weight_decay_rate * l2_loss
|
|
||||||
|
|
||||||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||||||
# TPU v5e-8使用bfloat16,不需要loss scaling
|
# TPU v5e-8使用bfloat16,不需要loss scaling
|
||||||
|
|
||||||
@@ -599,7 +599,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def _validation_step(self, batch):
|
def _validation_step(self, batch):
|
||||||
"""Single validation step"""
|
"""Single validation step - returns data for accurate PER calculation"""
|
||||||
features = batch['input_features']
|
features = batch['input_features']
|
||||||
labels = batch['seq_class_ids']
|
labels = batch['seq_class_ids']
|
||||||
n_time_steps = batch['n_time_steps']
|
n_time_steps = batch['n_time_steps']
|
||||||
@@ -630,38 +630,11 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
loss = self.ctc_loss(loss_input, logits)
|
loss = self.ctc_loss(loss_input, logits)
|
||||||
loss = tf.reduce_mean(loss)
|
loss = tf.reduce_mean(loss)
|
||||||
|
|
||||||
# Calculate simplified PER approximation (TPU-compatible)
|
# Greedy decoding for PER calculation
|
||||||
# For TPU training, we use a simplified metric that avoids complex loops
|
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||||
# This gives an approximation of PER but is much faster and TPU-compatible
|
|
||||||
|
|
||||||
# Greedy decoding
|
# Return all necessary data for accurate PER calculation on CPU
|
||||||
predicted_ids = tf.argmax(logits, axis=-1)
|
return loss, predicted_ids, labels, adjusted_lens, phone_seq_lens
|
||||||
|
|
||||||
# Simple approximation: count exact matches vs mismatches
|
|
||||||
# This is less accurate than true edit distance but TPU-compatible
|
|
||||||
batch_size = tf.shape(logits)[0]
|
|
||||||
|
|
||||||
# For each sample, compare predicted vs true sequences
|
|
||||||
total_mismatches = tf.constant(0, dtype=tf.int32)
|
|
||||||
|
|
||||||
for i in tf.range(batch_size):
|
|
||||||
# Get sequences for this sample
|
|
||||||
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
|
||||||
true_seq = labels[i, :phone_seq_lens[i]]
|
|
||||||
|
|
||||||
# Pad to same length for comparison
|
|
||||||
max_len = tf.maximum(tf.shape(pred_seq)[0], tf.shape(true_seq)[0])
|
|
||||||
pred_padded = tf.pad(pred_seq, [[0, max_len - tf.shape(pred_seq)[0]]], constant_values=0)
|
|
||||||
true_padded = tf.pad(true_seq, [[0, max_len - tf.shape(true_seq)[0]]], constant_values=0)
|
|
||||||
|
|
||||||
# Count mismatches
|
|
||||||
mismatches = tf.reduce_sum(tf.cast(tf.not_equal(pred_padded, true_padded), tf.int32))
|
|
||||||
total_mismatches += mismatches
|
|
||||||
|
|
||||||
# Approximate edit distance as number of mismatches
|
|
||||||
batch_edit_distance = tf.cast(total_mismatches, tf.float32)
|
|
||||||
|
|
||||||
return loss, batch_edit_distance, tf.cast(tf.reduce_sum(phone_seq_lens), tf.float32)
|
|
||||||
|
|
||||||
def train(self) -> Dict[str, Any]:
|
def train(self) -> Dict[str, Any]:
|
||||||
"""Main training loop"""
|
"""Main training loop"""
|
||||||
@@ -671,28 +644,28 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
initial_tpu_status = self._get_detailed_tpu_status()
|
initial_tpu_status = self._get_detailed_tpu_status()
|
||||||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||||
|
|
||||||
# Create distributed datasets
|
# Create datasets using modern distribution API
|
||||||
train_dataset = create_input_fn(
|
def create_dist_dataset_fn(input_dataset_tf, training):
|
||||||
self.train_dataset_tf,
|
"""Create distributed dataset function for modern TPU strategy"""
|
||||||
|
def dataset_fn(input_context):
|
||||||
|
# create_input_fn returns a complete, batched tf.data.Dataset
|
||||||
|
return create_input_fn(
|
||||||
|
input_dataset_tf,
|
||||||
self.args['dataset']['data_transforms'],
|
self.args['dataset']['data_transforms'],
|
||||||
training=True
|
training=training
|
||||||
)
|
)
|
||||||
|
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||||||
|
|
||||||
val_dataset = create_input_fn(
|
# Distribute datasets using modern API
|
||||||
self.val_dataset_tf,
|
|
||||||
self.args['dataset']['data_transforms'],
|
|
||||||
training=False
|
|
||||||
)
|
|
||||||
# Distribute datasets with timing
|
|
||||||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
||||||
dist_start_time = time.time()
|
dist_start_time = time.time()
|
||||||
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
|
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
|
||||||
train_dist_time = time.time() - dist_start_time
|
train_dist_time = time.time() - dist_start_time
|
||||||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||||||
|
|
||||||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
||||||
val_start_time = time.time()
|
val_start_time = time.time()
|
||||||
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
|
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
|
||||||
val_dist_time = time.time() - val_start_time
|
val_dist_time = time.time() - val_start_time
|
||||||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||||||
|
|
||||||
@@ -709,14 +682,9 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Training loop
|
# Training loop
|
||||||
step = 0
|
step = 0
|
||||||
|
|
||||||
# Add timing diagnostic for first batch iteration
|
self.logger.info("🔄 Starting training loop...")
|
||||||
self.logger.info("🔄 Starting training loop iteration...")
|
|
||||||
loop_start_time = time.time()
|
|
||||||
|
|
||||||
for batch in train_dist_dataset:
|
for batch in train_dist_dataset:
|
||||||
if step == 0:
|
|
||||||
first_batch_iteration_time = time.time() - loop_start_time
|
|
||||||
self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s")
|
|
||||||
if step >= self.args['num_training_batches']:
|
if step >= self.args['num_training_batches']:
|
||||||
self.logger.info("Reached maximum training batches, stopping training")
|
self.logger.info("Reached maximum training batches, stopping training")
|
||||||
break
|
break
|
||||||
@@ -794,7 +762,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
if new_best:
|
if new_best:
|
||||||
if self.args.get('save_best_checkpoint', True):
|
if self.args.get('save_best_checkpoint', True):
|
||||||
self.logger.info("Checkpointing model")
|
self.logger.info("Checkpointing model")
|
||||||
self._save_checkpoint('best_checkpoint', step)
|
self._save_checkpoint(step)
|
||||||
|
|
||||||
if self.args.get('save_val_metrics', True):
|
if self.args.get('save_val_metrics', True):
|
||||||
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
||||||
@@ -806,7 +774,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Optional save all validation checkpoints
|
# Optional save all validation checkpoints
|
||||||
if self.args.get('save_all_val_steps', False):
|
if self.args.get('save_all_val_steps', False):
|
||||||
self._save_checkpoint(f'checkpoint_batch_{step}', step)
|
self._save_checkpoint(step)
|
||||||
|
|
||||||
# Early stopping
|
# Early stopping
|
||||||
if (self.args.get('early_stopping', False) and
|
if (self.args.get('early_stopping', False) and
|
||||||
@@ -825,7 +793,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Save final model
|
# Save final model
|
||||||
if self.args.get('save_final_model', False):
|
if self.args.get('save_final_model', False):
|
||||||
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||||||
self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1)
|
self._save_checkpoint(step-1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'train_losses': train_losses,
|
'train_losses': train_losses,
|
||||||
@@ -835,25 +803,58 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _validate(self, val_dataset) -> Dict[str, Any]:
|
def _validate(self, val_dataset) -> Dict[str, Any]:
|
||||||
"""Run validation on entire validation dataset"""
|
"""Run validation on entire validation dataset with accurate PER calculation"""
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
total_edit_distance = 0
|
total_edit_distance = 0
|
||||||
total_seq_length = 0
|
total_seq_length = 0
|
||||||
num_batches = 0
|
num_batches = 0
|
||||||
|
|
||||||
for batch in val_dataset:
|
for batch in val_dataset:
|
||||||
per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = (
|
# Get predictions and labels from all TPU cores
|
||||||
|
per_replica_losses, per_replica_preds, per_replica_labels, per_replica_pred_lens, per_replica_label_lens = (
|
||||||
self.strategy.run(self._validation_step, args=(batch,))
|
self.strategy.run(self._validation_step, args=(batch,))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reduce across replicas
|
# Reduce loss across replicas
|
||||||
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||||
batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None)
|
|
||||||
batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None)
|
|
||||||
|
|
||||||
total_loss += float(batch_loss.numpy())
|
total_loss += float(batch_loss.numpy())
|
||||||
total_edit_distance += float(batch_edit_distance.numpy())
|
|
||||||
total_seq_length += float(batch_seq_length.numpy())
|
# Gather all data from TPU cores to CPU for accurate PER calculation
|
||||||
|
all_preds = self.strategy.gather(per_replica_preds, axis=0)
|
||||||
|
all_labels = self.strategy.gather(per_replica_labels, axis=0)
|
||||||
|
all_pred_lens = self.strategy.gather(per_replica_pred_lens, axis=0)
|
||||||
|
all_label_lens = self.strategy.gather(per_replica_label_lens, axis=0)
|
||||||
|
|
||||||
|
# Calculate accurate edit distance on CPU
|
||||||
|
batch_size = all_preds.shape[0]
|
||||||
|
for i in range(batch_size):
|
||||||
|
pred_len = int(all_pred_lens[i].numpy())
|
||||||
|
label_len = int(all_label_lens[i].numpy())
|
||||||
|
|
||||||
|
# Extract sequences and remove CTC blanks (assuming blank_index=0)
|
||||||
|
pred_seq = all_preds[i, :pred_len].numpy()
|
||||||
|
pred_seq = [p for p in pred_seq if p != 0] # Remove blanks
|
||||||
|
|
||||||
|
# Remove consecutive duplicates (CTC decoding)
|
||||||
|
if len(pred_seq) > 0:
|
||||||
|
deduped_pred = [pred_seq[0]]
|
||||||
|
for j in range(1, len(pred_seq)):
|
||||||
|
if pred_seq[j] != pred_seq[j-1]:
|
||||||
|
deduped_pred.append(pred_seq[j])
|
||||||
|
pred_seq = deduped_pred
|
||||||
|
|
||||||
|
true_seq = all_labels[i, :label_len].numpy().tolist()
|
||||||
|
|
||||||
|
# Calculate edit distance using proper library if available
|
||||||
|
if editdistance is not None:
|
||||||
|
edit_dist = editdistance.eval(pred_seq, true_seq)
|
||||||
|
else:
|
||||||
|
# Fallback to simple approximation if editdistance not available
|
||||||
|
edit_dist = self._simple_edit_distance(pred_seq, true_seq)
|
||||||
|
|
||||||
|
total_edit_distance += edit_dist
|
||||||
|
total_seq_length += label_len
|
||||||
|
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
|
|
||||||
avg_loss = total_loss / max(num_batches, 1)
|
avg_loss = total_loss / max(num_batches, 1)
|
||||||
@@ -867,50 +868,95 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
'num_batches': num_batches
|
'num_batches': num_batches
|
||||||
}
|
}
|
||||||
|
|
||||||
def _save_checkpoint(self, name: str, step: int):
|
def _simple_edit_distance(self, seq1, seq2):
|
||||||
"""Save model checkpoint"""
|
"""Simple edit distance implementation as fallback"""
|
||||||
checkpoint_path = os.path.join(self.args['checkpoint_dir'], name)
|
# Dynamic programming implementation of edit distance
|
||||||
|
m, n = len(seq1), len(seq2)
|
||||||
|
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
||||||
|
|
||||||
# Save model weights
|
# Initialize base cases
|
||||||
self.model.save_weights(checkpoint_path + '.weights.h5')
|
for i in range(m + 1):
|
||||||
|
dp[i][0] = i
|
||||||
|
for j in range(n + 1):
|
||||||
|
dp[0][j] = j
|
||||||
|
|
||||||
# Save optimizer state
|
# Fill the DP table
|
||||||
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
|
for i in range(1, m + 1):
|
||||||
optimizer_checkpoint.save(checkpoint_path + '.optimizer')
|
for j in range(1, n + 1):
|
||||||
|
if seq1[i-1] == seq2[j-1]:
|
||||||
|
dp[i][j] = dp[i-1][j-1]
|
||||||
|
else:
|
||||||
|
dp[i][j] = 1 + min(
|
||||||
|
dp[i-1][j], # deletion
|
||||||
|
dp[i][j-1], # insertion
|
||||||
|
dp[i-1][j-1] # substitution
|
||||||
|
)
|
||||||
|
|
||||||
# Save training state
|
return dp[m][n]
|
||||||
|
|
||||||
|
def _save_checkpoint(self, step: int, name: str = ""):
|
||||||
|
"""Save checkpoint using the unified CheckpointManager"""
|
||||||
|
# CheckpointManager automatically handles naming and numbering
|
||||||
|
# The 'name' parameter is kept for backward compatibility but not used
|
||||||
|
save_path = self.ckpt_manager.save(checkpoint_number=step)
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"Saved checkpoint for step {step}: {save_path}")
|
||||||
|
else:
|
||||||
|
print(f"Saved checkpoint for step {step}: {save_path}")
|
||||||
|
|
||||||
|
# Save non-TensorFlow Python state separately
|
||||||
state = {
|
state = {
|
||||||
'step': step,
|
'step': step,
|
||||||
'best_val_per': float(self.best_val_per),
|
'best_val_per': float(self.best_val_per),
|
||||||
'best_val_loss': float(self.best_val_loss)
|
'best_val_loss': float(self.best_val_loss)
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(checkpoint_path + '.state.json', 'w') as f:
|
# Associate state file with checkpoint
|
||||||
|
state_path = os.path.join(self.args['checkpoint_dir'], f'state-{step}.json')
|
||||||
|
with open(state_path, 'w') as f:
|
||||||
json.dump(state, f)
|
json.dump(state, f)
|
||||||
|
|
||||||
# Save config
|
# Save config file (only once)
|
||||||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
config_path = os.path.join(self.args['checkpoint_dir'], 'args.yaml')
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
OmegaConf.save(config=self.args, f=f)
|
OmegaConf.save(config=self.args, f=f)
|
||||||
|
|
||||||
self.logger.info(f"Saved checkpoint: {checkpoint_path}")
|
|
||||||
|
|
||||||
def load_checkpoint(self, checkpoint_path: str):
|
def load_checkpoint(self, checkpoint_path: str):
|
||||||
"""Load model checkpoint"""
|
"""Load a specific checkpoint and its associated training state"""
|
||||||
# Load model weights
|
if self.logger:
|
||||||
self.model.load_weights(checkpoint_path + '.weights.h5')
|
self.logger.info(f"Loading checkpoint from: {checkpoint_path}")
|
||||||
|
else:
|
||||||
|
print(f"Loading checkpoint from: {checkpoint_path}")
|
||||||
|
|
||||||
# Load optimizer state
|
# Restore TensorFlow objects (model and optimizer)
|
||||||
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
|
self.ckpt.restore(checkpoint_path).expect_partial()
|
||||||
optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1')
|
|
||||||
|
|
||||||
# Load training state
|
# Restore non-TensorFlow training state
|
||||||
with open(checkpoint_path + '.state.json', 'r') as f:
|
try:
|
||||||
|
# Extract step number from checkpoint path (e.g., ckpt-123 -> 123)
|
||||||
|
step = int(checkpoint_path.split('-')[-1])
|
||||||
|
state_path = os.path.join(os.path.dirname(checkpoint_path), f'state-{step}.json')
|
||||||
|
|
||||||
|
with open(state_path, 'r') as f:
|
||||||
state = json.load(f)
|
state = json.load(f)
|
||||||
|
|
||||||
self.best_val_per = state['best_val_per']
|
self.best_val_per = state['best_val_per']
|
||||||
self.best_val_loss = state['best_val_loss']
|
self.best_val_loss = state['best_val_loss']
|
||||||
|
|
||||||
self.logger.info(f"Loaded checkpoint: {checkpoint_path}")
|
if self.logger:
|
||||||
|
self.logger.info(f"Restored training state from: {state_path}")
|
||||||
|
else:
|
||||||
|
print(f"Restored training state from: {state_path}")
|
||||||
|
|
||||||
|
except (IOError, ValueError, KeyError) as e:
|
||||||
|
warning_msg = (f"Could not load or parse state file for checkpoint {checkpoint_path}. "
|
||||||
|
f"Starting with fresh state. Error: {e}")
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(warning_msg)
|
||||||
|
else:
|
||||||
|
print(f"⚠️ {warning_msg}")
|
||||||
|
|
||||||
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
|
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
|
||||||
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
|
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
|
||||||
|
Reference in New Issue
Block a user