This commit is contained in:
Zchen
2025-10-15 19:04:42 +08:00
parent 7965f7dbfe
commit 3b242b908d

View File

@@ -0,0 +1,651 @@
import os
import tensorflow as tf
import numpy as np
import time
import json
import pickle
import logging
import pathlib
import sys
from typing import Dict, Any, Tuple, Optional, List
from omegaconf import OmegaConf
from rnn_model_tf import (
TripleGRUDecoder,
CTCLoss,
create_tpu_strategy,
build_model_for_tpu,
configure_mixed_precision
)
from dataset_tf import (
BrainToTextDatasetTF,
DataAugmentationTF,
train_test_split_indices,
create_input_fn
)
class BrainToTextDecoderTrainerTF:
"""
TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
This trainer implements the same training logic as the PyTorch version but uses
TensorFlow operations optimized for TPU hardware.
"""
def __init__(self, args: Dict[str, Any]):
"""
Initialize the TensorFlow trainer
Args:
args: Configuration dictionary containing all training parameters
"""
self.args = args
self.logger = None
# Initialize TPU strategy
self.strategy = create_tpu_strategy()
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
# Configure mixed precision for TPU v5e-8
if args.get('use_amp', True):
configure_mixed_precision()
self.mixed_precision = True
else:
self.mixed_precision = False
# Initialize tracking variables
self.best_val_per = float('inf')
self.best_val_loss = float('inf')
# Setup directories
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=True)
if (args.get('save_best_checkpoint', True) or
args.get('save_all_val_steps', False) or
args.get('save_final_model', False)):
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
# Setup logging
self._setup_logging()
# Set random seeds
if self.args['seed'] != -1:
tf.random.set_seed(self.args['seed'])
np.random.seed(self.args['seed'])
# Initialize datasets
self._initialize_datasets()
# Build model within strategy scope
with self.strategy.scope():
self.model = self._build_model()
self.optimizer = self._create_optimizer()
self.lr_scheduler = self._create_lr_scheduler()
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
# Log model information
self._log_model_info()
# Adversarial training configuration
adv_cfg = self.args.get('adversarial', {})
self.adv_enabled = adv_cfg.get('enabled', False)
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5))
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2))
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
if self.adv_enabled:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
f"noise_l2_weight={self.adv_noise_l2_weight}, "
f"warmup_steps={self.adv_warmup_steps}")
def _setup_logging(self):
"""Setup logging configuration"""
self.logger = logging.getLogger(__name__)
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
if self.args['mode'] == 'train':
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log')))
fh.setFormatter(formatter)
self.logger.addHandler(fh)
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
self.logger.addHandler(sh)
self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas')
if self.mixed_precision:
self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
def _initialize_datasets(self):
"""Initialize training and validation datasets"""
# Create file paths
train_file_paths = [
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
for s in self.args['dataset']['sessions']
]
val_file_paths = [
os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5')
for s in self.args['dataset']['sessions']
]
# Validate no duplicates
if len(set(train_file_paths)) != len(train_file_paths):
raise ValueError("Duplicate sessions in train dataset")
if len(set(val_file_paths)) != len(val_file_paths):
raise ValueError("Duplicate sessions in val dataset")
# Split trials
train_trials, _ = train_test_split_indices(
file_paths=train_file_paths,
test_percentage=0,
seed=self.args['dataset']['seed'],
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
)
_, val_trials = train_test_split_indices(
file_paths=val_file_paths,
test_percentage=1,
seed=self.args['dataset']['seed'],
bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
)
# Save trial splits
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
json.dump({'train': train_trials, 'val': val_trials}, f)
# Create TensorFlow datasets
self.train_dataset_tf = BrainToTextDatasetTF(
trial_indices=train_trials,
n_batches=self.args['num_training_batches'],
split='train',
batch_size=self.args['dataset']['batch_size'],
days_per_batch=self.args['dataset']['days_per_batch'],
random_seed=self.args['dataset']['seed'],
must_include_days=self.args['dataset'].get('must_include_days'),
feature_subset=self.args['dataset'].get('feature_subset')
)
self.val_dataset_tf = BrainToTextDatasetTF(
trial_indices=val_trials,
n_batches=None, # Use all validation data
split='test',
batch_size=self.args['dataset']['batch_size'],
days_per_batch=1, # One day per validation batch
random_seed=self.args['dataset']['seed'],
feature_subset=self.args['dataset'].get('feature_subset')
)
self.logger.info("Successfully initialized TensorFlow datasets")
def _build_model(self) -> TripleGRUDecoder:
"""Build the TripleGRUDecoder model"""
model = TripleGRUDecoder(
neural_dim=self.args['model']['n_input_features'],
n_units=self.args['model']['n_units'],
n_days=len(self.args['dataset']['sessions']),
n_classes=self.args['dataset']['n_classes'],
rnn_dropout=self.args['model']['rnn_dropout'],
input_dropout=self.args['model']['input_network']['input_layer_dropout'],
patch_size=self.args['model']['patch_size'],
patch_stride=self.args['model']['patch_stride']
)
return model
def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Create AdamW optimizer with parameter groups"""
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
# We'll use a single optimizer and handle different learning rates in the scheduler
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=self.args['weight_decay']
)
return optimizer
def _create_lr_scheduler(self):
"""Create learning rate scheduler"""
if self.args['lr_scheduler_type'] == 'cosine':
return self._create_cosine_scheduler()
elif self.args['lr_scheduler_type'] == 'linear':
return tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=self.args['lr_max'],
decay_steps=self.args['lr_decay_steps'],
end_learning_rate=self.args['lr_min'],
power=1.0 # Linear decay
)
else:
raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}")
def _create_cosine_scheduler(self):
"""Create cosine learning rate scheduler"""
return tf.keras.optimizers.schedules.CosineDecayRestarts(
initial_learning_rate=self.args['lr_max'],
first_decay_steps=self.args['lr_decay_steps'],
t_mul=1.0,
m_mul=1.0,
alpha=self.args['lr_min'] / self.args['lr_max']
)
def _log_model_info(self):
"""Log model architecture and parameter information"""
self.logger.info("Initialized TripleGRUDecoder model")
# Build the model by calling it once with dummy data
dummy_batch_size = 2
dummy_time_steps = 100
dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
# Call the model to build it
_ = self.model(dummy_features, dummy_day_idx, training=False)
# Count parameters
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
self.logger.info(f"Model has {total_params:,} trainable parameters")
@tf.function
def _train_step(self, batch, step):
"""Single training step with gradient tape"""
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indices = batch['day_indices']
with tf.GradientTape() as tape:
# Apply data transformations
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, self.args['dataset']['data_transforms'], training=True
)
# Calculate adjusted lengths for CTC
adjusted_lens = tf.cast(
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
self.args['model']['patch_stride'] + 1,
tf.int32
)
# Forward pass
use_full = self.adv_enabled and (step >= self.adv_warmup_steps)
if use_full:
clean_logits, noisy_logits, noise_output = self.model(
features, day_indices, None, False, 'full',
grl_lambda=self.adv_grl_lambda, training=True
)
else:
clean_logits = self.model(
features, day_indices, None, False, 'inference', training=True
)
# Calculate losses
if use_full:
# Clean CTC loss
clean_loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
clean_loss = self.ctc_loss(clean_loss_input, clean_logits)
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss
noisy_loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits)
noisy_loss = tf.reduce_mean(noisy_loss)
# Optional noise L2 regularization
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
if self.adv_noise_l2_weight > 0.0:
noise_l2 = tf.reduce_mean(tf.square(noise_output))
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
loss = self.ctc_loss(loss_input, clean_logits)
loss = tf.reduce_mean(loss)
# Scale loss for mixed precision
if self.mixed_precision:
scaled_loss = self.optimizer.get_scaled_loss(loss)
else:
scaled_loss = loss
# Calculate gradients
if self.mixed_precision:
scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
else:
gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
# Clip gradients
if self.args['grad_norm_clip_value'] > 0:
gradients, grad_norm = tf.clip_by_global_norm(
gradients, self.args['grad_norm_clip_value']
)
else:
grad_norm = tf.global_norm(gradients)
# Apply gradients
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
return loss, grad_norm
@tf.function
def _validation_step(self, batch):
"""Single validation step"""
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indices = batch['day_indices']
# Apply data transformations (no augmentation for validation)
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
)
# Calculate adjusted lengths
adjusted_lens = tf.cast(
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
self.args['model']['patch_stride'] + 1,
tf.int32
)
# Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Calculate loss
loss_input = {
'labels': labels,
'input_lengths': adjusted_lens,
'label_lengths': phone_seq_lens
}
loss = self.ctc_loss(loss_input, logits)
loss = tf.reduce_mean(loss)
# Calculate PER (Phoneme Error Rate)
# Greedy decoding
predicted_ids = tf.argmax(logits, axis=-1)
# Remove blanks and consecutive duplicates
batch_edit_distance = 0
for i in range(tf.shape(logits)[0]):
pred_seq = predicted_ids[i, :adjusted_lens[i]]
# Remove consecutive duplicates
pred_seq = tf.py_function(
func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]),
inp=[pred_seq],
Tout=tf.int64
)
# Remove blanks (assuming blank_index=0)
pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0)
true_seq = labels[i, :phone_seq_lens[i]]
# Calculate edit distance
edit_dist = tf.edit_distance(
tf.SparseTensor(
indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1),
values=tf.cast(pred_seq, tf.int64),
dense_shape=[tf.size(pred_seq)]
),
tf.SparseTensor(
indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1),
values=tf.cast(true_seq, tf.int64),
dense_shape=[tf.size(true_seq)]
),
normalize=False
)
batch_edit_distance += edit_dist
return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens)
def train(self) -> Dict[str, Any]:
"""Main training loop"""
self.logger.info("Starting training loop...")
# Create distributed datasets
train_dataset = create_input_fn(
self.train_dataset_tf,
self.args['dataset']['data_transforms'],
training=True
)
val_dataset = create_input_fn(
self.val_dataset_tf,
self.args['dataset']['data_transforms'],
training=False
)
# Distribute datasets
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
# Training metrics
train_losses = []
val_losses = []
val_pers = []
val_results = []
val_steps_since_improvement = 0
train_start_time = time.time()
# Training loop
step = 0
for batch in train_dist_dataset:
if step >= self.args['num_training_batches']:
break
start_time = time.time()
# Distributed training step
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
# Reduce across replicas
loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)
train_step_duration = time.time() - start_time
train_losses.append(float(loss.numpy()))
# Log training progress
if step % self.args['batches_per_train_log'] == 0:
self.logger.info(f'Train batch {step}: '
f'loss: {float(loss.numpy()):.2f} '
f'grad norm: {float(grad_norm.numpy()):.2f} '
f'time: {train_step_duration:.3f}')
# Validation step
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
self.logger.info(f"Running validation after training batch: {step}")
val_start_time = time.time()
val_metrics = self._validate(val_dist_dataset)
val_step_duration = time.time() - val_start_time
self.logger.info(f'Val batch {step}: '
f'PER (avg): {val_metrics["avg_per"]:.4f} '
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
f'time: {val_step_duration:.3f}')
val_pers.append(val_metrics['avg_per'])
val_losses.append(val_metrics['avg_loss'])
val_results.append(val_metrics)
# Check for improvement
new_best = False
if val_metrics['avg_per'] < self.best_val_per:
self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}")
self.best_val_per = val_metrics['avg_per']
self.best_val_loss = val_metrics['avg_loss']
new_best = True
elif (val_metrics['avg_per'] == self.best_val_per and
val_metrics['avg_loss'] < self.best_val_loss):
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
self.best_val_loss = val_metrics['avg_loss']
new_best = True
if new_best:
if self.args.get('save_best_checkpoint', True):
self.logger.info("Checkpointing model")
self._save_checkpoint('best_checkpoint', step)
if self.args.get('save_val_metrics', True):
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
pickle.dump(val_metrics, f)
val_steps_since_improvement = 0
else:
val_steps_since_improvement += 1
# Optional save all validation checkpoints
if self.args.get('save_all_val_steps', False):
self._save_checkpoint(f'checkpoint_batch_{step}', step)
# Early stopping
if (self.args.get('early_stopping', False) and
val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)):
self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} '
f'validation steps. Stopping training early at batch: {step}')
break
step += 1
# Training completed
training_duration = time.time() - train_start_time
self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}')
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
# Save final model
if self.args.get('save_final_model', False):
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1)
return {
'train_losses': train_losses,
'val_losses': val_losses,
'val_pers': val_pers,
'val_metrics': val_results
}
def _validate(self, val_dataset) -> Dict[str, Any]:
"""Run validation on entire validation dataset"""
total_loss = 0.0
total_edit_distance = 0
total_seq_length = 0
num_batches = 0
for batch in val_dataset:
per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = (
self.strategy.run(self._validation_step, args=(batch,))
)
# Reduce across replicas
batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None)
batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None)
total_loss += float(batch_loss.numpy())
total_edit_distance += float(batch_edit_distance.numpy())
total_seq_length += float(batch_seq_length.numpy())
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
avg_per = total_edit_distance / max(total_seq_length, 1e-6)
return {
'avg_loss': avg_loss,
'avg_per': avg_per,
'total_edit_distance': total_edit_distance,
'total_seq_length': total_seq_length,
'num_batches': num_batches
}
def _save_checkpoint(self, name: str, step: int):
"""Save model checkpoint"""
checkpoint_path = os.path.join(self.args['checkpoint_dir'], name)
# Save model weights
self.model.save_weights(checkpoint_path + '.weights.h5')
# Save optimizer state
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
optimizer_checkpoint.save(checkpoint_path + '.optimizer')
# Save training state
state = {
'step': step,
'best_val_per': float(self.best_val_per),
'best_val_loss': float(self.best_val_loss)
}
with open(checkpoint_path + '.state.json', 'w') as f:
json.dump(state, f)
# Save config
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
self.logger.info(f"Saved checkpoint: {checkpoint_path}")
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint"""
# Load model weights
self.model.load_weights(checkpoint_path + '.weights.h5')
# Load optimizer state
optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1')
# Load training state
with open(checkpoint_path + '.state.json', 'r') as f:
state = json.load(f)
self.best_val_per = state['best_val_per']
self.best_val_loss = state['best_val_loss']
self.logger.info(f"Loaded checkpoint: {checkpoint_path}")
def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
"""
Run inference on input features
Args:
features: Input neural features [batch_size, time_steps, features]
day_indices: Day indices [batch_size]
n_time_steps: Number of valid time steps [batch_size]
mode: 'inference' or 'full'
Returns:
Phoneme logits [batch_size, time_steps, n_classes]
"""
# Apply data transformations (no augmentation)
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
)
# Run model inference
logits = self.model(features, day_indices, None, False, mode, training=False)
return logits