980 lines
41 KiB
Python
980 lines
41 KiB
Python
import os
|
||
import tensorflow as tf
|
||
import numpy as np
|
||
import time
|
||
import json
|
||
import pickle
|
||
import logging
|
||
import pathlib
|
||
import sys
|
||
from typing import Dict, Any, Tuple, Optional, List
|
||
from omegaconf import OmegaConf
|
||
|
||
from rnn_model_tf import (
|
||
TripleGRUDecoder,
|
||
CTCLoss,
|
||
create_tpu_strategy,
|
||
build_model_for_tpu,
|
||
configure_mixed_precision
|
||
)
|
||
from dataset_tf import (
|
||
BrainToTextDatasetTF,
|
||
DataAugmentationTF,
|
||
train_test_split_indices,
|
||
create_input_fn
|
||
)
|
||
|
||
|
||
class BrainToTextDecoderTrainerTF:
|
||
"""
|
||
TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
|
||
|
||
This trainer implements the same training logic as the PyTorch version but uses
|
||
TensorFlow operations optimized for TPU hardware.
|
||
"""
|
||
|
||
def __init__(self, args: Dict[str, Any]):
|
||
"""
|
||
Initialize the TensorFlow trainer
|
||
|
||
Args:
|
||
args: Configuration dictionary containing all training parameters
|
||
"""
|
||
self.args = args
|
||
self.logger = None
|
||
|
||
# Optimize CPU utilization for data pipeline (利用224核心)
|
||
self._configure_cpu_optimization()
|
||
|
||
# Initialize TPU strategy
|
||
self.strategy = create_tpu_strategy()
|
||
if self.strategy is None:
|
||
raise RuntimeError("Failed to create TPU strategy - strategy is None")
|
||
|
||
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
|
||
print(f"Strategy type: {type(self.strategy).__name__}")
|
||
|
||
# Configure mixed precision for TPU v5e-8
|
||
if args.get('use_amp', True):
|
||
configure_mixed_precision()
|
||
self.mixed_precision = True
|
||
else:
|
||
self.mixed_precision = False
|
||
|
||
# Initialize tracking variables
|
||
self.best_val_per = float('inf')
|
||
self.best_val_loss = float('inf')
|
||
|
||
# Setup directories
|
||
if args['mode'] == 'train':
|
||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||
|
||
if (args.get('save_best_checkpoint', True) or
|
||
args.get('save_all_val_steps', False) or
|
||
args.get('save_final_model', False)):
|
||
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
|
||
|
||
# Setup logging
|
||
self._setup_logging()
|
||
|
||
# Set random seeds
|
||
if self.args['seed'] != -1:
|
||
tf.random.set_seed(self.args['seed'])
|
||
np.random.seed(self.args['seed'])
|
||
|
||
# Initialize datasets
|
||
self._initialize_datasets()
|
||
|
||
|
||
# Build model within strategy scope
|
||
with self.strategy.scope():
|
||
print("🔨 Building model within TPU strategy scope...")
|
||
self.model = self._build_model()
|
||
print("✅ Model built successfully")
|
||
|
||
print("⚙️ Creating optimizer...")
|
||
self.optimizer = self._create_optimizer()
|
||
print("✅ Optimizer created")
|
||
|
||
print("🔧 Pre-building optimizer state for TPU...")
|
||
# For TPU, we must ensure optimizer is completely ready before training
|
||
# since @tf.function doesn't allow dynamic building
|
||
try:
|
||
print("✅ Building optimizer with model variables...")
|
||
|
||
# Explicitly build the optimizer with model variables
|
||
print(f"Building optimizer with {len(self.model.trainable_variables)} variables")
|
||
self.optimizer.build(self.model.trainable_variables)
|
||
print("✅ Optimizer built with model variables")
|
||
|
||
# Verify optimizer is properly built - just check iterations
|
||
print(f"Optimizer iterations: {self.optimizer.iterations}")
|
||
|
||
# For TPU training, we should also ensure the optimizer has all its state
|
||
# variables created. We can do this by creating dummy variables that match
|
||
# the model variables, but we don't apply them (avoid the replica context issue)
|
||
print("🔄 Ensuring optimizer state variables are created...")
|
||
|
||
# Force creation of optimizer variables by accessing them
|
||
# This is safe and doesn't require replica context
|
||
_ = self.optimizer.iterations # This ensures basic state is created
|
||
|
||
print("✅ Optimizer fully ready for TPU training")
|
||
print("📝 Note: Optimizer will work correctly in @tf.function context")
|
||
|
||
except Exception as e:
|
||
print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}")
|
||
print(f"Error type: {type(e).__name__}")
|
||
import traceback
|
||
print(f"Full traceback: {traceback.format_exc()}")
|
||
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
|
||
|
||
print("📅 Setting up learning rate scheduler...")
|
||
self.lr_scheduler = self._create_lr_scheduler()
|
||
print("✅ LR scheduler ready")
|
||
|
||
print("🎯 Initializing CTC loss...")
|
||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||
print("✅ CTC loss initialized")
|
||
|
||
# Log model information
|
||
self._log_model_info()
|
||
|
||
# Adversarial training configuration
|
||
adv_cfg = self.args.get('adversarial', {})
|
||
self.adv_enabled = adv_cfg.get('enabled', False)
|
||
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5))
|
||
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2))
|
||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
||
|
||
# TPU-specific weight decay handling
|
||
self.manual_weight_decay = False
|
||
if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0:
|
||
self.manual_weight_decay = True
|
||
self.weight_decay_rate = self.args['weight_decay']
|
||
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
|
||
|
||
if self.adv_enabled:
|
||
if self.logger:
|
||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
|
||
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
|
||
f"noise_l2_weight={self.adv_noise_l2_weight}, "
|
||
f"warmup_steps={self.adv_warmup_steps}")
|
||
else:
|
||
print(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
|
||
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
|
||
f"noise_l2_weight={self.adv_noise_l2_weight}, "
|
||
f"warmup_steps={self.adv_warmup_steps}")
|
||
|
||
def _setup_logging(self):
|
||
"""Setup logging configuration"""
|
||
self.logger = logging.getLogger(__name__)
|
||
for handler in self.logger.handlers[:]:
|
||
self.logger.removeHandler(handler)
|
||
self.logger.setLevel(logging.INFO)
|
||
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
|
||
|
||
if self.args['mode'] == 'train':
|
||
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log')))
|
||
fh.setFormatter(formatter)
|
||
self.logger.addHandler(fh)
|
||
|
||
sh = logging.StreamHandler(sys.stdout)
|
||
sh.setFormatter(formatter)
|
||
self.logger.addHandler(sh)
|
||
|
||
self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas')
|
||
if self.mixed_precision:
|
||
self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
|
||
|
||
def _configure_cpu_optimization(self):
|
||
"""Configure CPU utilization to make use of 224 cores for data pipeline"""
|
||
import multiprocessing
|
||
|
||
# Get available CPU cores
|
||
available_cores = multiprocessing.cpu_count()
|
||
print(f"💻 Available CPU cores: {available_cores}")
|
||
|
||
# Optimize for data pipeline parallelism
|
||
# For 224 cores, use more threads for better data loading performance
|
||
if available_cores >= 200: # High core count system
|
||
inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores
|
||
intra_op_threads = min(32, available_cores // 6)
|
||
else:
|
||
# Use ~1/4 of cores for inter-op (between operations)
|
||
# Use ~1/8 of cores for intra-op (within operations)
|
||
inter_op_threads = min(32, available_cores // 4)
|
||
intra_op_threads = min(16, available_cores // 8)
|
||
|
||
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
|
||
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
|
||
|
||
print(f"🔧 CPU optimization configured:")
|
||
print(f" Inter-op parallelism: {inter_op_threads} threads")
|
||
print(f" Intra-op parallelism: {intra_op_threads} threads")
|
||
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
||
|
||
def _get_tpu_status(self) -> str:
|
||
"""Get current TPU status and HBM utilization info"""
|
||
try:
|
||
# Get TPU devices
|
||
tpu_devices = tf.config.list_logical_devices('TPU')
|
||
|
||
if not tpu_devices:
|
||
return "TPU: No devices"
|
||
|
||
# Get strategy info
|
||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||
|
||
# Get TPU memory info using the working /device:TPU:X format
|
||
try:
|
||
# Check all TPU devices for memory usage
|
||
active_cores = 0
|
||
total_current_mb = 0
|
||
max_peak_mb = 0
|
||
|
||
for device in tpu_devices:
|
||
try:
|
||
memory_info = tf.config.experimental.get_memory_info(device.name)
|
||
if memory_info and 'current' in memory_info:
|
||
current_mb = memory_info['current'] // (1024 * 1024)
|
||
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
|
||
|
||
if current_mb > 1: # >1MB considered active
|
||
active_cores += 1
|
||
total_current_mb += current_mb
|
||
max_peak_mb = max(max_peak_mb, peak_mb)
|
||
except:
|
||
continue
|
||
|
||
if active_cores > 0:
|
||
if active_cores == 1:
|
||
hbm_info = f"HBM:{total_current_mb}MB(peak:{max_peak_mb}MB)"
|
||
else:
|
||
hbm_info = f"HBM:{total_current_mb}MB/{active_cores}cores(peak:{max_peak_mb}MB)"
|
||
else:
|
||
hbm_info = "HBM:idle"
|
||
|
||
except Exception:
|
||
# Fallback: simple TPU activity check
|
||
try:
|
||
with tf.device('/TPU:0'):
|
||
_ = tf.constant(1.0)
|
||
hbm_info = "HBM:active"
|
||
except Exception:
|
||
hbm_info = "HBM:inactive"
|
||
|
||
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
|
||
f"{hbm_info}")
|
||
|
||
except Exception as e:
|
||
return f"TPU: status_error({str(e)[:20]})"
|
||
|
||
def _get_detailed_tpu_status(self) -> str:
|
||
"""Get detailed TPU status for training start"""
|
||
try:
|
||
# Get TPU devices
|
||
tpu_devices = tf.config.list_logical_devices('TPU')
|
||
|
||
if not tpu_devices:
|
||
return "❌ No TPU devices detected"
|
||
|
||
# Get strategy info
|
||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||
strategy_type = type(self.strategy).__name__
|
||
|
||
# Get TPU HBM memory info using working device format
|
||
try:
|
||
active_cores = 0
|
||
total_current_gb = 0
|
||
max_peak_gb = 0
|
||
memory_details = []
|
||
|
||
for i, device in enumerate(tpu_devices):
|
||
try:
|
||
memory_info = tf.config.experimental.get_memory_info(device.name)
|
||
if memory_info and 'current' in memory_info:
|
||
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
|
||
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
|
||
|
||
if current_gb > 0 or memory_info['current'] > 1024*1024: # >1MB
|
||
active_cores += 1
|
||
total_current_gb += current_gb
|
||
max_peak_gb = max(max_peak_gb, peak_gb)
|
||
if current_gb > 0:
|
||
memory_details.append(f"Core{i}:{current_gb}GB")
|
||
except:
|
||
continue
|
||
|
||
if active_cores > 0:
|
||
# Based on your test: TPU:0 peaked at 14.5GB, suggesting ~16GB per core
|
||
estimated_per_core = 16 # Conservative estimate
|
||
estimated_total_gb = estimated_per_core * len(tpu_devices)
|
||
hbm_usage = f"HBM: {total_current_gb}GB/{estimated_total_gb}GB (peak: {max_peak_gb}GB) active:{active_cores}cores"
|
||
else:
|
||
hbm_usage = "HBM: 0GB/256GB (idle)"
|
||
|
||
except Exception:
|
||
hbm_usage = "HBM: unavailable"
|
||
|
||
# Simple TPU test
|
||
try:
|
||
with tf.device('/TPU:0'):
|
||
test_result = tf.constant([1.0, 2.0])
|
||
_ = tf.reduce_sum(test_result)
|
||
tpu_test = "✅ responsive"
|
||
except Exception as e:
|
||
tpu_test = f"❌ test_failed({str(e)[:15]})"
|
||
|
||
return (f"TPU Devices: {len(tpu_devices)} | "
|
||
f"Strategy: {strategy_type} | "
|
||
f"Cores: {num_replicas} | "
|
||
f"{hbm_usage} | "
|
||
f"Test: {tpu_test}")
|
||
|
||
except Exception as e:
|
||
return f"❌ TPU status check failed: {str(e)[:50]}"
|
||
|
||
def _initialize_datasets(self):
|
||
"""Initialize training and validation datasets"""
|
||
# Create file paths
|
||
train_file_paths = [
|
||
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
|
||
for s in self.args['dataset']['sessions']
|
||
]
|
||
val_file_paths = [
|
||
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5')
|
||
for s in self.args['dataset']['sessions']
|
||
]
|
||
|
||
# Validate no duplicates
|
||
if len(set(train_file_paths)) != len(train_file_paths):
|
||
raise ValueError("Duplicate sessions in train dataset")
|
||
if len(set(val_file_paths)) != len(val_file_paths):
|
||
raise ValueError("Duplicate sessions in val dataset")
|
||
|
||
# Split trials
|
||
train_trials, _ = train_test_split_indices(
|
||
file_paths=train_file_paths,
|
||
test_percentage=0,
|
||
seed=self.args['dataset']['seed'],
|
||
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
|
||
)
|
||
|
||
_, val_trials = train_test_split_indices(
|
||
file_paths=val_file_paths,
|
||
test_percentage=1,
|
||
seed=self.args['dataset']['seed'],
|
||
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
|
||
)
|
||
|
||
# Save trial splits
|
||
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
|
||
json.dump({'train': train_trials, 'val': val_trials}, f)
|
||
|
||
# Create TensorFlow datasets with aggressive data preloading for TPU optimization
|
||
# Monitor memory usage during data preloading
|
||
import psutil
|
||
initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
|
||
|
||
print("🔄 Initializing training dataset with GPU-style memory management...")
|
||
init_start_time = time.time()
|
||
self.train_dataset_tf = BrainToTextDatasetTF(
|
||
trial_indices=train_trials,
|
||
n_batches=self.args['num_training_batches'],
|
||
split='train',
|
||
batch_size=self.args['dataset']['batch_size'],
|
||
days_per_batch=self.args['dataset']['days_per_batch'],
|
||
random_seed=self.args['dataset']['seed'],
|
||
must_include_days=self.args['dataset'].get('must_include_days'),
|
||
feature_subset=self.args['dataset'].get('feature_subset'),
|
||
cache_data=True, # 启用智能缓存(像GPU版本一样)
|
||
preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出
|
||
)
|
||
|
||
# Log training dataset initialization performance
|
||
train_init_time = time.time() - init_start_time
|
||
train_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
|
||
train_memory_used = train_memory_mb - initial_memory_mb
|
||
print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM")
|
||
|
||
print("🔄 Initializing validation dataset with GPU-style memory management...")
|
||
val_init_start_time = time.time()
|
||
self.val_dataset_tf = BrainToTextDatasetTF(
|
||
trial_indices=val_trials,
|
||
n_batches=None, # Use all validation data
|
||
split='test',
|
||
batch_size=self.args['dataset']['batch_size'],
|
||
days_per_batch=1, # One day per validation batch
|
||
random_seed=self.args['dataset']['seed'],
|
||
feature_subset=self.args['dataset'].get('feature_subset'),
|
||
cache_data=True, # 启用智能缓存(像GPU版本一样)
|
||
preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出
|
||
)
|
||
|
||
# Log validation dataset initialization performance
|
||
val_init_time = time.time() - val_init_start_time
|
||
final_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
|
||
total_memory_used = final_memory_mb - initial_memory_mb
|
||
val_memory_used = final_memory_mb - train_memory_mb
|
||
print(f"✅ Validation dataset initialized in {val_init_time:.2f}s, using {val_memory_used:.1f} MB RAM")
|
||
print(f"📊 Total data cache: {total_memory_used:.1f} MB RAM used for all datasets")
|
||
|
||
self.logger.info("Successfully initialized TensorFlow datasets")
|
||
|
||
def _build_model(self) -> TripleGRUDecoder:
|
||
"""Build the TripleGRUDecoder model"""
|
||
model = TripleGRUDecoder(
|
||
neural_dim=self.args['model']['n_input_features'],
|
||
n_units=self.args['model']['n_units'],
|
||
n_days=len(self.args['dataset']['sessions']),
|
||
n_classes=self.args['dataset']['n_classes'],
|
||
rnn_dropout=self.args['model']['rnn_dropout'],
|
||
input_dropout=self.args['model']['input_network']['input_layer_dropout'],
|
||
patch_size=self.args['model']['patch_size'],
|
||
patch_stride=self.args['model']['patch_stride']
|
||
)
|
||
return model
|
||
|
||
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
|
||
"""Create AdamW optimizer with parameter groups"""
|
||
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
||
# We'll use a single optimizer and handle different learning rates in the scheduler
|
||
|
||
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
||
|
||
# For TPU training, we need to be more explicit about optimizer configuration
|
||
# to avoid strategy context issues
|
||
if isinstance(self.strategy, tf.distribute.TPUStrategy):
|
||
print("Using TPU-optimized optimizer configuration")
|
||
# TPU-specific optimizer configuration
|
||
# IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues
|
||
# We'll implement manual L2 regularization instead
|
||
optimizer = tf.keras.optimizers.AdamW(
|
||
learning_rate=self.args['lr_max'],
|
||
beta_1=self.args['beta0'],
|
||
beta_2=self.args['beta1'],
|
||
epsilon=self.args['epsilon'],
|
||
weight_decay=0.0, # Disabled for TPU compatibility
|
||
# TPU-specific settings
|
||
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
|
||
)
|
||
print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})")
|
||
print("💡 Consider implementing manual L2 regularization if needed")
|
||
else:
|
||
print("Using standard optimizer configuration")
|
||
optimizer = tf.keras.optimizers.AdamW(
|
||
learning_rate=self.args['lr_max'],
|
||
beta_1=self.args['beta0'],
|
||
beta_2=self.args['beta1'],
|
||
epsilon=self.args['epsilon'],
|
||
weight_decay=self.args['weight_decay']
|
||
)
|
||
|
||
return optimizer
|
||
|
||
def _create_lr_scheduler(self):
|
||
"""Create learning rate scheduler"""
|
||
if self.args['lr_scheduler_type'] == 'cosine':
|
||
return self._create_cosine_scheduler()
|
||
elif self.args['lr_scheduler_type'] == 'linear':
|
||
return tf.keras.optimizers.schedules.PolynomialDecay(
|
||
initial_learning_rate=self.args['lr_max'],
|
||
decay_steps=self.args['lr_decay_steps'],
|
||
end_learning_rate=self.args['lr_min'],
|
||
power=1.0 # Linear decay
|
||
)
|
||
else:
|
||
raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}")
|
||
|
||
def _create_cosine_scheduler(self):
|
||
"""Create cosine learning rate scheduler"""
|
||
return tf.keras.optimizers.schedules.CosineDecayRestarts(
|
||
initial_learning_rate=self.args['lr_max'],
|
||
first_decay_steps=self.args['lr_decay_steps'],
|
||
t_mul=1.0,
|
||
m_mul=1.0,
|
||
alpha=self.args['lr_min'] / self.args['lr_max']
|
||
)
|
||
|
||
def _log_model_info(self):
|
||
"""Log model architecture and parameter information"""
|
||
if self.logger:
|
||
self.logger.info("Initialized TripleGRUDecoder model")
|
||
else:
|
||
print("Initialized TripleGRUDecoder model")
|
||
|
||
# Build the model by calling it once with dummy data
|
||
dummy_batch_size = 2
|
||
dummy_time_steps = 100
|
||
dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
|
||
dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
|
||
|
||
# Call the model to build it
|
||
_ = self.model(dummy_features, dummy_day_idx, training=False)
|
||
|
||
# Count parameters
|
||
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
|
||
|
||
if self.logger:
|
||
self.logger.info(f"Model has {total_params:,} trainable parameters")
|
||
else:
|
||
print(f"Model has {total_params:,} trainable parameters")
|
||
|
||
@tf.function
|
||
def _train_step(self, batch, step):
|
||
"""Single training step with gradient tape"""
|
||
features = batch['input_features']
|
||
labels = batch['seq_class_ids']
|
||
n_time_steps = batch['n_time_steps']
|
||
phone_seq_lens = batch['phone_seq_lens']
|
||
day_indices = batch['day_indices']
|
||
|
||
with tf.GradientTape() as tape:
|
||
# Apply data transformations
|
||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||
features, n_time_steps, self.args['dataset']['data_transforms'], training=True
|
||
)
|
||
|
||
# Calculate adjusted lengths for CTC
|
||
adjusted_lens = tf.cast(
|
||
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
||
self.args['model']['patch_stride'] + 1,
|
||
tf.int32
|
||
)
|
||
|
||
# Forward pass
|
||
use_full = self.adv_enabled and (step >= self.adv_warmup_steps)
|
||
if use_full:
|
||
clean_logits, noisy_logits, noise_output = self.model(
|
||
features, day_indices, None, False, 'full',
|
||
grl_lambda=self.adv_grl_lambda, training=True
|
||
)
|
||
else:
|
||
clean_logits = self.model(
|
||
features, day_indices, None, False, 'inference', training=True
|
||
)
|
||
|
||
# Calculate losses
|
||
if use_full:
|
||
# Clean CTC loss
|
||
clean_loss_input = {
|
||
'labels': labels,
|
||
'input_lengths': adjusted_lens,
|
||
'label_lengths': phone_seq_lens
|
||
}
|
||
clean_loss = self.ctc_loss(clean_loss_input, clean_logits)
|
||
clean_loss = tf.reduce_mean(clean_loss)
|
||
|
||
# Noisy CTC loss
|
||
noisy_loss_input = {
|
||
'labels': labels,
|
||
'input_lengths': adjusted_lens,
|
||
'label_lengths': phone_seq_lens
|
||
}
|
||
noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits)
|
||
noisy_loss = tf.reduce_mean(noisy_loss)
|
||
|
||
# Optional noise L2 regularization
|
||
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
|
||
if self.adv_noise_l2_weight > 0.0:
|
||
noise_l2 = tf.reduce_mean(tf.square(noise_output))
|
||
|
||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||
else:
|
||
loss_input = {
|
||
'labels': labels,
|
||
'input_lengths': adjusted_lens,
|
||
'label_lengths': phone_seq_lens
|
||
}
|
||
loss = self.ctc_loss(loss_input, clean_logits)
|
||
loss = tf.reduce_mean(loss)
|
||
|
||
# Add manual L2 regularization for TPU (since weight_decay is disabled)
|
||
if self.manual_weight_decay:
|
||
l2_loss = tf.constant(0.0, dtype=loss.dtype)
|
||
for var in self.model.trainable_variables:
|
||
l2_loss += tf.nn.l2_loss(var)
|
||
loss += self.weight_decay_rate * l2_loss
|
||
|
||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||
# TPU v5e-8使用bfloat16,不需要loss scaling
|
||
|
||
# Calculate gradients - TensorFlow自动处理混合精度
|
||
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||
|
||
# Filter out None gradients (for h0 variables that don't need gradients)
|
||
filtered_grads_and_vars = []
|
||
for grad, var in zip(gradients, self.model.trainable_variables):
|
||
if grad is not None:
|
||
filtered_grads_and_vars.append((grad, var))
|
||
else:
|
||
# Log which variables don't have gradients (informational)
|
||
tf.print(f"No gradient for variable: {var.name}")
|
||
|
||
# Extract filtered gradients and variables
|
||
filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
|
||
filtered_variables = [var for _, var in filtered_grads_and_vars]
|
||
|
||
# Clip gradients
|
||
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
|
||
filtered_gradients, grad_norm = tf.clip_by_global_norm(
|
||
filtered_gradients, self.args['grad_norm_clip_value']
|
||
)
|
||
elif len(filtered_gradients) > 0:
|
||
grad_norm = tf.global_norm(filtered_gradients)
|
||
else:
|
||
grad_norm = tf.constant(0.0)
|
||
|
||
# Apply gradients (only for variables that have gradients)
|
||
if len(filtered_gradients) > 0:
|
||
# Apply gradients directly - optimizer should be pre-built and ready
|
||
# In @tf.function, we need to keep error handling simple
|
||
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
||
|
||
return loss, grad_norm
|
||
|
||
@tf.function
|
||
def _validation_step(self, batch):
|
||
"""Single validation step"""
|
||
features = batch['input_features']
|
||
labels = batch['seq_class_ids']
|
||
n_time_steps = batch['n_time_steps']
|
||
phone_seq_lens = batch['phone_seq_lens']
|
||
day_indices = batch['day_indices']
|
||
|
||
# Apply data transformations (no augmentation for validation)
|
||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
|
||
)
|
||
|
||
# Calculate adjusted lengths
|
||
adjusted_lens = tf.cast(
|
||
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
||
self.args['model']['patch_stride'] + 1,
|
||
tf.int32
|
||
)
|
||
|
||
# Forward pass (inference mode only)
|
||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||
|
||
# Calculate loss
|
||
loss_input = {
|
||
'labels': labels,
|
||
'input_lengths': adjusted_lens,
|
||
'label_lengths': phone_seq_lens
|
||
}
|
||
loss = self.ctc_loss(loss_input, logits)
|
||
loss = tf.reduce_mean(loss)
|
||
|
||
# Calculate PER (Phoneme Error Rate)
|
||
# Greedy decoding
|
||
predicted_ids = tf.argmax(logits, axis=-1)
|
||
|
||
# Remove blanks and consecutive duplicates
|
||
batch_edit_distance = 0
|
||
for i in range(tf.shape(logits)[0]):
|
||
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
||
# Remove consecutive duplicates
|
||
pred_seq = tf.py_function(
|
||
func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]),
|
||
inp=[pred_seq],
|
||
Tout=tf.int64
|
||
)
|
||
# Remove blanks (assuming blank_index=0)
|
||
pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0)
|
||
|
||
true_seq = labels[i, :phone_seq_lens[i]]
|
||
|
||
# Calculate edit distance
|
||
edit_dist = tf.edit_distance(
|
||
tf.SparseTensor(
|
||
indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1),
|
||
values=tf.cast(pred_seq, tf.int64),
|
||
dense_shape=[tf.size(pred_seq)]
|
||
),
|
||
tf.SparseTensor(
|
||
indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1),
|
||
values=tf.cast(true_seq, tf.int64),
|
||
dense_shape=[tf.size(true_seq)]
|
||
),
|
||
normalize=False
|
||
)
|
||
|
||
batch_edit_distance += edit_dist
|
||
|
||
return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens)
|
||
|
||
def train(self) -> Dict[str, Any]:
|
||
"""Main training loop"""
|
||
self.logger.info("Starting training loop...")
|
||
|
||
# Log initial TPU status
|
||
initial_tpu_status = self._get_detailed_tpu_status()
|
||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||
|
||
# Create distributed datasets
|
||
train_dataset = create_input_fn(
|
||
self.train_dataset_tf,
|
||
self.args['dataset']['data_transforms'],
|
||
training=True
|
||
)
|
||
|
||
val_dataset = create_input_fn(
|
||
self.val_dataset_tf,
|
||
self.args['dataset']['data_transforms'],
|
||
training=False
|
||
)
|
||
# Distribute datasets with timing
|
||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
||
dist_start_time = time.time()
|
||
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
|
||
train_dist_time = time.time() - dist_start_time
|
||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||
|
||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
||
val_start_time = time.time()
|
||
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
|
||
val_dist_time = time.time() - val_start_time
|
||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||
|
||
self.logger.info("Created distributed training and validation datasets")
|
||
# Training metrics
|
||
train_losses = []
|
||
val_losses = []
|
||
val_pers = []
|
||
val_results = []
|
||
val_steps_since_improvement = 0
|
||
self.logger.info("Training time count beginning...")
|
||
train_start_time = time.time()
|
||
|
||
# Training loop
|
||
step = 0
|
||
|
||
# Add timing diagnostic for first batch iteration
|
||
self.logger.info("🔄 Starting training loop iteration...")
|
||
loop_start_time = time.time()
|
||
|
||
for batch in train_dist_dataset:
|
||
if step == 0:
|
||
first_batch_iteration_time = time.time() - loop_start_time
|
||
self.logger.info(f"✅ First batch iteration completed in {first_batch_iteration_time:.2f}s")
|
||
if step >= self.args['num_training_batches']:
|
||
self.logger.info("Reached maximum training batches, stopping training")
|
||
break
|
||
|
||
start_time = time.time()
|
||
|
||
# Distributed training step
|
||
self.logger.info("Running distributed training step...")
|
||
# Ensure we're in the correct TPU strategy scope
|
||
try:
|
||
with self.strategy.scope():
|
||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||
self._train_step, args=(batch, step)
|
||
)
|
||
except AttributeError as e:
|
||
if "merge_call" in str(e):
|
||
error_msg = f"Strategy merge_call error at step {step}: {e}"
|
||
print(error_msg)
|
||
if self.logger:
|
||
self.logger.error(error_msg)
|
||
self.logger.error("This indicates the strategy is not properly initialized")
|
||
raise RuntimeError(f"TPU strategy failed during training step {step}: {e}")
|
||
else:
|
||
raise
|
||
|
||
# Reduce across replicas
|
||
self.logger.info("Reducing results across replicas...")
|
||
loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||
grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)
|
||
|
||
train_step_duration = time.time() - start_time
|
||
train_losses.append(float(loss.numpy()))
|
||
|
||
# Log training progress with TPU status
|
||
if step % self.args['batches_per_train_log'] == 0:
|
||
tpu_status = self._get_tpu_status()
|
||
self.logger.info(f'Train batch {step}: '
|
||
f'loss: {float(loss.numpy()):.2f} '
|
||
f'grad norm: {float(grad_norm.numpy()):.2f} '
|
||
f'time: {train_step_duration:.3f}s '
|
||
f'| {tpu_status}')
|
||
|
||
# Validation step
|
||
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
|
||
self.logger.info(f"Running validation after training batch: {step}")
|
||
|
||
val_start_time = time.time()
|
||
val_metrics = self._validate(val_dist_dataset)
|
||
val_step_duration = time.time() - val_start_time
|
||
|
||
tpu_status = self._get_tpu_status()
|
||
self.logger.info(f'Val batch {step}: '
|
||
f'PER (avg): {val_metrics["avg_per"]:.4f} '
|
||
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
|
||
f'time: {val_step_duration:.3f}s '
|
||
f'| {tpu_status}')
|
||
|
||
val_pers.append(val_metrics['avg_per'])
|
||
val_losses.append(val_metrics['avg_loss'])
|
||
val_results.append(val_metrics)
|
||
|
||
# Check for improvement
|
||
new_best = False
|
||
if val_metrics['avg_per'] < self.best_val_per:
|
||
self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}")
|
||
self.best_val_per = val_metrics['avg_per']
|
||
self.best_val_loss = val_metrics['avg_loss']
|
||
new_best = True
|
||
elif (val_metrics['avg_per'] == self.best_val_per and
|
||
val_metrics['avg_loss'] < self.best_val_loss):
|
||
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
|
||
self.best_val_loss = val_metrics['avg_loss']
|
||
new_best = True
|
||
|
||
if new_best:
|
||
if self.args.get('save_best_checkpoint', True):
|
||
self.logger.info("Checkpointing model")
|
||
self._save_checkpoint('best_checkpoint', step)
|
||
|
||
if self.args.get('save_val_metrics', True):
|
||
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
||
pickle.dump(val_metrics, f)
|
||
|
||
val_steps_since_improvement = 0
|
||
else:
|
||
val_steps_since_improvement += 1
|
||
|
||
# Optional save all validation checkpoints
|
||
if self.args.get('save_all_val_steps', False):
|
||
self._save_checkpoint(f'checkpoint_batch_{step}', step)
|
||
|
||
# Early stopping
|
||
if (self.args.get('early_stopping', False) and
|
||
val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)):
|
||
self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} '
|
||
f'validation steps. Stopping training early at batch: {step}')
|
||
break
|
||
|
||
step += 1
|
||
|
||
# Training completed
|
||
training_duration = time.time() - train_start_time
|
||
self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}')
|
||
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
|
||
|
||
# Save final model
|
||
if self.args.get('save_final_model', False):
|
||
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||
self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1)
|
||
|
||
return {
|
||
'train_losses': train_losses,
|
||
'val_losses': val_losses,
|
||
'val_pers': val_pers,
|
||
'val_metrics': val_results
|
||
}
|
||
|
||
def _validate(self, val_dataset) -> Dict[str, Any]:
|
||
"""Run validation on entire validation dataset"""
|
||
total_loss = 0.0
|
||
total_edit_distance = 0
|
||
total_seq_length = 0
|
||
num_batches = 0
|
||
|
||
for batch in val_dataset:
|
||
per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = (
|
||
self.strategy.run(self._validation_step, args=(batch,))
|
||
)
|
||
|
||
# Reduce across replicas
|
||
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||
batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None)
|
||
batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None)
|
||
|
||
total_loss += float(batch_loss.numpy())
|
||
total_edit_distance += float(batch_edit_distance.numpy())
|
||
total_seq_length += float(batch_seq_length.numpy())
|
||
num_batches += 1
|
||
|
||
avg_loss = total_loss / max(num_batches, 1)
|
||
avg_per = total_edit_distance / max(total_seq_length, 1e-6)
|
||
|
||
return {
|
||
'avg_loss': avg_loss,
|
||
'avg_per': avg_per,
|
||
'total_edit_distance': total_edit_distance,
|
||
'total_seq_length': total_seq_length,
|
||
'num_batches': num_batches
|
||
}
|
||
|
||
def _save_checkpoint(self, name: str, step: int):
|
||
"""Save model checkpoint"""
|
||
checkpoint_path = os.path.join(self.args['checkpoint_dir'], name)
|
||
|
||
# Save model weights
|
||
self.model.save_weights(checkpoint_path + '.weights.h5')
|
||
|
||
# Save optimizer state
|
||
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
|
||
optimizer_checkpoint.save(checkpoint_path + '.optimizer')
|
||
|
||
# Save training state
|
||
state = {
|
||
'step': step,
|
||
'best_val_per': float(self.best_val_per),
|
||
'best_val_loss': float(self.best_val_loss)
|
||
}
|
||
|
||
with open(checkpoint_path + '.state.json', 'w') as f:
|
||
json.dump(state, f)
|
||
|
||
# Save config
|
||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||
OmegaConf.save(config=self.args, f=f)
|
||
|
||
self.logger.info(f"Saved checkpoint: {checkpoint_path}")
|
||
|
||
def load_checkpoint(self, checkpoint_path: str):
|
||
"""Load model checkpoint"""
|
||
# Load model weights
|
||
self.model.load_weights(checkpoint_path + '.weights.h5')
|
||
|
||
# Load optimizer state
|
||
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
|
||
optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1')
|
||
|
||
# Load training state
|
||
with open(checkpoint_path + '.state.json', 'r') as f:
|
||
state = json.load(f)
|
||
|
||
self.best_val_per = state['best_val_per']
|
||
self.best_val_loss = state['best_val_loss']
|
||
|
||
self.logger.info(f"Loaded checkpoint: {checkpoint_path}")
|
||
|
||
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
|
||
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
|
||
"""
|
||
Run inference on input features
|
||
|
||
Args:
|
||
features: Input neural features [batch_size, time_steps, features]
|
||
day_indices: Day indices [batch_size]
|
||
n_time_steps: Number of valid time steps [batch_size]
|
||
mode: 'inference' or 'full'
|
||
|
||
Returns:
|
||
Phoneme logits [batch_size, time_steps, n_classes]
|
||
"""
|
||
# Apply data transformations (no augmentation)
|
||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
|
||
)
|
||
|
||
# Run model inference
|
||
logits = self.model(features, day_indices, None, False, mode, training=False)
|
||
|
||
return logits |