tpu支持
This commit is contained in:
204
model_training_nnn/TPU_SETUP_GUIDE.md
Normal file
204
model_training_nnn/TPU_SETUP_GUIDE.md
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
# TPU Training Setup Guide for Brain-to-Text RNN
|
||||||
|
|
||||||
|
This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
### 1. Install PyTorch XLA for TPU Support
|
||||||
|
```bash
|
||||||
|
# Install PyTorch XLA (adjust version as needed)
|
||||||
|
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
|
||||||
|
# Or for specific PyTorch version:
|
||||||
|
pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Install Accelerate Library
|
||||||
|
```bash
|
||||||
|
pip install accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Verify TPU Access
|
||||||
|
```bash
|
||||||
|
# Check if TPU is available
|
||||||
|
python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Setup
|
||||||
|
|
||||||
|
### 1. Enable TPU in Configuration File
|
||||||
|
|
||||||
|
Update your `rnn_args.yaml` file with TPU settings:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# TPU and distributed training settings
|
||||||
|
use_tpu: true # Enable TPU training
|
||||||
|
num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8)
|
||||||
|
gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size
|
||||||
|
dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues
|
||||||
|
use_amp: true # Enable mixed precision (bfloat16)
|
||||||
|
|
||||||
|
# Adjust batch size for multi-core TPU
|
||||||
|
dataset:
|
||||||
|
batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. TPU-Optimized Hyperparameters
|
||||||
|
|
||||||
|
Recommended adjustments for TPU training:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Learning rate scaling for distributed training
|
||||||
|
lr_max: 0.005 # May need to scale with number of cores
|
||||||
|
lr_max_day: 0.005
|
||||||
|
|
||||||
|
# Batch size considerations
|
||||||
|
dataset:
|
||||||
|
batch_size: 8 # Per-core batch size
|
||||||
|
days_per_batch: 4 # Keep consistent across cores
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Launch Options
|
||||||
|
|
||||||
|
### Method 1: Using the TPU Launch Script (Recommended)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic TPU training with 8 cores
|
||||||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||||
|
|
||||||
|
# Check TPU environment only
|
||||||
|
python launch_tpu_training.py --check_only
|
||||||
|
|
||||||
|
# Custom configuration file
|
||||||
|
python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Method 2: Direct Accelerate Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Configure accelerate (one-time setup)
|
||||||
|
accelerate config
|
||||||
|
|
||||||
|
# Or use provided TPU config
|
||||||
|
export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
|
||||||
|
|
||||||
|
# Launch training
|
||||||
|
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Method 3: Manual XLA Launch (Advanced)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set TPU environment variables
|
||||||
|
export TPU_CORES=8
|
||||||
|
export XLA_USE_BF16=1
|
||||||
|
|
||||||
|
# Launch with PyTorch XLA
|
||||||
|
python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key TPU Features Implemented
|
||||||
|
|
||||||
|
### 1. Distributed Training Support
|
||||||
|
- Automatic model parallelization across 8 TPU cores
|
||||||
|
- Synchronized gradient updates across all cores
|
||||||
|
- Proper checkpoint saving/loading for distributed training
|
||||||
|
|
||||||
|
### 2. Mixed Precision Training
|
||||||
|
- Automatic bfloat16 precision for TPU optimization
|
||||||
|
- Faster training with maintained numerical stability
|
||||||
|
- Reduced memory usage
|
||||||
|
|
||||||
|
### 3. TPU-Optimized Data Loading
|
||||||
|
- Single-threaded data loading (num_workers=0) for TPU compatibility
|
||||||
|
- Automatic data distribution across TPU cores
|
||||||
|
- Efficient batch processing
|
||||||
|
|
||||||
|
### 4. Inference Support
|
||||||
|
- TPU-compatible inference methods added to trainer class
|
||||||
|
- `inference()` and `inference_batch()` methods for production use
|
||||||
|
- Automatic mixed precision during inference
|
||||||
|
|
||||||
|
## Performance Optimization Tips
|
||||||
|
|
||||||
|
### 1. Batch Size Tuning
|
||||||
|
- Start with total batch size = 64 (8 cores × 8 per core)
|
||||||
|
- Increase gradually if memory allows
|
||||||
|
- Monitor TPU utilization with `top` command
|
||||||
|
|
||||||
|
### 2. Gradient Accumulation
|
||||||
|
- Use `gradient_accumulation_steps` to simulate larger batch sizes
|
||||||
|
- Effective batch size = batch_size × num_cores × gradient_accumulation_steps
|
||||||
|
|
||||||
|
### 3. Learning Rate Scaling
|
||||||
|
- Consider scaling learning rate with number of cores
|
||||||
|
- Linear scaling: `lr_new = lr_base × num_cores`
|
||||||
|
- May need warmup adjustment for large batch training
|
||||||
|
|
||||||
|
### 4. Memory Management
|
||||||
|
- TPU v3-8: 128GB HBM memory total
|
||||||
|
- TPU v4-8: 512GB HBM memory total
|
||||||
|
- Monitor memory usage to avoid OOM errors
|
||||||
|
|
||||||
|
## Monitoring and Debugging
|
||||||
|
|
||||||
|
### 1. TPU Utilization
|
||||||
|
```bash
|
||||||
|
# Monitor TPU usage
|
||||||
|
watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Training Logs
|
||||||
|
- Training logs include device information and core count
|
||||||
|
- Monitor validation metrics across all cores
|
||||||
|
- Check for synchronization issues in distributed training
|
||||||
|
|
||||||
|
### 3. Common Issues and Solutions
|
||||||
|
|
||||||
|
**Issue**: "No TPU devices found"
|
||||||
|
- **Solution**: Verify TPU runtime is started and accessible
|
||||||
|
|
||||||
|
**Issue**: "DataLoader workers > 0 causes hangs"
|
||||||
|
- **Solution**: Set `dataloader_num_workers: 0` in config
|
||||||
|
|
||||||
|
**Issue**: "Mixed precision errors"
|
||||||
|
- **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16
|
||||||
|
|
||||||
|
**Issue**: "Gradient synchronization timeouts"
|
||||||
|
- **Solution**: Check network connectivity between TPU cores
|
||||||
|
|
||||||
|
## Example Training Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Complete TPU training example
|
||||||
|
cd model_training_nnn
|
||||||
|
|
||||||
|
# 1. Update config for TPU
|
||||||
|
vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8
|
||||||
|
|
||||||
|
# 2. Launch TPU training
|
||||||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||||
|
|
||||||
|
# 3. Monitor training progress
|
||||||
|
tail -f trained_models/baseline_rnn/training_log
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Reference
|
||||||
|
|
||||||
|
### Required TPU Settings
|
||||||
|
```yaml
|
||||||
|
use_tpu: true
|
||||||
|
num_tpu_cores: 8
|
||||||
|
dataloader_num_workers: 0
|
||||||
|
use_amp: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Optional TPU Optimizations
|
||||||
|
```yaml
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
dataset:
|
||||||
|
batch_size: 8 # Per-core batch size
|
||||||
|
mixed_precision: bf16
|
||||||
|
```
|
||||||
|
|
||||||
|
This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.
|
26
model_training_nnn/accelerate_config_tpu.yaml
Normal file
26
model_training_nnn/accelerate_config_tpu.yaml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Accelerate Configuration for TPU Training
|
||||||
|
# This file configures Accelerate library for 8-core TPU training
|
||||||
|
# with mixed precision (bfloat16) support
|
||||||
|
|
||||||
|
compute_environment: TPU
|
||||||
|
distributed_type: TPU
|
||||||
|
tpu_name: null # Will use default TPU
|
||||||
|
tpu_zone: null # Will use default zone
|
||||||
|
|
||||||
|
# Mixed precision settings (use bfloat16 for TPU)
|
||||||
|
mixed_precision: bf16
|
||||||
|
|
||||||
|
# Number of TPU cores (v3-8 or v4-8 TPUs have 8 cores)
|
||||||
|
num_processes: 8
|
||||||
|
|
||||||
|
# Enable TPU debugging (set to false for production)
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
|
||||||
|
# Logging settings
|
||||||
|
main_process_port: null
|
||||||
|
machine_rank: 0
|
||||||
|
num_machines: 1
|
||||||
|
|
||||||
|
# Enable automatic optimization
|
||||||
|
use_cpu: false
|
126
model_training_nnn/launch_tpu_training.py
Normal file
126
model_training_nnn/launch_tpu_training.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
TPU Training Launch Script for Brain-to-Text RNN Model
|
||||||
|
|
||||||
|
This script provides easy TPU training setup using Accelerate library.
|
||||||
|
Supports both single TPU core and multi-core (8 cores) training.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- PyTorch XLA installed
|
||||||
|
- Accelerate library installed
|
||||||
|
- TPU runtime available
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def update_config_for_tpu(config_path, num_cores=8):
|
||||||
|
"""
|
||||||
|
Update configuration file to enable TPU training
|
||||||
|
"""
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Enable TPU settings
|
||||||
|
config['use_tpu'] = True
|
||||||
|
config['num_tpu_cores'] = num_cores
|
||||||
|
config['dataloader_num_workers'] = 0 # Required for TPU
|
||||||
|
config['use_amp'] = True # Enable mixed precision with bfloat16
|
||||||
|
|
||||||
|
# Adjust batch size and gradient accumulation for multi-core TPU
|
||||||
|
if num_cores > 1:
|
||||||
|
# Distribute batch size across cores
|
||||||
|
original_batch_size = config['dataset']['batch_size']
|
||||||
|
config['dataset']['batch_size'] = max(1, original_batch_size // num_cores)
|
||||||
|
config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1))
|
||||||
|
|
||||||
|
print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core")
|
||||||
|
print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
|
||||||
|
|
||||||
|
# Save updated config
|
||||||
|
tpu_config_path = config_path.replace('.yaml', '_tpu.yaml')
|
||||||
|
with open(tpu_config_path, 'w') as f:
|
||||||
|
yaml.dump(config, f, default_flow_style=False)
|
||||||
|
|
||||||
|
print(f"TPU configuration saved to: {tpu_config_path}")
|
||||||
|
return tpu_config_path
|
||||||
|
|
||||||
|
def check_tpu_environment():
|
||||||
|
"""
|
||||||
|
Check if TPU environment is properly set up
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import torch_xla
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
# Check if TPUs are available
|
||||||
|
device = xm.xla_device()
|
||||||
|
print(f"TPU device available: {device}")
|
||||||
|
print(f"TPU ordinal: {xm.get_ordinal()}")
|
||||||
|
print(f"TPU world size: {xm.xrt_world_size()}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: TPU not available - {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run_tpu_training(config_path, num_cores=8):
|
||||||
|
"""
|
||||||
|
Launch TPU training using accelerate
|
||||||
|
"""
|
||||||
|
# Check TPU environment
|
||||||
|
if not check_tpu_environment():
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Update config for TPU
|
||||||
|
tpu_config_path = update_config_for_tpu(config_path, num_cores)
|
||||||
|
|
||||||
|
# Set TPU environment variables
|
||||||
|
os.environ['TPU_CORES'] = str(num_cores)
|
||||||
|
os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16
|
||||||
|
|
||||||
|
# Launch training with accelerate
|
||||||
|
cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}"
|
||||||
|
|
||||||
|
print(f"Launching TPU training with command:")
|
||||||
|
print(f" {cmd}")
|
||||||
|
print(f"Using {num_cores} TPU cores")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
# Execute training
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN')
|
||||||
|
parser.add_argument('--config', default='rnn_args.yaml',
|
||||||
|
help='Path to configuration file (default: rnn_args.yaml)')
|
||||||
|
parser.add_argument('--num_cores', type=int, default=8,
|
||||||
|
help='Number of TPU cores to use (default: 8)')
|
||||||
|
parser.add_argument('--check_only', action='store_true',
|
||||||
|
help='Only check TPU environment, do not launch training')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Verify config file exists
|
||||||
|
if not os.path.exists(args.config):
|
||||||
|
print(f"ERROR: Configuration file {args.config} not found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.check_only:
|
||||||
|
check_tpu_environment()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run TPU training
|
||||||
|
run_tpu_training(args.config, args.num_cores)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@@ -18,6 +18,12 @@ gpu_number: '1' # GPU number to use for training, formatted as a string (e.g., '
|
|||||||
mode: train
|
mode: train
|
||||||
use_amp: true # whether to use automatic mixed precision (AMP) for training
|
use_amp: true # whether to use automatic mixed precision (AMP) for training
|
||||||
|
|
||||||
|
# TPU and distributed training settings
|
||||||
|
use_tpu: false # whether to use TPU for training (set to true for TPU)
|
||||||
|
num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8)
|
||||||
|
gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training
|
||||||
|
dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
|
||||||
|
|
||||||
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
|
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
|
||||||
checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
|
checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
|
||||||
init_from_checkpoint: false # whether to initialize the model from a checkpoint
|
init_from_checkpoint: false # whether to initialize the model from a checkpoint
|
||||||
|
@@ -182,11 +182,14 @@ class BrainToTextDecoder_Trainer:
|
|||||||
random_seed = self.args['dataset']['seed'],
|
random_seed = self.args['dataset']['seed'],
|
||||||
feature_subset = feature_subset
|
feature_subset = feature_subset
|
||||||
)
|
)
|
||||||
|
# Use TPU-optimized dataloader settings if TPU is enabled
|
||||||
|
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
|
||||||
|
|
||||||
self.train_loader = DataLoader(
|
self.train_loader = DataLoader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||||
shuffle = self.args['dataset']['loader_shuffle'],
|
shuffle = self.args['dataset']['loader_shuffle'],
|
||||||
num_workers = self.args['dataset']['num_dataloader_workers'],
|
num_workers = num_workers,
|
||||||
pin_memory = True
|
pin_memory = True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -205,7 +208,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
self.val_dataset,
|
self.val_dataset,
|
||||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||||
shuffle = False,
|
shuffle = False,
|
||||||
num_workers = 0,
|
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||||
pin_memory = True
|
pin_memory = True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -366,33 +369,35 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
def load_model_checkpoint(self, load_path):
|
def load_model_checkpoint(self, load_path):
|
||||||
'''
|
'''
|
||||||
Load a training checkpoint
|
Load a training checkpoint for distributed training
|
||||||
'''
|
'''
|
||||||
checkpoint = torch.load(load_path, weights_only = False) # checkpoint is just a dict
|
# Load checkpoint on CPU first to avoid OOM issues
|
||||||
|
checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
|
||||||
|
|
||||||
|
# Get unwrapped model for loading state dict
|
||||||
|
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||||
|
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
|
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
|
||||||
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
|
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
|
||||||
|
|
||||||
self.model.to(self.device)
|
# Device handling is managed by Accelerator, no need to manually move to device
|
||||||
|
|
||||||
# Send optimizer params back to GPU
|
|
||||||
for state in self.optimizer.state.values():
|
|
||||||
for k, v in state.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
state[k] = v.to(self.device)
|
|
||||||
|
|
||||||
self.logger.info("Loaded model from checkpoint: " + load_path)
|
self.logger.info("Loaded model from checkpoint: " + load_path)
|
||||||
|
|
||||||
def save_model_checkpoint(self, save_path, PER, loss):
|
def save_model_checkpoint(self, save_path, PER, loss):
|
||||||
'''
|
'''
|
||||||
Save a training checkpoint
|
Save a training checkpoint using Accelerator for distributed training
|
||||||
'''
|
'''
|
||||||
|
# Only save on main process to avoid conflicts
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
# Unwrap model to get base model for saving
|
||||||
|
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
'model_state_dict' : self.model.state_dict(),
|
'model_state_dict' : unwrapped_model.state_dict(),
|
||||||
'optimizer_state_dict' : self.optimizer.state_dict(),
|
'optimizer_state_dict' : self.optimizer.state_dict(),
|
||||||
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
|
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
|
||||||
'val_PER' : PER,
|
'val_PER' : PER,
|
||||||
@@ -407,6 +412,9 @@ class BrainToTextDecoder_Trainer:
|
|||||||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||||||
OmegaConf.save(config=self.args, f=f)
|
OmegaConf.save(config=self.args, f=f)
|
||||||
|
|
||||||
|
# Wait for all processes to complete checkpoint saving
|
||||||
|
self.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
def create_attention_mask(self, sequence_lengths):
|
def create_attention_mask(self, sequence_lengths):
|
||||||
|
|
||||||
max_length = torch.max(sequence_lengths).item()
|
max_length = torch.max(sequence_lengths).item()
|
||||||
@@ -687,11 +695,12 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
for i, batch in enumerate(loader):
|
for i, batch in enumerate(loader):
|
||||||
|
|
||||||
features = batch['input_features'].to(self.device)
|
# Data is automatically moved to device by Accelerator
|
||||||
labels = batch['seq_class_ids'].to(self.device)
|
features = batch['input_features']
|
||||||
n_time_steps = batch['n_time_steps'].to(self.device)
|
labels = batch['seq_class_ids']
|
||||||
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
|
n_time_steps = batch['n_time_steps']
|
||||||
day_indicies = batch['day_indicies'].to(self.device)
|
phone_seq_lens = batch['phone_seq_lens']
|
||||||
|
day_indicies = batch['day_indicies']
|
||||||
|
|
||||||
# Determine if we should perform validation on this batch
|
# Determine if we should perform validation on this batch
|
||||||
day = day_indicies[0].item()
|
day = day_indicies[0].item()
|
||||||
@@ -702,7 +711,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16):
|
with self.accelerator.autocast():
|
||||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
@@ -769,3 +778,43 @@ class BrainToTextDecoder_Trainer:
|
|||||||
metrics['avg_loss'] = np.mean(metrics['losses'])
|
metrics['avg_loss'] = np.mean(metrics['losses'])
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
def inference(self, features, day_indicies, n_time_steps, mode='inference'):
|
||||||
|
'''
|
||||||
|
TPU-compatible inference method for generating phoneme logits
|
||||||
|
'''
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.accelerator.autocast():
|
||||||
|
# Apply data transformations (no augmentation for inference)
|
||||||
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
|
# Get phoneme predictions
|
||||||
|
logits = self.model(features, day_indicies, None, False, mode)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def inference_batch(self, batch, mode='inference'):
|
||||||
|
'''
|
||||||
|
TPU-compatible inference method for processing a full batch
|
||||||
|
'''
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Data is automatically moved to device by Accelerator
|
||||||
|
features = batch['input_features']
|
||||||
|
day_indicies = batch['day_indicies']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.accelerator.autocast():
|
||||||
|
# Apply data transformations (no augmentation for inference)
|
||||||
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
|
# Calculate adjusted sequence lengths for CTC
|
||||||
|
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
|
||||||
|
# Get phoneme predictions
|
||||||
|
logits = self.model(features, day_indicies, None, False, mode)
|
||||||
|
|
||||||
|
return logits, adjusted_lens
|
@@ -1,6 +1,25 @@
|
|||||||
|
import argparse
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from rnn_trainer import BrainToTextDecoder_Trainer
|
from rnn_trainer import BrainToTextDecoder_Trainer
|
||||||
|
|
||||||
args = OmegaConf.load('rnn_args.yaml')
|
def main():
|
||||||
trainer = BrainToTextDecoder_Trainer(args)
|
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
|
||||||
metrics = trainer.train()
|
parser.add_argument('--config_path', default='rnn_args.yaml',
|
||||||
|
help='Path to configuration file (default: rnn_args.yaml)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
config = OmegaConf.load(args.config_path)
|
||||||
|
|
||||||
|
# Initialize trainer
|
||||||
|
trainer = BrainToTextDecoder_Trainer(config)
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
print("Training completed successfully!")
|
||||||
|
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user