Files
b2txt25/model_training_nnn_tpu/trainer_tf.py
Zchen 7efa33d730 f
2025-10-16 22:42:33 +08:00

980 lines
41 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tensorflow as tf
import numpy as np
import time
import json
import pickle
import logging
import pathlib
import sys
from typing import Dict, Any, Tuple, Optional, List
from omegaconf import OmegaConf
from rnn_model_tf import (
TripleGRUDecoder,
CTCLoss,
create_tpu_strategy,
build_model_for_tpu,
configure_mixed_precision
)
from dataset_tf import (
BrainToTextDatasetTF,
DataAugmentationTF,
train_test_split_indices,
create_input_fn
)
class BrainToTextDecoderTrainerTF:
"""
TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
This trainer implements the same training logic as the PyTorch version but uses
TensorFlow operations optimized for TPU hardware.
"""
def __init__(self, args: Dict[str, Any]):
"""
Initialize the TensorFlow trainer
Args:
args: Configuration dictionary containing all training parameters
"""
self.args = args
self.logger = None
# Optimize CPU utilization for data pipeline (利用224核心)
self._configure_cpu_optimization()
# Initialize TPU strategy
self.strategy = create_tpu_strategy()
if self.strategy is None:
raise RuntimeError("Failed to create TPU strategy - strategy is None")
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
print(f"Strategy type: {type(self.strategy).__name__}")
# Configure mixed precision for TPU v5e-8
if args.get('use_amp', True):
configure_mixed_precision()
self.mixed_precision = True
else:
self.mixed_precision = False
# Initialize tracking variables
self.best_val_per = float('inf')
self.best_val_loss = float('inf')
# Setup directories
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=True)
if (args.get('save_best_checkpoint', True) or
args.get('save_all_val_steps', False) or
args.get('save_final_model', False)):
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
# Setup logging
self._setup_logging()
# Set random seeds
if self.args['seed'] != -1:
tf.random.set_seed(self.args['seed'])
np.random.seed(self.args['seed'])
# Initialize datasets
self._initialize_datasets()
# Build model within strategy scope
with self.strategy.scope():
print("🔨 Building model within TPU strategy scope...")
self.model = self._build_model()
print("✅ Model built successfully")
print("⚙️ Creating optimizer...")
self.optimizer = self._create_optimizer()
print("✅ Optimizer created")
print("🔧 Pre-building optimizer state for TPU...")
# For TPU, we must ensure optimizer is completely ready before training
# since @tf.function doesn't allow dynamic building
try:
print("✅ Building optimizer with model variables...")
# Explicitly build the optimizer with model variables
print(f"Building optimizer with {len(self.model.trainable_variables)} variables")
self.optimizer.build(self.model.trainable_variables)
print("✅ Optimizer built with model variables")
# Verify optimizer is properly built - just check iterations
print(f"Optimizer iterations: {self.optimizer.iterations}")
# For TPU training, we should also ensure the optimizer has all its state
# variables created. We can do this by creating dummy variables that match
# the model variables, but we don't apply them (avoid the replica context issue)
print("🔄 Ensuring optimizer state variables are created...")
# Force creation of optimizer variables by accessing them
# This is safe and doesn't require replica context
_ = self.optimizer.iterations # This ensures basic state is created
print("✅ Optimizer fully ready for TPU training")
print("📝 Note: Optimizer will work correctly in @tf.function context")
except Exception as e:
print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}")
print(f"Error type: {type(e).__name__}")
import traceback
print(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
print("📅 Setting up learning rate scheduler...")
self.lr_scheduler = self._create_lr_scheduler()
print("✅ LR scheduler ready")
print("🎯 Initializing CTC loss...")
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
print("✅ CTC loss initialized")
# Log model information
self._log_model_info()
# Adversarial training configuration
adv_cfg = self.args.get('adversarial', {})
self.adv_enabled = adv_cfg.get('enabled', False)
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5))
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2))
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
# TPU-specific weight decay handling
self.manual_weight_decay = False
if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0:
self.manual_weight_decay = True
self.weight_decay_rate = self.args['weight_decay']
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
if self.adv_enabled:
if self.logger:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
f"noise_l2_weight={self.adv_noise_l2_weight}, "
f"warmup_steps={self.adv_warmup_steps}")
else:
print(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
f"noise_l2_weight={self.adv_noise_l2_weight}, "
f"warmup_steps={self.adv_warmup_steps}")
def _setup_logging(self):
"""Setup logging configuration"""
self.logger = logging.getLogger(__name__)
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
if self.args['mode'] == 'train':
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log')))
fh.setFormatter(formatter)
self.logger.addHandler(fh)
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
self.logger.addHandler(sh)
self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas')
if self.mixed_precision:
self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
def _configure_cpu_optimization(self):
"""Configure CPU utilization to make use of 224 cores for data pipeline"""
import multiprocessing
# Get available CPU cores
available_cores = multiprocessing.cpu_count()
print(f"💻 Available CPU cores: {available_cores}")
# Optimize for data pipeline parallelism
# For 224 cores, use more threads for better data loading performance
if available_cores >= 200: # High core count system
inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores
intra_op_threads = min(32, available_cores // 6)
else:
# Use ~1/4 of cores for inter-op (between operations)
# Use ~1/8 of cores for intra-op (within operations)
inter_op_threads = min(32, available_cores // 4)
intra_op_threads = min(16, available_cores // 8)
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
print(f"🔧 CPU optimization configured:")
print(f" Inter-op parallelism: {inter_op_threads} threads")
print(f" Intra-op parallelism: {intra_op_threads} threads")
print(f" This will accelerate data loading and preprocessing while TPU handles training")
def _get_tpu_status(self) -> str:
"""Get current TPU status and HBM utilization info"""
try:
# Get TPU devices
tpu_devices = tf.config.list_logical_devices('TPU')
if not tpu_devices:
return "TPU: No devices"
# Get strategy info
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
# Get TPU memory info using the working /device:TPU:X format
try:
# Check all TPU devices for memory usage
active_cores = 0
total_current_mb = 0
max_peak_mb = 0
for device in tpu_devices:
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
if current_mb > 1: # >1MB considered active
active_cores += 1
total_current_mb += current_mb
max_peak_mb = max(max_peak_mb, peak_mb)
except:
continue
if active_cores > 0:
if active_cores == 1:
hbm_info = f"HBM:{total_current_mb}MB(peak:{max_peak_mb}MB)"
else:
hbm_info = f"HBM:{total_current_mb}MB/{active_cores}cores(peak:{max_peak_mb}MB)"
else:
hbm_info = "HBM:idle"
except Exception:
# Fallback: simple TPU activity check
try:
with tf.device('/TPU:0'):
_ = tf.constant(1.0)
hbm_info = "HBM:active"
except Exception:
hbm_info = "HBM:inactive"
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
f"{hbm_info}")
except Exception as e:
return f"TPU: status_error({str(e)[:20]})"
def _get_detailed_tpu_status(self) -> str:
"""Get detailed TPU status for training start"""
try:
# Get TPU devices
tpu_devices = tf.config.list_logical_devices('TPU')
if not tpu_devices:
return "❌ No TPU devices detected"
# Get strategy info
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
strategy_type = type(self.strategy).__name__
# Get TPU HBM memory info using working device format
try:
active_cores = 0
total_current_gb = 0
max_peak_gb = 0
memory_details = []
for i, device in enumerate(tpu_devices):
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
if current_gb > 0 or memory_info['current'] > 1024*1024: # >1MB
active_cores += 1
total_current_gb += current_gb
max_peak_gb = max(max_peak_gb, peak_gb)
if current_gb > 0:
memory_details.append(f"Core{i}:{current_gb}GB")
except:
continue
if active_cores > 0:
# Based on your test: TPU:0 peaked at 14.5GB, suggesting ~16GB per core
estimated_per_core = 16 # Conservative estimate
estimated_total_gb = estimated_per_core * len(tpu_devices)
hbm_usage = f"HBM: {total_current_gb}GB/{estimated_total_gb}GB (peak: {max_peak_gb}GB) active:{active_cores}cores"
else:
hbm_usage = "HBM: 0GB/256GB (idle)"
except Exception:
hbm_usage = "HBM: unavailable"
# Simple TPU test
try:
with tf.device('/TPU:0'):
test_result = tf.constant([1.0, 2.0])
_ = tf.reduce_sum(test_result)
tpu_test = "✅ responsive"
except Exception as e:
tpu_test = f"❌ test_failed({str(e)[:15]})"
return (f"TPU Devices: {len(tpu_devices)} | "
f"Strategy: {strategy_type} | "
f"Cores: {num_replicas} | "
f"{hbm_usage} | "
f"Test: {tpu_test}")
except Exception as e:
return f"❌ TPU status check failed: {str(e)[:50]}"
def _initialize_datasets(self):
"""Initialize training and validation datasets"""
# Create file paths
train_file_paths = [
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
for s in self.args['dataset']['sessions']
]
val_file_paths = [
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5')
for s in self.args['dataset']['sessions']
]
# Validate no duplicates
if len(set(train_file_paths)) != len(train_file_paths):
raise ValueError("Duplicate sessions in train dataset")
if len(set(val_file_paths)) != len(val_file_paths):
raise ValueError("Duplicate sessions in val dataset")
# Split trials
train_trials, _ = train_test_split_indices(
file_paths=train_file_paths,
test_percentage=0,
seed=self.args['dataset']['seed'],
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
)
_, val_trials = train_test_split_indices(
file_paths=val_file_paths,
test_percentage=1,
seed=self.args['dataset']['seed'],
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
)
# Save trial splits
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
json.dump({'train': train_trials, 'val': val_trials}, f)
# Create TensorFlow datasets with aggressive data preloading for TPU optimization
# Monitor memory usage during data preloading
import psutil
initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
print("🔄 Initializing training dataset with GPU-style memory management...")
init_start_time = time.time()
self.train_dataset_tf = BrainToTextDatasetTF(
trial_indices=train_trials,
n_batches=self.args['num_training_batches'],
split='train',
batch_size=self.args['dataset']['batch_size'],
days_per_batch=self.args['dataset']['days_per_batch'],
random_seed=self.args['dataset']['seed'],
must_include_days=self.args['dataset'].get('must_include_days'),
feature_subset=self.args['dataset'].get('feature_subset'),
cache_data=True, # 启用智能缓存像GPU版本一样
preload_all_data=False # 🚨 采用GPU版本策略按需加载避免内存溢出
)
# Log training dataset initialization performance
train_init_time = time.time() - init_start_time
train_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
train_memory_used = train_memory_mb - initial_memory_mb
print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM")
print("🔄 Initializing validation dataset with GPU-style memory management...")
val_init_start_time = time.time()
self.val_dataset_tf = BrainToTextDatasetTF(
trial_indices=val_trials,
n_batches=None, # Use all validation data
split='test',
batch_size=self.args['dataset']['batch_size'],
days_per_batch=1, # One day per validation batch
random_seed=self.args['dataset']['seed'],
feature_subset=self.args['dataset'].get('feature_subset'),
cache_data=True, # 启用智能缓存像GPU版本一样
preload_all_data=False # 🚨 采用GPU版本策略按需加载避免内存溢出
)
# Log validation dataset initialization performance
val_init_time = time.time() - val_init_start_time
final_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
total_memory_used = final_memory_mb - initial_memory_mb
val_memory_used = final_memory_mb - train_memory_mb
print(f"✅ Validation dataset initialized in {val_init_time:.2f}s, using {val_memory_used:.1f} MB RAM")
print(f"📊 Total data cache: {total_memory_used:.1f} MB RAM used for all datasets")
self.logger.info("Successfully initialized TensorFlow datasets")
def _build_model(self) -> TripleGRUDecoder:
"""Build the TripleGRUDecoder model"""
model = TripleGRUDecoder(
neural_dim=self.args['model']['n_input_features'],
n_units=self.args['model']['n_units'],
n_days=len(self.args['dataset']['sessions']),
n_classes=self.args['dataset']['n_classes'],
rnn_dropout=self.args['model']['rnn_dropout'],
input_dropout=self.args['model']['input_network']['input_layer_dropout'],
patch_size=self.args['model']['patch_size'],
patch_stride=self.args['model']['patch_stride']
)
return model
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Create AdamW optimizer with parameter groups"""
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
# We'll use a single optimizer and handle different learning rates in the scheduler
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
# For TPU training, we need to be more explicit about optimizer configuration
# to avoid strategy context issues
if isinstance(self.strategy, tf.distribute.TPUStrategy):
print("Using TPU-optimized optimizer configuration")
# TPU-specific optimizer configuration
# IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues
# We'll implement manual L2 regularization instead
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=0.0, # Disabled for TPU compatibility
# TPU-specific settings
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
)
print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})")
print("💡 Consider implementing manual L2 regularization if needed")
else:
print("Using standard optimizer configuration")
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=self.args['weight_decay']
)
return optimizer
def _create_lr_scheduler(self):
"""Create learning rate scheduler"""
if self.args['lr_scheduler_type'] == 'cosine':
return self._create_cosine_scheduler()
elif self.args['lr_scheduler_type'] == 'linear':
return tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=self.args['lr_max'],
decay_steps=self.args['lr_decay_steps'],
end_learning_rate=self.args['lr_min'],
power=1.0 # Linear decay
)
else:
raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}")
def _create_cosine_scheduler(self):
"""Create cosine learning rate scheduler"""
return tf.keras.optimizers.schedules.CosineDecayRestarts(
initial_learning_rate=self.args['lr_max'],
first_decay_steps=self.args['lr_decay_steps'],
t_mul=1.0,
m_mul=1.0,
alpha=self.args['lr_min'] / self.args['lr_max']
)
def _log_model_info(self):
"""Log model architecture and parameter information"""
if self.logger:
self.logger.info("Initialized TripleGRUDecoder model")
else:
print("Initialized TripleGRUDecoder model")
# Build the model by calling it once with dummy data
dummy_batch_size = 2
dummy_time_steps = 100
dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
# Call the model to build it
_ = self.model(dummy_features, dummy_day_idx, training=False)
# Count parameters
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
if self.logger:
self.logger.info(f"Model has {total_params:,} trainable parameters")
else:
print(f"Model has {total_params:,} trainable parameters")
@tf.function
def _train_step(self, batch, step):
"""Single training step with gradient tape"""
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indices = batch['day_indices']
with tf.GradientTape() as tape:
# Apply data transformations
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, self.args['dataset']['data_transforms'], training=True
)
# Calculate adjusted lengths for CTC
adjusted_lens = tf.cast(
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
self.args['model']['patch_stride'] + 1,
tf.int32
)
# Forward pass
use_full = self.adv_enabled and (step >= self.adv_warmup_steps)
if use_full:
clean_logits, noisy_logits, noise_output = self.model(
features, day_indices, None, False, 'full',
grl_lambda=self.adv_grl_lambda, training=True
)
else:
clean_logits = self.model(
features, day_indices, None, False, 'inference', training=True
)
# Calculate losses
if use_full:
# Clean CTC loss
clean_loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
clean_loss = self.ctc_loss(clean_loss_input, clean_logits)
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss
noisy_loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits)
noisy_loss = tf.reduce_mean(noisy_loss)
# Optional noise L2 regularization
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
if self.adv_noise_l2_weight > 0.0:
noise_l2 = tf.reduce_mean(tf.square(noise_output))
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
loss = self.ctc_loss(loss_input, clean_logits)
loss = tf.reduce_mean(loss)
# Add manual L2 regularization for TPU (since weight_decay is disabled)
if self.manual_weight_decay:
l2_loss = tf.constant(0.0, dtype=loss.dtype)
for var in self.model.trainable_variables:
l2_loss += tf.nn.l2_loss(var)
loss += self.weight_decay_rate * l2_loss
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理
# TPU v5e-8使用bfloat16不需要loss scaling
# Calculate gradients - TensorFlow自动处理混合精度
gradients = tape.gradient(loss, self.model.trainable_variables)
# Filter out None gradients (for h0 variables that don't need gradients)
filtered_grads_and_vars = []
for grad, var in zip(gradients, self.model.trainable_variables):
if grad is not None:
filtered_grads_and_vars.append((grad, var))
else:
# Log which variables don't have gradients (informational)
tf.print(f"No gradient for variable: {var.name}")
# Extract filtered gradients and variables
filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
filtered_variables = [var for _, var in filtered_grads_and_vars]
# Clip gradients
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
filtered_gradients, grad_norm = tf.clip_by_global_norm(
filtered_gradients, self.args['grad_norm_clip_value']
)
elif len(filtered_gradients) > 0:
grad_norm = tf.global_norm(filtered_gradients)
else:
grad_norm = tf.constant(0.0)
# Apply gradients (only for variables that have gradients)
if len(filtered_gradients) > 0:
# Apply gradients directly - optimizer should be pre-built and ready
# In @tf.function, we need to keep error handling simple
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
return loss, grad_norm
@tf.function
def _validation_step(self, batch):
"""Single validation step"""
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indices = batch['day_indices']
# Apply data transformations (no augmentation for validation)
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
)
# Calculate adjusted lengths
adjusted_lens = tf.cast(
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
self.args['model']['patch_stride'] + 1,
tf.int32
)
# Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Calculate loss
loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
loss = self.ctc_loss(loss_input, logits)
loss = tf.reduce_mean(loss)
# Calculate PER (Phoneme Error Rate)
# Greedy decoding
predicted_ids = tf.argmax(logits, axis=-1)
# Remove blanks and consecutive duplicates
batch_edit_distance = 0
for i in range(tf.shape(logits)[0]):
pred_seq = predicted_ids[i, :adjusted_lens[i]]
# Remove consecutive duplicates
pred_seq = tf.py_function(
func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]),
inp=[pred_seq],
Tout=tf.int64
)
# Remove blanks (assuming blank_index=0)
pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0)
true_seq = labels[i, :phone_seq_lens[i]]
# Calculate edit distance
edit_dist = tf.edit_distance(
tf.SparseTensor(
indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1),
values=tf.cast(pred_seq, tf.int64),
dense_shape=[tf.size(pred_seq)]
),
tf.SparseTensor(
indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1),
values=tf.cast(true_seq, tf.int64),
dense_shape=[tf.size(true_seq)]
),
normalize=False
)
batch_edit_distance += edit_dist
return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens)
def train(self) -> Dict[str, Any]:
"""Main training loop"""
self.logger.info("Starting training loop...")
# Log initial TPU status
initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# Create distributed datasets
train_dataset = create_input_fn(
self.train_dataset_tf,
self.args['dataset']['data_transforms'],
training=True
)
val_dataset = create_input_fn(
self.val_dataset_tf,
self.args['dataset']['data_transforms'],
training=False
)
# Distribute datasets with timing
self.logger.info("🔄 Distributing training dataset across TPU cores...")
dist_start_time = time.time()
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
train_dist_time = time.time() - dist_start_time
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
val_start_time = time.time()
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
val_dist_time = time.time() - val_start_time
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
self.logger.info("Created distributed training and validation datasets")
# Training metrics
train_losses = []
val_losses = []
val_pers = []
val_results = []
val_steps_since_improvement = 0
self.logger.info("Training time count beginning...")
train_start_time = time.time()
# Training loop
step = 0
# Add timing diagnostic for first batch iteration
self.logger.info("🔄 Starting training loop iteration...")
loop_start_time = time.time()
for batch in train_dist_dataset:
if step == 0:
first_batch_iteration_time = time.time() - loop_start_time
self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s")
if step >= self.args['num_training_batches']:
self.logger.info("Reached maximum training batches, stopping training")
break
start_time = time.time()
# Distributed training step
self.logger.info("Running distributed training step...")
# Ensure we're in the correct TPU strategy scope
try:
with self.strategy.scope():
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
except AttributeError as e:
if "merge_call" in str(e):
error_msg = f"Strategy merge_call error at step {step}: {e}"
print(error_msg)
if self.logger:
self.logger.error(error_msg)
self.logger.error("This indicates the strategy is not properly initialized")
raise RuntimeError(f"TPU strategy failed during training step {step}: {e}")
else:
raise
# Reduce across replicas
self.logger.info("Reducing results across replicas...")
loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)
train_step_duration = time.time() - start_time
train_losses.append(float(loss.numpy()))
# Log training progress with TPU status
if step % self.args['batches_per_train_log'] == 0:
tpu_status = self._get_tpu_status()
self.logger.info(f'Train batch {step}: '
f'loss: {float(loss.numpy()):.2f} '
f'grad norm: {float(grad_norm.numpy()):.2f} '
f'time: {train_step_duration:.3f}s '
f'| {tpu_status}')
# Validation step
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
self.logger.info(f"Running validation after training batch: {step}")
val_start_time = time.time()
val_metrics = self._validate(val_dist_dataset)
val_step_duration = time.time() - val_start_time
tpu_status = self._get_tpu_status()
self.logger.info(f'Val batch {step}: '
f'PER (avg): {val_metrics["avg_per"]:.4f} '
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
f'time: {val_step_duration:.3f}s '
f'| {tpu_status}')
val_pers.append(val_metrics['avg_per'])
val_losses.append(val_metrics['avg_loss'])
val_results.append(val_metrics)
# Check for improvement
new_best = False
if val_metrics['avg_per'] < self.best_val_per:
self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}")
self.best_val_per = val_metrics['avg_per']
self.best_val_loss = val_metrics['avg_loss']
new_best = True
elif (val_metrics['avg_per'] == self.best_val_per and
val_metrics['avg_loss'] < self.best_val_loss):
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
self.best_val_loss = val_metrics['avg_loss']
new_best = True
if new_best:
if self.args.get('save_best_checkpoint', True):
self.logger.info("Checkpointing model")
self._save_checkpoint('best_checkpoint', step)
if self.args.get('save_val_metrics', True):
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
pickle.dump(val_metrics, f)
val_steps_since_improvement = 0
else:
val_steps_since_improvement += 1
# Optional save all validation checkpoints
if self.args.get('save_all_val_steps', False):
self._save_checkpoint(f'checkpoint_batch_{step}', step)
# Early stopping
if (self.args.get('early_stopping', False) and
val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)):
self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} '
f'validation steps. Stopping training early at batch: {step}')
break
step += 1
# Training completed
training_duration = time.time() - train_start_time
self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}')
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
# Save final model
if self.args.get('save_final_model', False):
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1)
return {
'train_losses': train_losses,
'val_losses': val_losses,
'val_pers': val_pers,
'val_metrics': val_results
}
def _validate(self, val_dataset) -> Dict[str, Any]:
"""Run validation on entire validation dataset"""
total_loss = 0.0
total_edit_distance = 0
total_seq_length = 0
num_batches = 0
for batch in val_dataset:
per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = (
self.strategy.run(self._validation_step, args=(batch,))
)
# Reduce across replicas
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None)
batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None)
total_loss += float(batch_loss.numpy())
total_edit_distance += float(batch_edit_distance.numpy())
total_seq_length += float(batch_seq_length.numpy())
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
avg_per = total_edit_distance / max(total_seq_length, 1e-6)
return {
'avg_loss': avg_loss,
'avg_per': avg_per,
'total_edit_distance': total_edit_distance,
'total_seq_length': total_seq_length,
'num_batches': num_batches
}
def _save_checkpoint(self, name: str, step: int):
"""Save model checkpoint"""
checkpoint_path = os.path.join(self.args['checkpoint_dir'], name)
# Save model weights
self.model.save_weights(checkpoint_path + '.weights.h5')
# Save optimizer state
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
optimizer_checkpoint.save(checkpoint_path + '.optimizer')
# Save training state
state = {
'step': step,
'best_val_per': float(self.best_val_per),
'best_val_loss': float(self.best_val_loss)
}
with open(checkpoint_path + '.state.json', 'w') as f:
json.dump(state, f)
# Save config
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
self.logger.info(f"Saved checkpoint: {checkpoint_path}")
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint"""
# Load model weights
self.model.load_weights(checkpoint_path + '.weights.h5')
# Load optimizer state
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1')
# Load training state
with open(checkpoint_path + '.state.json', 'r') as f:
state = json.load(f)
self.best_val_per = state['best_val_per']
self.best_val_loss = state['best_val_loss']
self.logger.info(f"Loaded checkpoint: {checkpoint_path}")
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
"""
Run inference on input features
Args:
features: Input neural features [batch_size, time_steps, features]
day_indices: Day indices [batch_size]
n_time_steps: Number of valid time steps [batch_size]
mode: 'inference' or 'full'
Returns:
Phoneme logits [batch_size, time_steps, n_classes]
"""
# Apply data transformations (no augmentation)
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
)
# Run model inference
logits = self.model(features, day_indices, None, False, mode, training=False)
return logits