trainer
This commit is contained in:
651
model_training_nnn_tpu/trainer_tf.py
Normal file
651
model_training_nnn_tpu/trainer_tf.py
Normal file
@@ -0,0 +1,651 @@
|
||||
import os
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import time
|
||||
import json
|
||||
import pickle
|
||||
import logging
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Dict, Any, Tuple, Optional, List
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from rnn_model_tf import (
|
||||
TripleGRUDecoder,
|
||||
CTCLoss,
|
||||
create_tpu_strategy,
|
||||
build_model_for_tpu,
|
||||
configure_mixed_precision
|
||||
)
|
||||
from dataset_tf import (
|
||||
BrainToTextDatasetTF,
|
||||
DataAugmentationTF,
|
||||
train_test_split_indices,
|
||||
create_input_fn
|
||||
)
|
||||
|
||||
|
||||
class BrainToTextDecoderTrainerTF:
|
||||
"""
|
||||
TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
|
||||
|
||||
This trainer implements the same training logic as the PyTorch version but uses
|
||||
TensorFlow operations optimized for TPU hardware.
|
||||
"""
|
||||
|
||||
def __init__(self, args: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the TensorFlow trainer
|
||||
|
||||
Args:
|
||||
args: Configuration dictionary containing all training parameters
|
||||
"""
|
||||
self.args = args
|
||||
self.logger = None
|
||||
|
||||
# Initialize TPU strategy
|
||||
self.strategy = create_tpu_strategy()
|
||||
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
|
||||
|
||||
# Configure mixed precision for TPU v5e-8
|
||||
if args.get('use_amp', True):
|
||||
configure_mixed_precision()
|
||||
self.mixed_precision = True
|
||||
else:
|
||||
self.mixed_precision = False
|
||||
|
||||
# Initialize tracking variables
|
||||
self.best_val_per = float('inf')
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
# Setup directories
|
||||
if args['mode'] == 'train':
|
||||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||||
|
||||
if (args.get('save_best_checkpoint', True) or
|
||||
args.get('save_all_val_steps', False) or
|
||||
args.get('save_final_model', False)):
|
||||
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
|
||||
|
||||
# Setup logging
|
||||
self._setup_logging()
|
||||
|
||||
# Set random seeds
|
||||
if self.args['seed'] != -1:
|
||||
tf.random.set_seed(self.args['seed'])
|
||||
np.random.seed(self.args['seed'])
|
||||
|
||||
# Initialize datasets
|
||||
self._initialize_datasets()
|
||||
|
||||
# Build model within strategy scope
|
||||
with self.strategy.scope():
|
||||
self.model = self._build_model()
|
||||
self.optimizer = self._create_optimizer()
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
|
||||
# Log model information
|
||||
self._log_model_info()
|
||||
|
||||
# Adversarial training configuration
|
||||
adv_cfg = self.args.get('adversarial', {})
|
||||
self.adv_enabled = adv_cfg.get('enabled', False)
|
||||
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5))
|
||||
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2))
|
||||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
||||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
||||
|
||||
if self.adv_enabled:
|
||||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
|
||||
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
|
||||
f"noise_l2_weight={self.adv_noise_l2_weight}, "
|
||||
f"warmup_steps={self.adv_warmup_steps}")
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Setup logging configuration"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
for handler in self.logger.handlers[:]:
|
||||
self.logger.removeHandler(handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
|
||||
|
||||
if self.args['mode'] == 'train':
|
||||
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log')))
|
||||
fh.setFormatter(formatter)
|
||||
self.logger.addHandler(fh)
|
||||
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
sh.setFormatter(formatter)
|
||||
self.logger.addHandler(sh)
|
||||
|
||||
self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas')
|
||||
if self.mixed_precision:
|
||||
self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
|
||||
|
||||
def _initialize_datasets(self):
|
||||
"""Initialize training and validation datasets"""
|
||||
# Create file paths
|
||||
train_file_paths = [
|
||||
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
|
||||
for s in self.args['dataset']['sessions']
|
||||
]
|
||||
val_file_paths = [
|
||||
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5')
|
||||
for s in self.args['dataset']['sessions']
|
||||
]
|
||||
|
||||
# Validate no duplicates
|
||||
if len(set(train_file_paths)) != len(train_file_paths):
|
||||
raise ValueError("Duplicate sessions in train dataset")
|
||||
if len(set(val_file_paths)) != len(val_file_paths):
|
||||
raise ValueError("Duplicate sessions in val dataset")
|
||||
|
||||
# Split trials
|
||||
train_trials, _ = train_test_split_indices(
|
||||
file_paths=train_file_paths,
|
||||
test_percentage=0,
|
||||
seed=self.args['dataset']['seed'],
|
||||
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
|
||||
)
|
||||
|
||||
_, val_trials = train_test_split_indices(
|
||||
file_paths=val_file_paths,
|
||||
test_percentage=1,
|
||||
seed=self.args['dataset']['seed'],
|
||||
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
|
||||
)
|
||||
|
||||
# Save trial splits
|
||||
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
|
||||
json.dump({'train': train_trials, 'val': val_trials}, f)
|
||||
|
||||
# Create TensorFlow datasets
|
||||
self.train_dataset_tf = BrainToTextDatasetTF(
|
||||
trial_indices=train_trials,
|
||||
n_batches=self.args['num_training_batches'],
|
||||
split='train',
|
||||
batch_size=self.args['dataset']['batch_size'],
|
||||
days_per_batch=self.args['dataset']['days_per_batch'],
|
||||
random_seed=self.args['dataset']['seed'],
|
||||
must_include_days=self.args['dataset'].get('must_include_days'),
|
||||
feature_subset=self.args['dataset'].get('feature_subset')
|
||||
)
|
||||
|
||||
self.val_dataset_tf = BrainToTextDatasetTF(
|
||||
trial_indices=val_trials,
|
||||
n_batches=None, # Use all validation data
|
||||
split='test',
|
||||
batch_size=self.args['dataset']['batch_size'],
|
||||
days_per_batch=1, # One day per validation batch
|
||||
random_seed=self.args['dataset']['seed'],
|
||||
feature_subset=self.args['dataset'].get('feature_subset')
|
||||
)
|
||||
|
||||
self.logger.info("Successfully initialized TensorFlow datasets")
|
||||
|
||||
def _build_model(self) -> TripleGRUDecoder:
|
||||
"""Build the TripleGRUDecoder model"""
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=self.args['model']['n_input_features'],
|
||||
n_units=self.args['model']['n_units'],
|
||||
n_days=len(self.args['dataset']['sessions']),
|
||||
n_classes=self.args['dataset']['n_classes'],
|
||||
rnn_dropout=self.args['model']['rnn_dropout'],
|
||||
input_dropout=self.args['model']['input_network']['input_layer_dropout'],
|
||||
patch_size=self.args['model']['patch_size'],
|
||||
patch_stride=self.args['model']['patch_stride']
|
||||
)
|
||||
return model
|
||||
|
||||
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
||||
"""Create AdamW optimizer with parameter groups"""
|
||||
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
||||
# We'll use a single optimizer and handle different learning rates in the scheduler
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon'],
|
||||
weight_decay=self.args['weight_decay']
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
def _create_lr_scheduler(self):
|
||||
"""Create learning rate scheduler"""
|
||||
if self.args['lr_scheduler_type'] == 'cosine':
|
||||
return self._create_cosine_scheduler()
|
||||
elif self.args['lr_scheduler_type'] == 'linear':
|
||||
return tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=self.args['lr_max'],
|
||||
decay_steps=self.args['lr_decay_steps'],
|
||||
end_learning_rate=self.args['lr_min'],
|
||||
power=1.0 # Linear decay
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}")
|
||||
|
||||
def _create_cosine_scheduler(self):
|
||||
"""Create cosine learning rate scheduler"""
|
||||
return tf.keras.optimizers.schedules.CosineDecayRestarts(
|
||||
initial_learning_rate=self.args['lr_max'],
|
||||
first_decay_steps=self.args['lr_decay_steps'],
|
||||
t_mul=1.0,
|
||||
m_mul=1.0,
|
||||
alpha=self.args['lr_min'] / self.args['lr_max']
|
||||
)
|
||||
|
||||
def _log_model_info(self):
|
||||
"""Log model architecture and parameter information"""
|
||||
self.logger.info("Initialized TripleGRUDecoder model")
|
||||
|
||||
# Build the model by calling it once with dummy data
|
||||
dummy_batch_size = 2
|
||||
dummy_time_steps = 100
|
||||
dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
|
||||
dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
|
||||
|
||||
# Call the model to build it
|
||||
_ = self.model(dummy_features, dummy_day_idx, training=False)
|
||||
|
||||
# Count parameters
|
||||
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
|
||||
self.logger.info(f"Model has {total_params:,} trainable parameters")
|
||||
|
||||
@tf.function
|
||||
def _train_step(self, batch, step):
|
||||
"""Single training step with gradient tape"""
|
||||
features = batch['input_features']
|
||||
labels = batch['seq_class_ids']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
phone_seq_lens = batch['phone_seq_lens']
|
||||
day_indices = batch['day_indices']
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
# Apply data transformations
|
||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, self.args['dataset']['data_transforms'], training=True
|
||||
)
|
||||
|
||||
# Calculate adjusted lengths for CTC
|
||||
adjusted_lens = tf.cast(
|
||||
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
||||
self.args['model']['patch_stride'] + 1,
|
||||
tf.int32
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
use_full = self.adv_enabled and (step >= self.adv_warmup_steps)
|
||||
if use_full:
|
||||
clean_logits, noisy_logits, noise_output = self.model(
|
||||
features, day_indices, None, False, 'full',
|
||||
grl_lambda=self.adv_grl_lambda, training=True
|
||||
)
|
||||
else:
|
||||
clean_logits = self.model(
|
||||
features, day_indices, None, False, 'inference', training=True
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
if use_full:
|
||||
# Clean CTC loss
|
||||
clean_loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
clean_loss = self.ctc_loss(clean_loss_input, clean_logits)
|
||||
clean_loss = tf.reduce_mean(clean_loss)
|
||||
|
||||
# Noisy CTC loss
|
||||
noisy_loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits)
|
||||
noisy_loss = tf.reduce_mean(noisy_loss)
|
||||
|
||||
# Optional noise L2 regularization
|
||||
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
|
||||
if self.adv_noise_l2_weight > 0.0:
|
||||
noise_l2 = tf.reduce_mean(tf.square(noise_output))
|
||||
|
||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||
else:
|
||||
loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
loss = self.ctc_loss(loss_input, clean_logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Scale loss for mixed precision
|
||||
if self.mixed_precision:
|
||||
scaled_loss = self.optimizer.get_scaled_loss(loss)
|
||||
else:
|
||||
scaled_loss = loss
|
||||
|
||||
# Calculate gradients
|
||||
if self.mixed_precision:
|
||||
scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
|
||||
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
|
||||
else:
|
||||
gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
|
||||
|
||||
# Clip gradients
|
||||
if self.args['grad_norm_clip_value'] > 0:
|
||||
gradients, grad_norm = tf.clip_by_global_norm(
|
||||
gradients, self.args['grad_norm_clip_value']
|
||||
)
|
||||
else:
|
||||
grad_norm = tf.global_norm(gradients)
|
||||
|
||||
# Apply gradients
|
||||
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
|
||||
|
||||
return loss, grad_norm
|
||||
|
||||
@tf.function
|
||||
def _validation_step(self, batch):
|
||||
"""Single validation step"""
|
||||
features = batch['input_features']
|
||||
labels = batch['seq_class_ids']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
phone_seq_lens = batch['phone_seq_lens']
|
||||
day_indices = batch['day_indices']
|
||||
|
||||
# Apply data transformations (no augmentation for validation)
|
||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
|
||||
)
|
||||
|
||||
# Calculate adjusted lengths
|
||||
adjusted_lens = tf.cast(
|
||||
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
||||
self.args['model']['patch_stride'] + 1,
|
||||
tf.int32
|
||||
)
|
||||
|
||||
# Forward pass (inference mode only)
|
||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||
|
||||
# Calculate loss
|
||||
loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
loss = self.ctc_loss(loss_input, logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Calculate PER (Phoneme Error Rate)
|
||||
# Greedy decoding
|
||||
predicted_ids = tf.argmax(logits, axis=-1)
|
||||
|
||||
# Remove blanks and consecutive duplicates
|
||||
batch_edit_distance = 0
|
||||
for i in range(tf.shape(logits)[0]):
|
||||
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
||||
# Remove consecutive duplicates
|
||||
pred_seq = tf.py_function(
|
||||
func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]),
|
||||
inp=[pred_seq],
|
||||
Tout=tf.int64
|
||||
)
|
||||
# Remove blanks (assuming blank_index=0)
|
||||
pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0)
|
||||
|
||||
true_seq = labels[i, :phone_seq_lens[i]]
|
||||
|
||||
# Calculate edit distance
|
||||
edit_dist = tf.edit_distance(
|
||||
tf.SparseTensor(
|
||||
indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1),
|
||||
values=tf.cast(pred_seq, tf.int64),
|
||||
dense_shape=[tf.size(pred_seq)]
|
||||
),
|
||||
tf.SparseTensor(
|
||||
indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1),
|
||||
values=tf.cast(true_seq, tf.int64),
|
||||
dense_shape=[tf.size(true_seq)]
|
||||
),
|
||||
normalize=False
|
||||
)
|
||||
|
||||
batch_edit_distance += edit_dist
|
||||
|
||||
return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens)
|
||||
|
||||
def train(self) -> Dict[str, Any]:
|
||||
"""Main training loop"""
|
||||
self.logger.info("Starting training loop...")
|
||||
|
||||
# Create distributed datasets
|
||||
train_dataset = create_input_fn(
|
||||
self.train_dataset_tf,
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=True
|
||||
)
|
||||
val_dataset = create_input_fn(
|
||||
self.val_dataset_tf,
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=False
|
||||
)
|
||||
|
||||
# Distribute datasets
|
||||
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
|
||||
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
|
||||
|
||||
# Training metrics
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
val_pers = []
|
||||
val_results = []
|
||||
val_steps_since_improvement = 0
|
||||
|
||||
train_start_time = time.time()
|
||||
|
||||
# Training loop
|
||||
step = 0
|
||||
for batch in train_dist_dataset:
|
||||
if step >= self.args['num_training_batches']:
|
||||
break
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Distributed training step
|
||||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||||
self._train_step, args=(batch, step)
|
||||
)
|
||||
|
||||
# Reduce across replicas
|
||||
loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||
grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)
|
||||
|
||||
train_step_duration = time.time() - start_time
|
||||
train_losses.append(float(loss.numpy()))
|
||||
|
||||
# Log training progress
|
||||
if step % self.args['batches_per_train_log'] == 0:
|
||||
self.logger.info(f'Train batch {step}: '
|
||||
f'loss: {float(loss.numpy()):.2f} '
|
||||
f'grad norm: {float(grad_norm.numpy()):.2f} '
|
||||
f'time: {train_step_duration:.3f}')
|
||||
|
||||
# Validation step
|
||||
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
|
||||
self.logger.info(f"Running validation after training batch: {step}")
|
||||
|
||||
val_start_time = time.time()
|
||||
val_metrics = self._validate(val_dist_dataset)
|
||||
val_step_duration = time.time() - val_start_time
|
||||
|
||||
self.logger.info(f'Val batch {step}: '
|
||||
f'PER (avg): {val_metrics["avg_per"]:.4f} '
|
||||
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
|
||||
f'time: {val_step_duration:.3f}')
|
||||
|
||||
val_pers.append(val_metrics['avg_per'])
|
||||
val_losses.append(val_metrics['avg_loss'])
|
||||
val_results.append(val_metrics)
|
||||
|
||||
# Check for improvement
|
||||
new_best = False
|
||||
if val_metrics['avg_per'] < self.best_val_per:
|
||||
self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}")
|
||||
self.best_val_per = val_metrics['avg_per']
|
||||
self.best_val_loss = val_metrics['avg_loss']
|
||||
new_best = True
|
||||
elif (val_metrics['avg_per'] == self.best_val_per and
|
||||
val_metrics['avg_loss'] < self.best_val_loss):
|
||||
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
|
||||
self.best_val_loss = val_metrics['avg_loss']
|
||||
new_best = True
|
||||
|
||||
if new_best:
|
||||
if self.args.get('save_best_checkpoint', True):
|
||||
self.logger.info("Checkpointing model")
|
||||
self._save_checkpoint('best_checkpoint', step)
|
||||
|
||||
if self.args.get('save_val_metrics', True):
|
||||
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
||||
pickle.dump(val_metrics, f)
|
||||
|
||||
val_steps_since_improvement = 0
|
||||
else:
|
||||
val_steps_since_improvement += 1
|
||||
|
||||
# Optional save all validation checkpoints
|
||||
if self.args.get('save_all_val_steps', False):
|
||||
self._save_checkpoint(f'checkpoint_batch_{step}', step)
|
||||
|
||||
# Early stopping
|
||||
if (self.args.get('early_stopping', False) and
|
||||
val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)):
|
||||
self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} '
|
||||
f'validation steps. Stopping training early at batch: {step}')
|
||||
break
|
||||
|
||||
step += 1
|
||||
|
||||
# Training completed
|
||||
training_duration = time.time() - train_start_time
|
||||
self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}')
|
||||
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
|
||||
|
||||
# Save final model
|
||||
if self.args.get('save_final_model', False):
|
||||
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||||
self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1)
|
||||
|
||||
return {
|
||||
'train_losses': train_losses,
|
||||
'val_losses': val_losses,
|
||||
'val_pers': val_pers,
|
||||
'val_metrics': val_results
|
||||
}
|
||||
|
||||
def _validate(self, val_dataset) -> Dict[str, Any]:
|
||||
"""Run validation on entire validation dataset"""
|
||||
total_loss = 0.0
|
||||
total_edit_distance = 0
|
||||
total_seq_length = 0
|
||||
num_batches = 0
|
||||
|
||||
for batch in val_dataset:
|
||||
per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = (
|
||||
self.strategy.run(self._validation_step, args=(batch,))
|
||||
)
|
||||
|
||||
# Reduce across replicas
|
||||
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||
batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None)
|
||||
batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None)
|
||||
|
||||
total_loss += float(batch_loss.numpy())
|
||||
total_edit_distance += float(batch_edit_distance.numpy())
|
||||
total_seq_length += float(batch_seq_length.numpy())
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
avg_per = total_edit_distance / max(total_seq_length, 1e-6)
|
||||
|
||||
return {
|
||||
'avg_loss': avg_loss,
|
||||
'avg_per': avg_per,
|
||||
'total_edit_distance': total_edit_distance,
|
||||
'total_seq_length': total_seq_length,
|
||||
'num_batches': num_batches
|
||||
}
|
||||
|
||||
def _save_checkpoint(self, name: str, step: int):
|
||||
"""Save model checkpoint"""
|
||||
checkpoint_path = os.path.join(self.args['checkpoint_dir'], name)
|
||||
|
||||
# Save model weights
|
||||
self.model.save_weights(checkpoint_path + '.weights.h5')
|
||||
|
||||
# Save optimizer state
|
||||
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
|
||||
optimizer_checkpoint.save(checkpoint_path + '.optimizer')
|
||||
|
||||
# Save training state
|
||||
state = {
|
||||
'step': step,
|
||||
'best_val_per': float(self.best_val_per),
|
||||
'best_val_loss': float(self.best_val_loss)
|
||||
}
|
||||
|
||||
with open(checkpoint_path + '.state.json', 'w') as f:
|
||||
json.dump(state, f)
|
||||
|
||||
# Save config
|
||||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||||
OmegaConf.save(config=self.args, f=f)
|
||||
|
||||
self.logger.info(f"Saved checkpoint: {checkpoint_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load model checkpoint"""
|
||||
# Load model weights
|
||||
self.model.load_weights(checkpoint_path + '.weights.h5')
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
|
||||
optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1')
|
||||
|
||||
# Load training state
|
||||
with open(checkpoint_path + '.state.json', 'r') as f:
|
||||
state = json.load(f)
|
||||
|
||||
self.best_val_per = state['best_val_per']
|
||||
self.best_val_loss = state['best_val_loss']
|
||||
|
||||
self.logger.info(f"Loaded checkpoint: {checkpoint_path}")
|
||||
|
||||
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
|
||||
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
|
||||
"""
|
||||
Run inference on input features
|
||||
|
||||
Args:
|
||||
features: Input neural features [batch_size, time_steps, features]
|
||||
day_indices: Day indices [batch_size]
|
||||
n_time_steps: Number of valid time steps [batch_size]
|
||||
mode: 'inference' or 'full'
|
||||
|
||||
Returns:
|
||||
Phoneme logits [batch_size, time_steps, n_classes]
|
||||
"""
|
||||
# Apply data transformations (no augmentation)
|
||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
|
||||
)
|
||||
|
||||
# Run model inference
|
||||
logits = self.model(features, day_indices, None, False, mode, training=False)
|
||||
|
||||
return logits
|
Reference in New Issue
Block a user