2025-10-15 16:55:52 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
"""
|
|
|
|
TensorFlow Training Script for Brain-to-Text RNN Model
|
|
|
|
Optimized for TPU v5e-8
|
|
|
|
|
|
|
|
This script trains the TripleGRUDecoder model using TensorFlow and TPU hardware.
|
|
|
|
It provides the same functionality as the PyTorch version but with TensorFlow
|
|
|
|
operations optimized for TPU performance.
|
|
|
|
|
|
|
|
Usage:
|
2025-10-20 11:05:03 +08:00
|
|
|
python train_model_tf.py -config_path rnn_args.yaml
|
2025-10-15 16:55:52 +08:00
|
|
|
|
|
|
|
Requirements:
|
|
|
|
- TensorFlow >= 2.15.0
|
|
|
|
- TPU v5e-8 environment
|
|
|
|
- Access to brain-to-text HDF5 dataset
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import logging
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
|
|
# Add the current directory to Python path for imports
|
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
|
|
|
from trainer_tf import BrainToTextDecoderTrainerTF
|
|
|
|
|
|
|
|
|
|
|
|
def setup_tpu_environment():
|
|
|
|
"""Setup TPU environment variables for optimal performance"""
|
|
|
|
# TPU v5e-8 optimizations
|
|
|
|
os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') # Enable XLA optimizations
|
|
|
|
os.environ.setdefault('XLA_USE_BF16', '1') # Enable bfloat16 for memory efficiency
|
|
|
|
os.environ.setdefault('TF_XLA_FLAGS', '--tf_xla_auto_jit=2') # Enable XLA JIT compilation
|
|
|
|
|
|
|
|
# TPU memory optimizations
|
|
|
|
os.environ.setdefault('TPU_MEGACORE', '1') # Enable megacore mode for larger models
|
|
|
|
os.environ.setdefault('LIBTPU_INIT_ARGS', '--xla_tpu_spmd_threshold_for_allgather_cse=10000')
|
|
|
|
|
|
|
|
# Disable warnings for cleaner output
|
|
|
|
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
|
|
|
|
|
|
|
|
print("TPU environment configured for v5e-8 optimizations")
|
|
|
|
|
|
|
|
|
|
|
|
def validate_config(config):
|
|
|
|
"""Validate configuration for TensorFlow TPU training"""
|
|
|
|
required_fields = [
|
|
|
|
'model.n_input_features',
|
|
|
|
'model.n_units',
|
|
|
|
'dataset.sessions',
|
|
|
|
'dataset.n_classes',
|
|
|
|
'num_training_batches',
|
|
|
|
'output_dir',
|
|
|
|
'checkpoint_dir'
|
|
|
|
]
|
|
|
|
|
|
|
|
for field in required_fields:
|
|
|
|
keys = field.split('.')
|
|
|
|
value = config
|
|
|
|
try:
|
|
|
|
for key in keys:
|
|
|
|
value = value[key]
|
|
|
|
except KeyError:
|
|
|
|
raise ValueError(f"Missing required configuration field: {field}")
|
|
|
|
|
|
|
|
# TPU-specific validations
|
|
|
|
if config.get('use_tpu', True):
|
|
|
|
if config['dataset']['batch_size'] < 8:
|
|
|
|
logging.warning("Small batch size may not utilize TPU efficiently. Consider batch_size >= 32")
|
|
|
|
|
|
|
|
if not config.get('use_amp', True):
|
|
|
|
logging.warning("Mixed precision disabled. Consider enabling for better TPU performance")
|
|
|
|
|
|
|
|
# Dataset validation
|
|
|
|
dataset_dir = config['dataset']['dataset_dir']
|
|
|
|
if not os.path.exists(dataset_dir):
|
|
|
|
raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
|
|
|
|
|
|
|
|
# Check if at least one session file exists
|
|
|
|
session_found = False
|
|
|
|
for session in config['dataset']['sessions']:
|
|
|
|
train_path = os.path.join(dataset_dir, session, 'data_train.hdf5')
|
|
|
|
if os.path.exists(train_path):
|
|
|
|
session_found = True
|
|
|
|
break
|
|
|
|
|
|
|
|
if not session_found:
|
|
|
|
raise FileNotFoundError("No valid session data files found in dataset directory")
|
|
|
|
|
|
|
|
print("Configuration validation passed")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
"""Main training function"""
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description='Train Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
|
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--config_path',
|
|
|
|
default='rnn_args.yaml',
|
|
|
|
help='Path to configuration YAML file'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--output_dir',
|
|
|
|
default=None,
|
|
|
|
help='Override output directory from config'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--checkpoint_dir',
|
|
|
|
default=None,
|
|
|
|
help='Override checkpoint directory from config'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--resume_from',
|
|
|
|
default=None,
|
|
|
|
help='Path to checkpoint to resume training from'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--num_batches',
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help='Override number of training batches'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--batch_size',
|
|
|
|
type=int,
|
|
|
|
default=None,
|
|
|
|
help='Override batch size'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--mixed_precision',
|
|
|
|
action='store_true',
|
|
|
|
default=None,
|
|
|
|
help='Enable mixed precision training (bfloat16)'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--disable_mixed_precision',
|
|
|
|
action='store_true',
|
|
|
|
help='Disable mixed precision training'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--validate_only',
|
|
|
|
action='store_true',
|
|
|
|
help='Only run validation, do not train'
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--debug',
|
|
|
|
action='store_true',
|
|
|
|
help='Enable debug logging'
|
|
|
|
)
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# Setup logging
|
|
|
|
log_level = logging.DEBUG if args.debug else logging.INFO
|
|
|
|
logging.basicConfig(
|
|
|
|
level=log_level,
|
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
|
|
)
|
|
|
|
|
|
|
|
# Setup TPU environment
|
|
|
|
setup_tpu_environment()
|
|
|
|
|
|
|
|
# Load configuration
|
|
|
|
if not os.path.exists(args.config_path):
|
|
|
|
raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
|
|
|
|
|
|
|
|
config = OmegaConf.load(args.config_path)
|
|
|
|
print(f"Loaded configuration from: {args.config_path}")
|
|
|
|
|
|
|
|
# Apply command line overrides
|
|
|
|
if args.output_dir:
|
|
|
|
config.output_dir = args.output_dir
|
|
|
|
if args.checkpoint_dir:
|
|
|
|
config.checkpoint_dir = args.checkpoint_dir
|
|
|
|
if args.num_batches:
|
|
|
|
config.num_training_batches = args.num_batches
|
|
|
|
if args.batch_size:
|
|
|
|
config.dataset.batch_size = args.batch_size
|
|
|
|
if args.mixed_precision:
|
|
|
|
config.use_amp = True
|
|
|
|
if args.disable_mixed_precision:
|
|
|
|
config.use_amp = False
|
|
|
|
|
|
|
|
# Validate configuration
|
|
|
|
validate_config(config)
|
|
|
|
|
|
|
|
try:
|
|
|
|
# Initialize trainer
|
|
|
|
print("Initializing TensorFlow Brain-to-Text trainer...")
|
|
|
|
trainer = BrainToTextDecoderTrainerTF(config)
|
|
|
|
|
|
|
|
# Load checkpoint if specified
|
|
|
|
if args.resume_from:
|
|
|
|
if os.path.exists(args.resume_from + '.weights.h5'):
|
|
|
|
trainer.load_checkpoint(args.resume_from)
|
|
|
|
print(f"Resumed training from checkpoint: {args.resume_from}")
|
|
|
|
else:
|
|
|
|
raise FileNotFoundError(f"Checkpoint not found: {args.resume_from}")
|
|
|
|
|
|
|
|
if args.validate_only:
|
|
|
|
print("Running validation only...")
|
|
|
|
# Create validation dataset
|
|
|
|
from dataset_tf import create_input_fn
|
|
|
|
val_dataset = create_input_fn(
|
|
|
|
trainer.val_dataset_tf,
|
|
|
|
trainer.args['dataset']['data_transforms'],
|
|
|
|
training=False
|
|
|
|
)
|
|
|
|
val_dist_dataset = trainer.strategy.experimental_distribute_dataset(val_dataset)
|
|
|
|
|
|
|
|
# Run validation
|
|
|
|
val_metrics = trainer._validate(val_dist_dataset)
|
|
|
|
|
|
|
|
print(f"Validation Results:")
|
|
|
|
print(f" Average Loss: {val_metrics['avg_loss']:.4f}")
|
|
|
|
print(f" Average PER: {val_metrics['avg_per']:.4f}")
|
|
|
|
print(f" Total Edit Distance: {val_metrics['total_edit_distance']}")
|
|
|
|
print(f" Total Sequence Length: {val_metrics['total_seq_length']}")
|
|
|
|
|
|
|
|
else:
|
|
|
|
# Start training
|
|
|
|
print("Starting training...")
|
|
|
|
train_stats = trainer.train()
|
|
|
|
|
|
|
|
print("\nTraining completed successfully!")
|
|
|
|
print(f"Best validation PER: {trainer.best_val_per:.5f}")
|
|
|
|
print(f"Final training loss: {train_stats['train_losses'][-1]:.4f}")
|
|
|
|
print(f"Final validation loss: {train_stats['val_losses'][-1]:.4f}")
|
|
|
|
print(f"Total training batches: {len(train_stats['train_losses'])}")
|
|
|
|
|
|
|
|
# Save final training statistics
|
|
|
|
import pickle
|
|
|
|
stats_path = os.path.join(config.output_dir, 'training_stats.pkl')
|
|
|
|
with open(stats_path, 'wb') as f:
|
|
|
|
pickle.dump(train_stats, f)
|
|
|
|
print(f"Training statistics saved to: {stats_path}")
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
print("\nTraining interrupted by user")
|
|
|
|
sys.exit(1)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"\nTraining failed with error: {e}")
|
|
|
|
if args.debug:
|
|
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|