Files
b2txt25/model_training_nnn_tpu/train_model_tf.py
Zchen f8fb4d7133 Remove setup script, TPU memory monitor, and training model script
- Deleted `setup_tensorflow_tpu.sh` which was responsible for setting up the TensorFlow environment on TPU v5e-8.
- Removed `tpu_memory_monitor.py`, a tool for monitoring TPU memory usage during training.
- Eliminated `train_model.py`, the script for training the Brain-to-Text RNN model.
2025-10-20 11:05:03 +08:00

265 lines
8.2 KiB
Python

#!/usr/bin/env python3
"""
TensorFlow Training Script for Brain-to-Text RNN Model
Optimized for TPU v5e-8
This script trains the TripleGRUDecoder model using TensorFlow and TPU hardware.
It provides the same functionality as the PyTorch version but with TensorFlow
operations optimized for TPU performance.
Usage:
python train_model_tf.py -config_path rnn_args.yaml
Requirements:
- TensorFlow >= 2.15.0
- TPU v5e-8 environment
- Access to brain-to-text HDF5 dataset
"""
import argparse
import os
import sys
import logging
from omegaconf import OmegaConf
# Add the current directory to Python path for imports
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from trainer_tf import BrainToTextDecoderTrainerTF
def setup_tpu_environment():
"""Setup TPU environment variables for optimal performance"""
# TPU v5e-8 optimizations
os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') # Enable XLA optimizations
os.environ.setdefault('XLA_USE_BF16', '1') # Enable bfloat16 for memory efficiency
os.environ.setdefault('TF_XLA_FLAGS', '--tf_xla_auto_jit=2') # Enable XLA JIT compilation
# TPU memory optimizations
os.environ.setdefault('TPU_MEGACORE', '1') # Enable megacore mode for larger models
os.environ.setdefault('LIBTPU_INIT_ARGS', '--xla_tpu_spmd_threshold_for_allgather_cse=10000')
# Disable warnings for cleaner output
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
print("TPU environment configured for v5e-8 optimizations")
def validate_config(config):
"""Validate configuration for TensorFlow TPU training"""
required_fields = [
'model.n_input_features',
'model.n_units',
'dataset.sessions',
'dataset.n_classes',
'num_training_batches',
'output_dir',
'checkpoint_dir'
]
for field in required_fields:
keys = field.split('.')
value = config
try:
for key in keys:
value = value[key]
except KeyError:
raise ValueError(f"Missing required configuration field: {field}")
# TPU-specific validations
if config.get('use_tpu', True):
if config['dataset']['batch_size'] < 8:
logging.warning("Small batch size may not utilize TPU efficiently. Consider batch_size >= 32")
if not config.get('use_amp', True):
logging.warning("Mixed precision disabled. Consider enabling for better TPU performance")
# Dataset validation
dataset_dir = config['dataset']['dataset_dir']
if not os.path.exists(dataset_dir):
raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
# Check if at least one session file exists
session_found = False
for session in config['dataset']['sessions']:
train_path = os.path.join(dataset_dir, session, 'data_train.hdf5')
if os.path.exists(train_path):
session_found = True
break
if not session_found:
raise FileNotFoundError("No valid session data files found in dataset directory")
print("Configuration validation passed")
def main():
"""Main training function"""
parser = argparse.ArgumentParser(
description='Train Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'--config_path',
default='rnn_args.yaml',
help='Path to configuration YAML file'
)
parser.add_argument(
'--output_dir',
default=None,
help='Override output directory from config'
)
parser.add_argument(
'--checkpoint_dir',
default=None,
help='Override checkpoint directory from config'
)
parser.add_argument(
'--resume_from',
default=None,
help='Path to checkpoint to resume training from'
)
parser.add_argument(
'--num_batches',
type=int,
default=None,
help='Override number of training batches'
)
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='Override batch size'
)
parser.add_argument(
'--mixed_precision',
action='store_true',
default=None,
help='Enable mixed precision training (bfloat16)'
)
parser.add_argument(
'--disable_mixed_precision',
action='store_true',
help='Disable mixed precision training'
)
parser.add_argument(
'--validate_only',
action='store_true',
help='Only run validation, do not train'
)
parser.add_argument(
'--debug',
action='store_true',
help='Enable debug logging'
)
args = parser.parse_args()
# Setup logging
log_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Setup TPU environment
setup_tpu_environment()
# Load configuration
if not os.path.exists(args.config_path):
raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
config = OmegaConf.load(args.config_path)
print(f"Loaded configuration from: {args.config_path}")
# Apply command line overrides
if args.output_dir:
config.output_dir = args.output_dir
if args.checkpoint_dir:
config.checkpoint_dir = args.checkpoint_dir
if args.num_batches:
config.num_training_batches = args.num_batches
if args.batch_size:
config.dataset.batch_size = args.batch_size
if args.mixed_precision:
config.use_amp = True
if args.disable_mixed_precision:
config.use_amp = False
# Validate configuration
validate_config(config)
try:
# Initialize trainer
print("Initializing TensorFlow Brain-to-Text trainer...")
trainer = BrainToTextDecoderTrainerTF(config)
# Load checkpoint if specified
if args.resume_from:
if os.path.exists(args.resume_from + '.weights.h5'):
trainer.load_checkpoint(args.resume_from)
print(f"Resumed training from checkpoint: {args.resume_from}")
else:
raise FileNotFoundError(f"Checkpoint not found: {args.resume_from}")
if args.validate_only:
print("Running validation only...")
# Create validation dataset
from dataset_tf import create_input_fn
val_dataset = create_input_fn(
trainer.val_dataset_tf,
trainer.args['dataset']['data_transforms'],
training=False
)
val_dist_dataset = trainer.strategy.experimental_distribute_dataset(val_dataset)
# Run validation
val_metrics = trainer._validate(val_dist_dataset)
print(f"Validation Results:")
print(f" Average Loss: {val_metrics['avg_loss']:.4f}")
print(f" Average PER: {val_metrics['avg_per']:.4f}")
print(f" Total Edit Distance: {val_metrics['total_edit_distance']}")
print(f" Total Sequence Length: {val_metrics['total_seq_length']}")
else:
# Start training
print("Starting training...")
train_stats = trainer.train()
print("\nTraining completed successfully!")
print(f"Best validation PER: {trainer.best_val_per:.5f}")
print(f"Final training loss: {train_stats['train_losses'][-1]:.4f}")
print(f"Final validation loss: {train_stats['val_losses'][-1]:.4f}")
print(f"Total training batches: {len(train_stats['train_losses'])}")
# Save final training statistics
import pickle
stats_path = os.path.join(config.output_dir, 'training_stats.pkl')
with open(stats_path, 'wb') as f:
pickle.dump(train_stats, f)
print(f"Training statistics saved to: {stats_path}")
except KeyboardInterrupt:
print("\nTraining interrupted by user")
sys.exit(1)
except Exception as e:
print(f"\nTraining failed with error: {e}")
if args.debug:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()