This commit is contained in:
Zchen
2025-10-15 14:26:11 +08:00
parent 11ee6ebc51
commit 56fa336af0
23 changed files with 5701 additions and 0 deletions

View File

@@ -0,0 +1,79 @@
# TPU-Optimized Brain-to-Text Model Training
This directory contains TPU-optimized code for training the brain-to-text RNN model with advanced adversarial training architecture. The model is based on "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), enhanced with three-model adversarial training and comprehensive XLA optimizations for efficient TPU training.
## Key Features
- **Triple-Model Adversarial Architecture**: NoiseModel + CleanSpeechModel + NoisySpeechModel for robust neural decoding
- **XLA/TPU Optimizations**: Comprehensive optimizations for fast compilation and efficient TPU utilization
- **Mixed Precision Training**: bfloat16 support with full dtype consistency
- **Distributed Training**: 8-core TPU support with Accelerate library integration
- **687M Parameters**: Large-scale model with patch processing and day-specific adaptations
For detailed technical documentation, see [TPU_MODEL_SUMMARY.md](TPU_MODEL_SUMMARY.md).
## Setup
1. Install the required `b2txt25` conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. See the main [README.md](../README.md) file for more details on the included datasets and the proper `data` directory structure.
## TPU Training
### Triple-Model Adversarial Architecture
This implementation features an advanced three-model adversarial training system:
- **NoiseModel**: 2-layer GRU that estimates noise in neural data
- **CleanSpeechModel**: 3-layer GRU that processes denoised signals for speech recognition
- **NoisySpeechModel**: 2-layer GRU that processes noise signals for adversarial training
The architecture uses residual connections and gradient reversal layers (GRL) to improve robustness. All models include day-specific input layers (512x512 linear with softsign activation), patch processing (14 timesteps), and are optimized for XLA compilation on TPU.
### Training Methods
#### Option 1: Direct Training
```bash
conda activate b2txt25
python train_model.py --config_path rnn_args.yaml
```
#### Option 2: Launcher Script (Recommended)
```bash
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
```
#### Option 3: Accelerate
```bash
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
```
The model trains for 120,000 mini-batches with mixed precision (bfloat16) and distributed training across 8 TPU cores. Expected training time varies based on TPU type and configuration. All hyperparameters are specified in [`rnn_args.yaml`](rnn_args.yaml).
## Model Configuration
### Key Configuration Files
- **`rnn_args.yaml`**: Main training configuration with adversarial training settings
- **`accelerate_config_tpu.yaml`**: Accelerate library configuration for TPU
- **`launch_tpu_training.py`**: Convenient TPU training launcher
### Adversarial Training Settings
```yaml
adversarial:
enabled: true
grl_lambda: 0.5 # Gradient Reversal Layer strength
noisy_loss_weight: 0.2 # Weight for noisy branch CTC loss
noise_l2_weight: 0.0 # L2 regularization on noise output
warmup_steps: 0 # Steps before enabling adversarial training
```
### TPU-Specific Settings
```yaml
use_tpu: true
num_tpu_cores: 8
gradient_accumulation_steps: 2
use_amp: true # bfloat16 mixed precision
batch_size: 32 # Per-core batch size
num_dataloader_workers: 0 # Required for TPU
```
## Evaluation
Model evaluation using the trained TripleGRUDecoder requires the language model pipeline. Please refer to the main project README for complete evaluation setup instructions. The evaluation scripts in this directory are currently being adapted for TPU compatibility.

View File

@@ -0,0 +1,183 @@
# TPU优化的Brain-to-Text模型代码总结
## 项目概述
这个目录包含了专门为TPU训练优化的Brain-to-Text RNN模型代码基于发表在《新英格兰医学杂志》(2024)的"An Accurate and Rapidly Calibrating Speech Neuroprosthesis"论文。该模型将大脑语音运动皮层的神经信号转换为文本使用RNN模型和n-gram语言模型。
## 核心架构改进
### 三模型对抗训练架构 (TripleGRUDecoder)
替代原来的单一GRU模型新架构包含三个协同工作的子模型
1. **NoiseModel** (2层GRU)
- 估计神经数据中的噪声
- 输入维度512 → 输出维度:与输入相同
- 作用:从原始信号中分离噪声成分
2. **CleanSpeechModel** (3层GRU + 分类层)
- 处理去噪后的信号进行语音识别
- 包含day-specific输入层
- 输出41类音素的logits
3. **NoisySpeechModel** (2层GRU + 分类层)
- 直接处理噪声信号进行语音识别
- 用于对抗训练提高NoiseModel的鲁棒性
- 输出41类音素的logits
### 对抗训练机制
- **残差连接**: `denoised_input = x_processed - noise_output`
- **梯度反转层(GRL)**: 在训练时对噪声输出应用梯度反转
- **多目标损失**: 结合clean和noisy分支的CTC损失
## TPU/XLA优化特性
### 1. XLA友好的操作设计
**静态张量操作替代动态操作**:
```python
# 优化前 (XLA不友好):
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
# 优化后 (XLA友好):
all_day_weights = torch.stack(list(self.day_weights), dim=0)
day_weights = torch.index_select(all_day_weights, 0, day_idx)
```
**XLA原语操作**:
```python
# 使用batch matrix multiplication (bmm)替代einsum
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
```
### 2. 混合精度训练的数据类型一致性
**全面的dtype一致性处理**:
- 基础操作中的dtype转换
- 补丁处理过程中的dtype保持
- 对抗训练残差连接的dtype匹配
- 梯度反转层的dtype处理
- 隐藏状态初始化的dtype一致性
### 3. 内存和编译优化
- **禁用autocast**: 在GRU操作中禁用自动混合精度以避免dtype冲突
- **静态形状**: 避免动态批次大小分配
- **元组返回**: 使用元组替代字典以获得更好的XLA编译性能
## 关键文件结构
### 核心训练文件
- **`rnn_model.py`**: 包含TripleGRUDecoder和三个子模型的完整实现具有XLA优化
- **`rnn_trainer.py`**: TPU训练器集成Accelerate库支持分布式训练
- **`train_model.py`**: 简洁的训练启动脚本
- **`rnn_args.yaml`**: TPU训练配置文件
### TPU特定文件
- **`accelerate_config_tpu.yaml`**: Accelerate库的TPU配置
- **`launch_tpu_training.py`**: TPU训练的便捷启动脚本
- **`TPU_SETUP_GUIDE.md`**: TPU环境设置指南
### 辅助文件
- **`dataset.py`**: 数据集加载和批处理
- **`data_augmentations.py`**: 数据增强工具
- **`evaluate_model_helpers.py`**: 评估工具函数
## 训练配置亮点
### TPU特定设置
```yaml
# TPU分布式训练设置
use_tpu: true
num_tpu_cores: 8
gradient_accumulation_steps: 2
use_amp: true # bfloat16混合精度
# 优化的批次配置
batch_size: 32 # 每个TPU核心的批次大小
num_dataloader_workers: 0 # TPU上设为0避免多进程问题
```
### 对抗训练配置
```yaml
adversarial:
enabled: true
grl_lambda: 0.5 # 梯度反转强度
noisy_loss_weight: 0.2 # 噪声分支损失权重
noise_l2_weight: 0.0 # 噪声输出L2正则化
warmup_steps: 0 # 对抗训练预热步数
```
## 模型规模
- **总参数**: ~687M个参数
- **神经输入**: 512特征 (每个电极2个特征 × 256个电极)
- **GRU隐藏单元**: 768个/层
- **输出类别**: 41个音素
- **补丁处理**: 14个时间步的补丁步长为4
## 数据流
1. **输入**: 512维神经特征20ms分辨率
2. **Day-specific变换**: 每日特定的线性变换和softsign激活
3. **补丁处理**: 将14个时间步连接为更大的输入向量
4. **三模型处理**:
- NoiseModel估计噪声
- CleanSpeechModel处理去噪信号
- NoisySpeechModel处理噪声信号(仅训练时)
5. **输出**: CTC兼容的音素logits
## 训练流程
### 推理模式 (`mode='inference'`):
- 只使用NoiseModel + CleanSpeechModel
- 计算: `clean_logits = CleanSpeechModel(x - NoiseModel(x))`
### 完整模式 (`mode='full'`, 训练时):
- 使用所有三个模型
- 对抗训练与梯度反转
- 多目标CTC损失
## 性能特点
- **编译优化**: XLA优化实现更快的TPU编译
- **内存效率**: bfloat16混合精度减少内存使用
- **分布式训练**: 支持8核心TPU并行训练
- **数据增强**: 高斯平滑、白噪声、时间抖动等
## 使用方法
### 基本训练
```bash
python train_model.py --config_path rnn_args.yaml
```
### 使用启动脚本
```bash
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
```
### 使用Accelerate
```bash
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
```
## 与原始模型的兼容性
- 保持相同的数学运算和模型架构
- 保留所有原始接口
- 支持'inference'和'full'两种模式
- 向后兼容现有训练脚本
## 技术创新点
1. **三模型对抗架构**: 创新的噪声估计和去噪方法
2. **XLA优化**: 全面的TPU编译优化
3. **混合精度一致性**: 解决了复杂对抗训练中的dtype冲突
4. **分布式训练集成**: 无缝的多核心TPU支持
这个TPU优化版本保持了原始模型的准确性同时显著提高了训练效率和可扩展性特别适合大规模神经解码任务的训练。

View File

@@ -0,0 +1,204 @@
# TPU Training Setup Guide for Brain-to-Text RNN
This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
## Prerequisites
### 1. Install PyTorch XLA for TPU Support
```bash
# Install PyTorch XLA (adjust version as needed)
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
# Or for specific PyTorch version:
pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
```
### 2. Install Accelerate Library
```bash
pip install accelerate
```
### 3. Verify TPU Access
```bash
# Check if TPU is available
python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
```
## Configuration Setup
### 1. Enable TPU in Configuration File
Update your `rnn_args.yaml` file with TPU settings:
```yaml
# TPU and distributed training settings
use_tpu: true # Enable TPU training
num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8)
gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size
dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues
use_amp: true # Enable mixed precision (bfloat16)
# Adjust batch size for multi-core TPU
dataset:
batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64)
```
### 2. TPU-Optimized Hyperparameters
Recommended adjustments for TPU training:
```yaml
# Learning rate scaling for distributed training
lr_max: 0.005 # May need to scale with number of cores
lr_max_day: 0.005
# Batch size considerations
dataset:
batch_size: 8 # Per-core batch size
days_per_batch: 4 # Keep consistent across cores
```
## Training Launch Options
### Method 1: Using the TPU Launch Script (Recommended)
```bash
# Basic TPU training with 8 cores
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
# Check TPU environment only
python launch_tpu_training.py --check_only
# Custom configuration file
python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
```
### Method 2: Direct Accelerate Launch
```bash
# Configure accelerate (one-time setup)
accelerate config
# Or use provided TPU config
export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
# Launch training
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
```
### Method 3: Manual XLA Launch (Advanced)
```bash
# Set TPU environment variables
export TPU_CORES=8
export XLA_USE_BF16=1
# Launch with PyTorch XLA
python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
```
## Key TPU Features Implemented
### 1. Distributed Training Support
- Automatic model parallelization across 8 TPU cores
- Synchronized gradient updates across all cores
- Proper checkpoint saving/loading for distributed training
### 2. Mixed Precision Training
- Automatic bfloat16 precision for TPU optimization
- Faster training with maintained numerical stability
- Reduced memory usage
### 3. TPU-Optimized Data Loading
- Single-threaded data loading (num_workers=0) for TPU compatibility
- Automatic data distribution across TPU cores
- Efficient batch processing
### 4. Inference Support
- TPU-compatible inference methods added to trainer class
- `inference()` and `inference_batch()` methods for production use
- Automatic mixed precision during inference
## Performance Optimization Tips
### 1. Batch Size Tuning
- Start with total batch size = 64 (8 cores × 8 per core)
- Increase gradually if memory allows
- Monitor TPU utilization with `top` command
### 2. Gradient Accumulation
- Use `gradient_accumulation_steps` to simulate larger batch sizes
- Effective batch size = batch_size × num_cores × gradient_accumulation_steps
### 3. Learning Rate Scaling
- Consider scaling learning rate with number of cores
- Linear scaling: `lr_new = lr_base × num_cores`
- May need warmup adjustment for large batch training
### 4. Memory Management
- TPU v3-8: 128GB HBM memory total
- TPU v4-8: 512GB HBM memory total
- Monitor memory usage to avoid OOM errors
## Monitoring and Debugging
### 1. TPU Utilization
```bash
# Monitor TPU usage
watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
```
### 2. Training Logs
- Training logs include device information and core count
- Monitor validation metrics across all cores
- Check for synchronization issues in distributed training
### 3. Common Issues and Solutions
**Issue**: "No TPU devices found"
- **Solution**: Verify TPU runtime is started and accessible
**Issue**: "DataLoader workers > 0 causes hangs"
- **Solution**: Set `dataloader_num_workers: 0` in config
**Issue**: "Mixed precision errors"
- **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16
**Issue**: "Gradient synchronization timeouts"
- **Solution**: Check network connectivity between TPU cores
## Example Training Command
```bash
# Complete TPU training example
cd model_training_nnn
# 1. Update config for TPU
vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8
# 2. Launch TPU training
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
# 3. Monitor training progress
tail -f trained_models/baseline_rnn/training_log
```
## Configuration Reference
### Required TPU Settings
```yaml
use_tpu: true
num_tpu_cores: 8
dataloader_num_workers: 0
use_amp: true
```
### Optional TPU Optimizations
```yaml
gradient_accumulation_steps: 1
dataset:
batch_size: 8 # Per-core batch size
mixed_precision: bf16
```
This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.

View File

@@ -0,0 +1,26 @@
# Accelerate Configuration for TPU Training
# This file configures Accelerate library for 8-core TPU training
# with mixed precision (bfloat16) support
compute_environment: TPU
distributed_type: TPU
tpu_name: null # Will use default TPU
tpu_zone: null # Will use default zone
# Mixed precision settings (use bfloat16 for TPU)
mixed_precision: bf16
# Number of TPU cores (v3-8 or v4-8 TPUs have 8 cores)
num_processes: 8
# Enable TPU debugging (set to false for production)
tpu_use_cluster: false
tpu_use_sudo: false
# Logging settings
main_process_port: null
machine_rank: 0
num_machines: 1
# Enable automatic optimization
use_cpu: false

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
"""
XLA Multi-threading Diagnostic Script
检查XLA编译是否正确使用多CPU核心
"""
import os
import psutil
import time
import threading
from concurrent.futures import ThreadPoolExecutor
def set_xla_environment():
"""设置XLA环境变量"""
cpu_count = os.cpu_count()
# 设置XLA环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
print(f"🔧 设置XLA环境变量:")
print(f" CPU核心数: {cpu_count}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
print("-" * 60)
def monitor_cpu_usage(duration=30, interval=1):
"""监控CPU使用情况"""
print(f"🔍 监控CPU使用情况 {duration}秒...")
cpu_usage_data = []
start_time = time.time()
while time.time() - start_time < duration:
# 获取每个CPU核心的使用率
cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True)
cpu_usage_data.append(cpu_percent_per_core)
# 实时显示
active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10)
print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, "
f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%",
end='\r')
print() # 换行
# 分析结果
if cpu_usage_data:
avg_usage_per_core = [
sum(core_data) / len(cpu_usage_data)
for core_data in zip(*cpu_usage_data)
]
active_cores = sum(1 for avg in avg_usage_per_core if avg > 5)
max_usage = max(avg_usage_per_core)
print(f"📊 CPU使用分析:")
print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}")
print(f" 最高平均使用率: {max_usage:.1f}%")
for i, usage in enumerate(avg_usage_per_core):
status = "🟢" if usage > 10 else "🔴" if usage > 5 else ""
print(f" CPU核心 {i}: {usage:.1f}% {status}")
return active_cores > 1
return False
def test_xla_compilation():
"""测试XLA编译"""
print(f"🚀 开始XLA编译测试...")
try:
import torch
import torch_xla.core.xla_model as xm
print(f"✅ PyTorch XLA导入成功")
print(f" XLA设备: {xm.xla_device()}")
print(f" XLA world size: {xm.xrt_world_size()}")
# 创建一个简单的计算图进行编译
device = xm.xla_device()
print(f"🔄 创建测试计算图...")
x = torch.randn(100, 100, device=device)
y = torch.randn(100, 100, device=device)
print(f"🔄 执行矩阵运算 (将触发XLA编译)...")
# 启动CPU监控
monitor_thread = threading.Thread(
target=lambda: monitor_cpu_usage(20, 0.5),
daemon=True
)
monitor_thread.start()
# 执行计算,触发编译
for i in range(10):
z = torch.matmul(x, y)
z = torch.relu(z)
z = torch.matmul(z, x.T)
if i == 0:
print(f"🔄 首次计算完成 (XLA编译应该正在进行)...")
elif i == 5:
print(f"🔄 第6次计算完成...")
# 等待监控完成
monitor_thread.join(timeout=25)
print(f"✅ XLA测试完成")
return True
except ImportError as e:
print(f"❌ PyTorch XLA导入失败: {e}")
return False
except Exception as e:
print(f"❌ XLA测试失败: {e}")
return False
def main():
print("=" * 60)
print("🧪 XLA多线程编译诊断工具")
print("=" * 60)
# 1. 设置环境
set_xla_environment()
# 2. 测试XLA编译并监控CPU
success = test_xla_compilation()
print("=" * 60)
if success:
print("✅ 诊断完成")
print("💡 如果看到多个CPU核心被激活说明XLA多线程工作正常")
print("💡 如果只有1-2个核心活跃可能需要其他优化方法")
else:
print("❌ 诊断失败")
print("💡 请检查PyTorch XLA安装和TPU环境")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,37 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import gaussian_filter1d
def gauss_smooth(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='same'):
"""
Applies a 1D Gaussian smoothing operation with PyTorch to smooth the data along the time axis.
Args:
inputs (tensor : B x T x N): A 3D tensor with batch size B, time steps T, and number of features N.
Assumed to already be on the correct device (e.g., GPU).
kernelSD (float): Standard deviation of the Gaussian smoothing kernel.
padding (str): Padding mode, either 'same' or 'valid'.
device (str): Device to use for computation (e.g., 'cuda' or 'cpu').
Returns:
smoothed (tensor : B x T x N): A smoothed 3D tensor with batch size B, time steps T, and number of features N.
"""
# Get Gaussian kernel
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
inp[smooth_kernel_size // 2] = 1
gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)
validIdx = np.argwhere(gaussKernel > 0.01)
gaussKernel = gaussKernel[validIdx]
gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))
# Convert to tensor
gaussKernel = torch.tensor(gaussKernel, dtype=torch.float32, device=device)
gaussKernel = gaussKernel.view(1, 1, -1) # [1, 1, kernel_size]
# Prepare convolution
B, T, C = inputs.shape
inputs = inputs.permute(0, 2, 1) # [B, C, T]
gaussKernel = gaussKernel.repeat(C, 1, 1) # [C, 1, kernel_size]
# Perform convolution
smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)
return smoothed.permute(0, 2, 1) # [B, T, C]

View File

@@ -0,0 +1,336 @@
import os
import torch
from torch.utils.data import Dataset
import h5py
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import math
class BrainToTextDataset(Dataset):
'''
Dataset for brain-to-text data
Returns an entire batch of data instead of a single example
'''
def __init__(
self,
trial_indicies,
n_batches,
split = 'train',
batch_size = 64,
days_per_batch = 1,
random_seed = -1,
must_include_days = None,
feature_subset = None
):
'''
trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values
n_batches: (int) - number of random training batches to create
split: (string) - string specifying if this is a train or test dataset
batch_size: (int) - number of examples to include in batch returned from __getitem_()
days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates
to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch
random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random
must_include_days ([int]) - list of days that must be included in every batch
feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data
'''
# Set random seed for reproducibility
if random_seed != -1:
np.random.seed(random_seed)
torch.manual_seed(random_seed)
self.split = split
# Ensure the split is valid
if self.split not in ['train', 'test']:
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
self.days_per_batch = days_per_batch
self.batch_size = batch_size
self.n_batches = n_batches
self.days = {}
self.n_trials = 0
self.trial_indicies = trial_indicies
self.n_days = len(trial_indicies.keys())
self.feature_subset = feature_subset
# Calculate total number of trials in the dataset
for d in trial_indicies:
self.n_trials += len(trial_indicies[d]['trials'])
if must_include_days is not None and len(must_include_days) > days_per_batch:
raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}')
if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train':
raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset')
if must_include_days is not None:
# Map must_include_days to correct indicies if they are negative
for i, d in enumerate(must_include_days):
if d < 0:
must_include_days[i] = self.n_days + d
self.must_include_days = must_include_days
# Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error
if self.split == 'train' and self.days_per_batch > self.n_days:
raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.')
if self.split == 'train':
self.batch_index = self.create_batch_index_train()
else:
self.batch_index = self.create_batch_index_test()
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
def __len__(self):
'''
How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
'''
return self.n_batches if self.n_batches is not None else 0
def __getitem__(self, idx):
'''
Gets an entire batch of data from the dataset, not just a single item
'''
batch = {
'input_features' : [],
'seq_class_ids' : [],
'n_time_steps' : [],
'phone_seq_lens' : [],
'day_indicies' : [],
'transcriptions' : [],
'block_nums' : [],
'trial_nums' : [],
}
index = self.batch_index[idx]
# Iterate through each day in the index
for d in index.keys():
# Open the hdf5 file for that day
with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f:
# For each trial in the selected trials in that day
for t in index[d]:
try:
g = f[f'trial_{t:04d}']
# Remove features is neccessary
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # neural data - convert to bf16 for TPU compatibility
if self.feature_subset:
input_features = input_features[:,self.feature_subset]
batch['input_features'].append(input_features)
batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels
batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions
batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding
batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding
batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers
batch['block_nums'].append(g.attrs['block_num'])
batch['trial_nums'].append(g.attrs['trial_num'])
except Exception as e:
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
continue
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens'])
batch['day_indicies'] = torch.tensor(batch['day_indicies'])
batch['transcriptions'] = torch.stack(batch['transcriptions'])
batch['block_nums'] = torch.tensor(batch['block_nums'])
batch['trial_nums'] = torch.tensor(batch['trial_nums'])
return batch
def create_batch_index_train(self):
'''
Create an index that maps a batch_number to batch_size number of trials
Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days
(or as even as possible if batch_size is not divisible by days_per_batch)
'''
batch_index = {}
# Precompute the days that are not in must_include_days
if self.must_include_days is not None:
non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days]
for batch_idx in range(self.n_batches):
batch = {}
# Which days will be used for this batch. Picked randomly without replacement
# TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day
# If must_include_days is not empty, we will use those days and then randomly sample the rest
if self.must_include_days is not None and len(self.must_include_days) > 0:
days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False)))
# Otherwise we will select random days without replacement
else:
days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False)
# How many trials will be sampled from each day
num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials
for d in days:
# Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem
trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True)
batch[d] = trial_idxs
# Remove extra trials
extra_trials = (num_trials * len(days)) - self.batch_size
# While we still have extra trials, remove the last trial from a random day
while extra_trials > 0:
d = np.random.choice(days)
batch[d] = batch[d][:-1]
extra_trials -= 1
batch_index[batch_idx] = batch
return batch_index
def create_batch_index_test(self):
'''
Create an index that is all validation/testing data in batches of up to self.batch_size
If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size
This index will ensures that every trial in the validation set is seen once and only once
'''
batch_index = {}
batch_idx = 0
for d in self.trial_indicies.keys():
# Calculate how many batches we need for this day
num_trials = len(self.trial_indicies[d]['trials'])
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
# Create batches for this day
for i in range(num_batches):
start_idx = i * self.batch_size
end_idx = min((i + 1) * self.batch_size, num_trials)
# Get the trial indices for this batch
batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx]
# Add to batch_index
batch_index[batch_idx] = {d : batch_trials}
batch_idx += 1
return batch_index
def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None):
'''
Split data from file_paths into train and test splits
Returns two dictionaries that detail which trials in each day will be a part of that split:
Example:
{
0: trials[1,2,3], session_path: 'path'
1: trials[2,5,6], session_path: 'path'
}
Args:
file_paths (list): List of file paths to the hdf5 files containing the data
test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing
seed (int): Seed for reproducibility. If set to -1, the split will be random
bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as:
{
'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
...
}
'''
# Set seed for reporoducibility
if seed != -1:
np.random.seed(seed)
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
good_trial_indices = []
if os.path.exists(path):
with h5py.File(path, 'r') as f:
num_trials = len(list(f.keys()))
for t in range(num_trials):
key = f'trial_{t:04d}'
block_num = f[key].attrs['block_num']
trial_num = f[key].attrs['trial_num']
if (
bad_trials_dict is not None
and session in bad_trials_dict
and str(block_num) in bad_trials_dict[session]
and trial_num in bad_trials_dict[session][str(block_num)]
):
# print(f'Bad trial: {session}_{block_num}_{trial_num}')
continue
good_trial_indices.append(t)
trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path}
# Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training
train_trials = {}
test_trials = {}
for day in trials_per_day.keys():
num_trials = trials_per_day[day]['num_trials']
# Generate all trial indices for this day (assuming 0-indexed)
all_trial_indices = trials_per_day[day]['trial_indices']
# If test_percentage is 0 or 1, we can just assign all trials to either train or test
if test_percentage == 0:
train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
continue
elif test_percentage == 1:
train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
continue
else:
# Calculate how many trials to use for testing
num_test = max(1, int(num_trials * test_percentage))
# Randomly select indices for testing
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
# Remaining indices go to training
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
# Store the split indices
train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']}
return train_trials, test_trials

View File

@@ -0,0 +1,304 @@
import os
import torch
import numpy as np
import pandas as pd
import redis
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
from rnn_model import GRUDecoder
from evaluate_model_helpers import *
# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
help='Path to the pretrained model directory (relative to the current working directory).')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set. '
'If "test", ground truth is not available.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
parser.add_argument('--gpu_number', type=int, default=-1,
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
args = parser.parse_args()
# paths to model and data directories
# Note: these paths are relative to the current working directory
model_path = args.model_path
data_dir = args.data_dir
# define evaluation type
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
# load csv file
b2txt_csv_df = pd.read_csv(args.csv_path)
# load model args
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
# set up gpu device
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
if gpu_number >= torch.cuda.device_count():
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
device = f'cuda:{gpu_number}'
device = torch.device(device)
print(f'Using {device} for model inference.')
else:
if gpu_number >= 0:
print(f'GPU number {gpu_number} requested but not available.')
print('Using CPU for model inference.')
device = torch.device('cpu')
# define model
model = GRUDecoder(
neural_dim = model_args['model']['n_input_features'],
n_units = model_args['model']['n_units'],
n_days = len(model_args['dataset']['sessions']),
n_classes = model_args['dataset']['n_classes'],
rnn_dropout = model_args['model']['rnn_dropout'],
input_dropout = model_args['model']['input_network']['input_layer_dropout'],
n_layers = model_args['model']['n_layers'],
patch_size = model_args['model']['patch_size'],
patch_stride = model_args['model']['patch_stride'],
)
# load model weights
checkpoint = torch.load(
os.path.join(model_path, 'checkpoint/best_checkpoint'),
map_location=device,
weights_only=False,
)
# rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(checkpoint['model_state_dict'].keys()):
checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
model.load_state_dict(checkpoint['model_state_dict'])
# add model to device
model.to(device)
# set model to eval mode
model.eval()
# load data for each session
test_data = {}
total_test_trials = 0
for session in model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()
# put neural data through the pretrained model to get phoneme predictions (logits)
with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
for session, data in test_data.items():
data['logits'] = []
data['pred_seq'] = []
input_layer = model_args['dataset']['sessions'].index(session)
for trial in range(len(data['neural_features'])):
# get neural input for the trial
neural_input = data['neural_features'][trial]
# add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# convert to torch tensor
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
# run decoding step
logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)
data['logits'].append(logits)
pbar.update(1)
pbar.close()
# convert logits to phoneme sequences and print them out
for session, data in test_data.items():
data['pred_seq'] = []
for trial in range(len(data['logits'])):
logits = data['logits'][trial][0]
pred_seq = np.argmax(logits, axis=-1)
# remove blanks (0)
pred_seq = [int(p) for p in pred_seq if p != 0]
# remove consecutive duplicates
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
# convert to phonemes
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# add to data
data['pred_seq'].append(pred_seq)
# print out the predicted sequences
block_num = data['block_num'][trial]
trial_num = data['trial_num'][trial]
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
if eval_type == 'val':
sentence_label = data['sentence_label'][trial]
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
print(f'Sentence label: {sentence_label}')
print(f'True sequence: {" ".join(true_seq)}')
print(f'Predicted Sequence: {" ".join(pred_seq)}')
print()
# language model inference via redis
# make sure that the standalone language model is running on the localhost redis ip
# see README.md for instructions on how to run the language model
def connect_to_redis_with_retry(host, port, password, db=0, max_retries=10, retry_delay=3):
"""Connect to Redis with retry logic"""
for attempt in range(max_retries):
try:
print(f"Attempting to connect to Redis at {host}:{port} (attempt {attempt + 1}/{max_retries})...")
r = redis.Redis(host=host, port=port, db=db, password=password)
r.ping() # Test the connection
print(f"Successfully connected to Redis at {host}:{port}")
return r
except redis.exceptions.ConnectionError as e:
print(f"Redis connection failed (attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
print("Max retries reached. Could not connect to Redis.")
raise e
except Exception as e:
print(f"Unexpected error connecting to Redis: {e}")
if attempt < max_retries - 1:
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
raise e
r = connect_to_redis_with_retry('hs.zchens.cn', 6379, 'admin01')
r.flushall() # clear all streams in redis
# define redis streams for the remote language model
remote_lm_input_stream = 'remote_lm_input'
remote_lm_output_partial_stream = 'remote_lm_output_partial'
remote_lm_output_final_stream = 'remote_lm_output_final'
# set timestamps for last entries seen in the redis streams
remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
lm_results = {
'session': [],
'block': [],
'trial': [],
'true_sentence': [],
'pred_sentence': [],
}
# loop through all trials and put logits into the remote language model to get text predictions
# note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090)
with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
for session in test_data.keys():
for trial in range(len(test_data[session]['logits'])):
# get trial logits and rearrange them for the LM
logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
# reset language model
remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
'''
# update language model parameters
remote_lm_done_updating_lastEntrySeen = update_remote_lm_params(
r,
remote_lm_done_updating_lastEntrySeen,
acoustic_scale=0.35,
blank_penalty=90.0,
alpha=0.55,
)
'''
# put logits into LM
remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
r,
remote_lm_input_stream,
remote_lm_output_partial_stream,
remote_lm_output_partial_lastEntrySeen,
logits,
)
# finalize remote LM
remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
r,
remote_lm_output_final_stream,
remote_lm_output_final_lastEntrySeen,
)
# get the best candidate sentence
best_candidate_sentence = lm_out['candidate_sentences'][0]
# store results
lm_results['session'].append(session)
lm_results['block'].append(test_data[session]['block_num'][trial])
lm_results['trial'].append(test_data[session]['trial_num'][trial])
if eval_type == 'val':
lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
else:
lm_results['true_sentence'].append(None)
lm_results['pred_sentence'].append(best_candidate_sentence)
# update progress bar
pbar.update(1)
pbar.close()
# if using the validation set, lets calculate the aggregate word error rate (WER)
if eval_type == 'val':
total_true_length = 0
total_edit_distance = 0
lm_results['edit_distance'] = []
lm_results['num_words'] = []
for i in range(len(lm_results['pred_sentence'])):
true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
total_true_length += len(true_sentence.split())
total_edit_distance += ed
lm_results['edit_distance'].append(ed)
lm_results['num_words'].append(len(true_sentence.split()))
print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
print(f'True sentence: {true_sentence}')
print(f'Predicted sentence: {pred_sentence}')
print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%')
print()
print(f'Total true sentence length: {total_true_length}')
print(f'Total edit distance: {total_edit_distance}')
print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
# write predicted sentences to a csv file. put a timestamp in the filename (YYYYMMDD_HHMMSS)
output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.csv')
ids = [i for i in range(len(lm_results['pred_sentence']))]
df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']})
df_out.to_csv(output_file, index=False)

View File

@@ -0,0 +1,297 @@
import torch
import numpy as np
import h5py
import time
import re
from data_augmentations import gauss_smooth
LOGIT_TO_PHONEME = [
'BLANK',
'AA', 'AE', 'AH', 'AO', 'AW',
'AY', 'B', 'CH', 'D', 'DH',
'EH', 'ER', 'EY', 'F', 'G',
'HH', 'IH', 'IY', 'JH', 'K',
'L', 'M', 'N', 'NG', 'OW',
'OY', 'P', 'R', 'S', 'SH',
'T', 'TH', 'UH', 'UW', 'V',
'W', 'Y', 'Z', 'ZH',
' | ',
]
def _extract_transcription(input):
endIdx = np.argwhere(input == 0)[0, 0]
trans = ''
for c in range(endIdx):
trans += chr(input[c])
return trans
def load_h5py_file(file_path, b2txt_csv_df):
data = {
'neural_features': [],
'n_time_steps': [],
'seq_class_ids': [],
'seq_len': [],
'transcriptions': [],
'sentence_label': [],
'session': [],
'block_num': [],
'trial_num': [],
'corpus': [],
}
# Open the hdf5 file for that day
with h5py.File(file_path, 'r') as f:
keys = list(f.keys())
# For each trial in the selected trials in that day
for key in keys:
g = f[key]
neural_features = g['input_features'][:]
n_time_steps = g.attrs['n_time_steps']
seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None
seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None
transcription = g['transcription'][:] if 'transcription' in g else None
sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None
session = g.attrs['session']
block_num = g.attrs['block_num']
trial_num = g.attrs['trial_num']
# match this trial up with the csv to get the corpus name
year, month, day = session.split('.')[1:]
date = f'{year}-{month}-{day}'
row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]
corpus_name = row['Corpus'].values[0]
data['neural_features'].append(neural_features)
data['n_time_steps'].append(n_time_steps)
data['seq_class_ids'].append(seq_class_ids)
data['seq_len'].append(seq_len)
data['transcriptions'].append(transcription)
data['sentence_label'].append(sentence_label)
data['session'].append(session)
data['block_num'].append(block_num)
data['trial_num'].append(trial_num)
data['corpus'].append(corpus_name)
return data
def rearrange_speech_logits_pt(logits):
# original order is [BLANK, phonemes..., SIL]
# rearrange so the order is [BLANK, SIL, phonemes...]
logits = np.concatenate((logits[:, :, 0:1], logits[:, :, -1:], logits[:, :, 1:-1]), axis=-1)
return logits
# single decoding step function.
# smooths data and puts it through the model.
def runSingleDecodingStep(x, input_layer, model, model_args, device):
# Use autocast for efficiency
with torch.autocast(device_type = "cuda", enabled = model_args['use_amp'], dtype = torch.bfloat16):
x = gauss_smooth(
inputs = x,
device = device,
smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],
smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],
padding = 'valid',
)
with torch.no_grad():
logits, _ = model(
x = x,
day_idx = torch.tensor([input_layer], device=device),
states = None, # no initial states
return_state = True,
)
# convert logits from bfloat16 to float32
logits = logits.float().cpu().numpy()
# # original order is [BLANK, phonemes..., SIL]
# # rearrange so the order is [BLANK, SIL, phonemes...]
# logits = rearrange_speech_logits_pt(logits)
return logits
def remove_punctuation(sentence):
# Remove punctuation
sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
sentence = sentence.replace('- ', ' ').lower()
sentence = sentence.replace('--', '').lower()
sentence = sentence.replace(" '", "'").lower()
sentence = sentence.strip()
sentence = ' '.join([word for word in sentence.split() if word != ''])
return sentence
def get_current_redis_time_ms(redis_conn):
t = redis_conn.time()
return int(t[0]*1000 + t[1]/1000)
######### language model helper functions ##########
def reset_remote_language_model(
r,
remote_lm_done_resetting_lastEntrySeen,
):
r.xadd('remote_lm_reset', {'done': 0})
time.sleep(0.001)
# print('Resetting remote language model before continuing...')
remote_lm_done_resetting = []
while len(remote_lm_done_resetting) == 0:
remote_lm_done_resetting = r.xread(
{'remote_lm_done_resetting': remote_lm_done_resetting_lastEntrySeen},
count=1,
block=10000,
)
if len(remote_lm_done_resetting) == 0:
print(f'Still waiting for remote lm reset from ts {remote_lm_done_resetting_lastEntrySeen}...')
for entry_id, entry_data in remote_lm_done_resetting[0][1]:
remote_lm_done_resetting_lastEntrySeen = entry_id
# print('Remote language model reset.')
return remote_lm_done_resetting_lastEntrySeen
def update_remote_lm_params(
r,
remote_lm_done_updating_lastEntrySeen,
acoustic_scale=0.35,
blank_penalty=90.0,
alpha=0.55,
):
# update remote lm params
entry_dict = {
# 'max_active': max_active,
# 'min_active': min_active,
# 'beam': beam,
# 'lattice_beam': lattice_beam,
'acoustic_scale': acoustic_scale,
# 'ctc_blank_skip_threshold': ctc_blank_skip_threshold,
# 'length_penalty': length_penalty,
# 'nbest': nbest,
'blank_penalty': blank_penalty,
'alpha': alpha,
# 'do_opt': do_opt,
# 'rescore': rescore,
# 'top_candidates_to_augment': top_candidates_to_augment,
# 'score_penalty_percent': score_penalty_percent,
# 'specific_word_bias': specific_word_bias,
}
r.xadd('remote_lm_update_params', entry_dict)
time.sleep(0.001)
remote_lm_done_updating = []
while len(remote_lm_done_updating) == 0:
remote_lm_done_updating = r.xread(
{'remote_lm_done_updating_params': remote_lm_done_updating_lastEntrySeen},
block=10000,
count=1,
)
if len(remote_lm_done_updating) == 0:
print(f'Still waiting for remote lm to update parameters from ts {remote_lm_done_updating_lastEntrySeen}...')
for entry_id, entry_data in remote_lm_done_updating[0][1]:
remote_lm_done_updating_lastEntrySeen = entry_id
# print('Remote language model params updated.')
return remote_lm_done_updating_lastEntrySeen
def send_logits_to_remote_lm(
r,
remote_lm_input_stream,
remote_lm_output_partial_stream,
remote_lm_output_partial_lastEntrySeen,
logits,
):
# put logits into remote lm and get partial output
r.xadd(remote_lm_input_stream, {'logits': np.float32(logits).tobytes()})
remote_lm_output = []
while len(remote_lm_output) == 0:
remote_lm_output = r.xread(
{remote_lm_output_partial_stream: remote_lm_output_partial_lastEntrySeen},
block=10000,
count=1,
)
if len(remote_lm_output) == 0:
print(f'Still waiting for remote lm partial output from ts {remote_lm_output_partial_lastEntrySeen}...')
for entry_id, entry_data in remote_lm_output[0][1]:
remote_lm_output_partial_lastEntrySeen = entry_id
decoded = entry_data[b'lm_response_partial'].decode()
return remote_lm_output_partial_lastEntrySeen, decoded
def finalize_remote_lm(
r,
remote_lm_output_final_stream,
remote_lm_output_final_lastEntrySeen,
):
# finalize remote lm
r.xadd('remote_lm_finalize', {'done': 0})
time.sleep(0.005)
remote_lm_output = []
while len(remote_lm_output) == 0:
remote_lm_output = r.xread(
{remote_lm_output_final_stream: remote_lm_output_final_lastEntrySeen},
block=10000,
count=1,
)
if len(remote_lm_output) == 0:
print(f'Still waiting for remote lm final output from ts {remote_lm_output_final_lastEntrySeen}...')
# print('Received remote lm final output.')
for entry_id, entry_data in remote_lm_output[0][1]:
remote_lm_output_final_lastEntrySeen = entry_id
candidate_sentences = [str(c) for c in entry_data[b'scoring'].decode().split(';')[::5]]
candidate_acoustic_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[1::5]]
candidate_ngram_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[2::5]]
candidate_llm_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[3::5]]
candidate_total_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[4::5]]
# account for a weird edge case where there are no candidate sentences
if len(candidate_sentences) == 0 or len(candidate_total_scores) == 0:
print('No candidate sentences were received from the language model.')
candidate_sentences = ['']
candidate_acoustic_scores = [0]
candidate_ngram_scores = [0]
candidate_llm_scores = [0]
candidate_total_scores = [0]
else:
# sort candidate sentences by total score (higher is better)
sort_order = np.argsort(candidate_total_scores)[::-1]
candidate_sentences = [candidate_sentences[i] for i in sort_order]
candidate_acoustic_scores = [candidate_acoustic_scores[i] for i in sort_order]
candidate_ngram_scores = [candidate_ngram_scores[i] for i in sort_order]
candidate_llm_scores = [candidate_llm_scores[i] for i in sort_order]
candidate_total_scores = [candidate_total_scores[i] for i in sort_order]
# loop through candidates backwards and remove any duplicates
for i in range(len(candidate_sentences)-1, 0, -1):
if candidate_sentences[i] in candidate_sentences[:i]:
candidate_sentences.pop(i)
candidate_acoustic_scores.pop(i)
candidate_ngram_scores.pop(i)
candidate_llm_scores.pop(i)
candidate_total_scores.pop(i)
lm_out = {
'candidate_sentences': candidate_sentences,
'candidate_acoustic_scores': candidate_acoustic_scores,
'candidate_ngram_scores': candidate_ngram_scores,
'candidate_llm_scores': candidate_llm_scores,
'candidate_total_scores': candidate_total_scores,
}
return remote_lm_output_final_lastEntrySeen, lm_out

View File

@@ -0,0 +1,124 @@
# ====================
# 单元格4: 逐步调试完整模型编译
# ====================
# 如果单元格3测试通过运行这个单元格
print("🔧 逐步测试完整TripleGRUDecoder模型...")
# 导入完整模型
import sys
sys.path.append('.') # 确保能导入本地模块
try:
from rnn_model import TripleGRUDecoder
print("✅ TripleGRUDecoder导入成功")
except ImportError as e:
print(f"❌ 模型导入失败: {e}")
print("请确保rnn_model.py在当前目录中")
# 分阶段测试
def test_model_compilation_stages():
"""分阶段测试模型编译"""
device = xm.xla_device()
# 阶段1: 测试NoiseModel单独编译
print("\n🔬 阶段1: 测试NoiseModel...")
try:
from rnn_model import NoiseModel
noise_model = NoiseModel(
neural_dim=512,
n_units=384, # 减小参数
n_days=4,
patch_size=8 # 减小patch size
).to(device)
x = torch.randn(2, 20, 512, device=device)
day_idx = torch.tensor([0, 1], device=device)
start_time = time.time()
with torch.no_grad():
output, states = noise_model(x, day_idx)
compile_time = time.time() - start_time
print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}")
print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
return True, compile_time
except Exception as e:
print(f"❌ NoiseModel编译失败: {e}")
return False, 0
# 阶段2: 测试CleanSpeechModel
print("\n🔬 阶段2: 测试CleanSpeechModel...")
try:
from rnn_model import CleanSpeechModel
clean_model = CleanSpeechModel(
neural_dim=512,
n_units=384,
n_days=4,
n_classes=41,
patch_size=8
).to(device)
start_time = time.time()
with torch.no_grad():
output = clean_model(x, day_idx)
compile_time = time.time() - start_time
print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}")
return True, compile_time
except Exception as e:
print(f"❌ CleanSpeechModel编译失败: {e}")
return False, 0
# 阶段3: 测试完整TripleGRUDecoder
print("\n🔬 阶段3: 测试TripleGRUDecoder...")
try:
model = TripleGRUDecoder(
neural_dim=512,
n_units=384, # 比原来的768小
n_days=4, # 减少天数
n_classes=41,
patch_size=8 # 减小patch size
).to(device)
print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 启动编译监控
compilation_monitor.start_monitoring()
start_time = time.time()
with torch.no_grad():
# 测试inference模式
logits = model(x, day_idx, None, False, 'inference')
compile_time = time.time() - start_time
compilation_monitor.complete_monitoring()
print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}")
print(f"📤 输出形状: {logits.shape}")
return True, compile_time
except Exception as e:
compilation_monitor.complete_monitoring()
print(f"❌ TripleGRUDecoder编译失败: {e}")
return False, 0
# 运行分阶段测试
stage_results = test_model_compilation_stages()
if stage_results:
print(f"\n🎉 所有编译测试完成!")
print("💡 下一步可以尝试:")
print(" 1. 使用简化配置进行训练")
print(" 2. 逐步增加模型复杂度")
print(" 3. 监控TPU资源使用情况")
else:
print(f"\n⚠️ 编译测试发现问题")
print("💡 建议:")
print(" 1. 进一步减小模型参数")
print(" 2. 检查内存使用情况")
print(" 3. 使用CPU模式进行调试")

View File

@@ -0,0 +1,131 @@
# ====================
# 单元格2: XLA编译进度监控
# ====================
import torch
import torch.nn as nn
import time
import threading
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
# 导入XLA (环境变量已在单元格1中设置)
print("🚀 导入PyTorch XLA...")
import torch_xla.core.xla_model as xm
print(f"✅ XLA导入成功!")
print(f" TPU设备: {xm.xla_device()}")
print(f" World Size: {xm.xrt_world_size()}")
# 创建编译进度监控器
class JupyterCompilationMonitor:
def __init__(self):
self.start_time = None
self.is_monitoring = False
# 创建输出widget
self.output_widget = widgets.Output()
# 创建进度条
self.progress_bar = widgets.IntProgress(
value=0,
min=0,
max=100,
description='XLA编译:',
bar_style='info',
style={'bar_color': '#1f77b4'},
orientation='horizontal'
)
# 创建状态标签
self.status_label = widgets.HTML(
value="<b>准备开始编译...</b>"
)
# 创建CPU使用率显示
self.cpu_label = widgets.HTML(
value="CPU: ---%"
)
self.memory_label = widgets.HTML(
value="内存: ---%"
)
# 组合界面
self.monitor_box = widgets.VBox([
widgets.HTML("<h3>🔄 XLA编译监控</h3>"),
self.progress_bar,
self.status_label,
widgets.HBox([self.cpu_label, self.memory_label]),
self.output_widget
])
def start_monitoring(self):
"""开始监控"""
self.start_time = time.time()
self.is_monitoring = True
display(self.monitor_box)
# 启动监控线程
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self.monitor_thread.start()
def _monitor_loop(self):
"""监控循环"""
while self.is_monitoring:
try:
elapsed = time.time() - self.start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
# 更新进度条 (模拟进度)
progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
self.progress_bar.value = progress
# 获取系统资源
cpu_percent = psutil.cpu_percent(interval=0.1)
memory_percent = psutil.virtual_memory().percent
# 更新显示
self.status_label.value = f"<b>编译进行中... ⏱️ {minutes:02d}:{seconds:02d}</b>"
self.cpu_label.value = f"<b>🖥️ CPU: {cpu_percent:5.1f}%</b>"
self.memory_label.value = f"<b>💾 内存: {memory_percent:5.1f}%</b>"
# 检测是否编译完成 (CPU使用率突然下降)
if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
self.complete_monitoring()
break
time.sleep(1)
except Exception as e:
with self.output_widget:
print(f"监控错误: {e}")
break
def complete_monitoring(self):
"""完成监控"""
if self.is_monitoring:
self.is_monitoring = False
elapsed = time.time() - self.start_time
self.progress_bar.value = 100
self.progress_bar.bar_style = 'success'
self.status_label.value = f"<b style='color: green'>✅ 编译完成! 总耗时: {elapsed:.2f}秒</b>"
with self.output_widget:
print(f"\n🎉 XLA编译成功完成!")
print(f"⏱️ 总耗时: {elapsed:.2f}")
if elapsed < 60:
print("✅ 编译速度正常")
elif elapsed < 300:
print("⚠️ 编译稍慢,但可接受")
else:
print("❌ 编译过慢,建议检查设置")
# 创建全局监控器
compilation_monitor = JupyterCompilationMonitor()
print("✅ 编译监控器已准备就绪!")
print("💡 运行下一个单元格开始XLA编译测试")

View File

@@ -0,0 +1,45 @@
# ====================
# 单元格1: 环境设置 (必须第一个运行!)
# ====================
import os
import time
import psutil
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
# ⚠️ 关键: 在导入torch_xla之前设置环境变量
print("🔧 设置XLA环境变量...")
# 获取CPU核心数
cpu_count = os.cpu_count()
print(f"检测到 {cpu_count} 个CPU核心")
# 设置XLA编译优化环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
os.environ['XLA_USE_BF16'] = '1'
# 显示设置结果
print("✅ XLA环境变量已设置:")
print(f" CPU核心数: {cpu_count}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
# 系统资源检查
memory_info = psutil.virtual_memory()
print(f"\n💾 系统内存信息:")
print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
print(f" 使用率: {memory_info.percent:.1f}%")
if memory_info.available < 8 * (1024**3): # 小于8GB
print("⚠️ 警告: 可用内存不足8GB可能影响XLA编译速度")
else:
print("✅ 内存充足")
print("\n🎯 环境设置完成! 现在可以运行下一个单元格")

View File

@@ -0,0 +1,78 @@
# ====================
# 单元格3: 快速XLA编译测试
# ====================
# 简化测试模型
class QuickTestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(512, 128)
self.gru = nn.GRU(128, 64, batch_first=True)
self.linear2 = nn.Linear(64, 41)
def forward(self, x):
x = torch.relu(self.linear1(x))
x, _ = self.gru(x)
x = self.linear2(x)
return x
print("🧪 开始XLA编译快速测试...")
# 启动监控
compilation_monitor.start_monitoring()
try:
# 获取TPU设备
device = xm.xla_device()
# 创建小模型
model = QuickTestModel().to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f"📊 测试模型参数: {param_count:,}")
# 创建测试数据 (很小的batch)
x = torch.randn(2, 20, 512, device=device)
print(f"📥 输入数据形状: {x.shape}")
print("🔄 开始首次前向传播 (触发XLA编译)...")
# 首次前向传播 - 这会触发XLA编译
with torch.no_grad():
start_compile = time.time()
output = model(x)
compile_time = time.time() - start_compile
print(f"✅ XLA编译完成!")
print(f"📤 输出形状: {output.shape}")
# 完成监控
compilation_monitor.complete_monitoring()
# 测试编译后的性能
print("\n🚀 测试编译后的执行速度...")
with torch.no_grad():
start_exec = time.time()
for _ in range(10):
output = model(x)
avg_exec_time = (time.time() - start_exec) / 10
print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms")
# 性能评估
if compile_time < 30:
print("✅ 编译速度优秀! 可以尝试完整模型")
test_result = "excellent"
elif compile_time < 120:
print("✅ 编译速度良好! 建议使用简化配置")
test_result = "good"
else:
print("⚠️ 编译速度较慢,建议进一步优化")
test_result = "slow"
except Exception as e:
compilation_monitor.complete_monitoring()
print(f"❌ 测试失败: {e}")
test_result = "failed"
print(f"\n📋 测试结果: {test_result}")
print("💡 如果测试通过,可以运行下一个单元格进行完整训练")

View File

@@ -0,0 +1,161 @@
#!/usr/bin/env python3
"""
TPU Training Launch Script for Brain-to-Text RNN Model
This script provides easy TPU training setup using Accelerate library.
Supports both single TPU core and multi-core (8 cores) training.
Usage:
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
Requirements:
- PyTorch XLA installed
- Accelerate library installed
- TPU runtime available
"""
import argparse
import yaml
import os
import sys
from pathlib import Path
def update_config_for_tpu(config_path, num_cores=8):
"""
Update configuration file to enable TPU training
"""
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# Enable TPU settings
config['use_tpu'] = True
config['num_tpu_cores'] = num_cores
config['dataloader_num_workers'] = 0 # Required for TPU
config['use_amp'] = True # Enable mixed precision with bfloat16
# Adjust batch size and gradient accumulation for multi-core TPU
if num_cores > 1:
# Distribute batch size across cores
original_batch_size = config['dataset']['batch_size']
config['dataset']['batch_size'] = max(1, original_batch_size // num_cores)
config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1))
print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core")
print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
# Save updated config
tpu_config_path = config_path.replace('.yaml', '_tpu.yaml')
with open(tpu_config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
print(f"TPU configuration saved to: {tpu_config_path}")
return tpu_config_path
def check_tpu_environment():
"""
Check if TPU environment is properly set up
"""
try:
import torch_xla
import torch_xla.core.xla_model as xm
# Check if TPUs are available
device = xm.xla_device()
print(f"TPU device available: {device}")
print(f"TPU ordinal: {xm.get_ordinal()}")
print(f"TPU world size: {xm.xrt_world_size()}")
return True
except ImportError:
print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.")
return False
except Exception as e:
print(f"ERROR: TPU not available - {e}")
return False
def run_tpu_training(config_path, num_cores=8):
"""
Launch TPU training using accelerate
"""
# Check TPU environment
if not check_tpu_environment():
sys.exit(1)
# Update config for TPU
tpu_config_path = update_config_for_tpu(config_path, num_cores)
# Set TPU environment variables BEFORE launching training
os.environ['TPU_CORES'] = str(num_cores)
os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16
# Critical XLA multi-threading settings - must be set before torch_xla import
cpu_count = os.cpu_count()
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
print(f"Set XLA compilation to use {cpu_count} CPU threads")
print(f"XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f"PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
# Launch training with accelerate using subprocess to ensure environment variables are passed
cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}"
print(f"Launching TPU training with command:")
print(f" {cmd}")
print(f"Using {num_cores} TPU cores")
print("-" * 60)
# Use subprocess to ensure environment variables are properly inherited
import subprocess
# Create environment with our XLA settings
env = os.environ.copy()
env.update({
'TPU_CORES': str(num_cores),
'XLA_USE_BF16': '1',
'XLA_FLAGS': (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
),
'PYTORCH_XLA_COMPILATION_THREADS': str(cpu_count)
})
print(f"Environment variables set for subprocess:")
print(f" XLA_FLAGS: {env['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {env['PYTORCH_XLA_COMPILATION_THREADS']}")
print("-" * 60)
# Execute training with proper environment
result = subprocess.run(cmd.split(), env=env)
return result.returncode
def main():
parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN')
parser.add_argument('--config', default='rnn_args.yaml',
help='Path to configuration file (default: rnn_args.yaml)')
parser.add_argument('--num_cores', type=int, default=8,
help='Number of TPU cores to use (default: 8)')
parser.add_argument('--check_only', action='store_true',
help='Only check TPU environment, do not launch training')
args = parser.parse_args()
# Verify config file exists
if not os.path.exists(args.config):
print(f"ERROR: Configuration file {args.config} not found")
sys.exit(1)
if args.check_only:
check_tpu_environment()
return
# Run TPU training
run_tpu_training(args.config, args.num_cores)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""
XLA编译进度监控脚本
"""
import os
import time
import threading
import psutil
from contextlib import contextmanager
def monitor_compilation_progress():
"""监控XLA编译进度"""
print("🔍 XLA编译进度监控已启动...")
start_time = time.time()
dots = 0
while True:
elapsed = time.time() - start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
# 获取CPU使用率
cpu_percent = psutil.cpu_percent(interval=1)
memory_percent = psutil.virtual_memory().percent
# 动态显示
dots = (dots + 1) % 4
dot_str = "." * dots + " " * (3 - dots)
print(f"\r🔄 XLA编译中{dot_str} "
f"⏱️ {minutes:02d}:{seconds:02d} "
f"🖥️ CPU: {cpu_percent:5.1f}% "
f"💾 内存: {memory_percent:5.1f}%", end="", flush=True)
time.sleep(1)
@contextmanager
def compilation_monitor():
"""编译监控上下文管理器"""
print("🚀 开始XLA编译监控...")
# 启动监控线程
monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True)
monitor_thread.start()
start_time = time.time()
try:
yield
finally:
elapsed = time.time() - start_time
print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}")
# 修改trainer中的使用方法
def add_compilation_monitoring_to_trainer():
"""给trainer添加编译监控的示例代码"""
example_code = '''
# 在rnn_trainer.py的train方法中添加
def train(self):
from monitor_xla_compilation import compilation_monitor
self.model.train()
train_losses = []
# ... 其他初始化代码 ...
self.logger.info("Starting training loop - XLA compilation monitoring enabled...")
# 使用编译监控
with compilation_monitor():
for i, batch in enumerate(self.train_loader):
# 第一个batch会触发XLA编译
# 监控会显示编译进度
# ... 训练代码 ...
# 编译完成后会自动停止监控
break # 只需要第一个batch来触发编译
# 继续正常训练循环
for i, batch in enumerate(self.train_loader):
# ... 正常训练代码 ...
'''
print("📝 如何在trainer中使用编译监控:")
print(example_code)
if __name__ == "__main__":
print("🧪 XLA编译监控工具")
print("=" * 50)
# 演示如何使用
print("📖 使用方法:")
print("1. 将此文件导入到你的训练脚本中")
print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器")
print("3. 会实时显示编译进度和系统资源使用情况")
add_compilation_monitoring_to_trainer()

View File

@@ -0,0 +1,181 @@
model:
n_input_features: 512 # number of input features in the neural data. (2 features per electrode, 256 electrodes)
n_units: 768 # number of units per GRU layer
rnn_dropout: 0.4 # dropout rate for the GRU layers
rnn_trainable: true # whether the GRU layers are trainable
n_layers: 5 # number of GRU layers
patch_size: 14 # size of the input patches (14 time steps)
patch_stride: 4 # stride for the input patches (4 time steps)
input_network:
n_input_layers: 1 # number of input layers per network (one network for each day)
input_layer_sizes:
- 512 # size of the input layer (number of input features)
input_trainable: true # whether the input layer is trainable
input_layer_dropout: 0.2 # dropout rate for the input layer
mode: train
use_amp: true # whether to use automatic mixed precision (AMP) for training with bfloat16 on TPU
# TPU distributed training settings
use_tpu: true # TPU training enabled
num_tpu_cores: 8 # number of TPU cores to use (full TPU v3-8 or v4-8)
gradient_accumulation_steps: 2 # number of gradient accumulation steps for distributed training (2x32=64 effective batch size)
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
init_from_checkpoint: false # whether to initialize the model from a checkpoint
init_checkpoint_path: None # path to the checkpoint to initialize the model from, if any
save_best_checkpoint: true # whether to save the best checkpoint based on validation metrics
save_all_val_steps: false # whether to save checkpoints at all validation steps
save_final_model: false # whether to save the final model after training
save_val_metrics: true # whether to save validation metrics during training
early_stopping: false # whether to use early stopping based on validation metrics
early_stopping_val_steps: 20 # number of validation steps to wait before stopping training if no improvement is seen
num_training_batches: 120000 # number of training batches to run
lr_scheduler_type: cosine # type of learning rate scheduler to use
lr_max: 0.005 # maximum learning rate for the main model
lr_min: 0.0001 # minimum learning rate for the main model
lr_decay_steps: 120000 # number of steps for the learning rate decay
lr_warmup_steps: 1000 # number of warmup steps for the learning rate scheduler
lr_max_day: 0.005 # maximum learning rate for the day specific input layers
lr_min_day: 0.0001 # minimum learning rate for the day specific input layers
lr_decay_steps_day: 120000 # number of steps for the learning rate decay for the day specific input layers
lr_warmup_steps_day: 1000 # number of warmup steps for the learning rate scheduler for the day specific input layers
beta0: 0.9 # beta0 parameter for the Adam optimizer
beta1: 0.999 # beta1 parameter for the Adam optimizer
epsilon: 0.1 # epsilon parameter for the Adam optimizer
weight_decay: 0.001 # weight decay for the main model
weight_decay_day: 0 # weight decay for the day specific input layers
seed: 10 # random seed for reproducibility
grad_norm_clip_value: 10 # gradient norm clipping value
batches_per_train_log: 200 # number of batches per training log
batches_per_val_step: 2000 # number of batches per validation step
batches_per_save: 0 # number of batches per save
log_individual_day_val_PER: true # whether to log individual day validation performance
log_val_skip_logs: false # whether to skip logging validation metrics
save_val_logits: true # whether to save validation logits
save_val_data: false # whether to save validation data
dataset:
data_transforms:
white_noise_std: 1.0 # standard deviation of the white noise added to the data
constant_offset_std: 0.2 # standard deviation of the constant offset added to the data
random_walk_std: 0.0 # standard deviation of the random walk added to the data
random_walk_axis: -1 # axis along which the random walk is applied
static_gain_std: 0.0 # standard deviation of the static gain applied to the data
random_cut: 3 # number of time steps to randomly cut from the beginning of each batch of trials
smooth_kernel_size: 100 # size of the smoothing kernel applied to the data
smooth_data: true # whether to smooth the data
smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data
neural_dim: 512 # dimensionality of the neural data
batch_size: 32 # batch size for training (reduced for TPU memory constraints)
n_classes: 41 # number of classes (phonemes) in the dataset
max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial
days_per_batch: 4 # number of randomly-selected days to include in each batch
seed: 1 # random seed for reproducibility
num_dataloader_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
loader_shuffle: false # whether to shuffle the data loader
must_include_days: null # specific days to include in the dataset
test_percentage: 0.1 # percentage of data to use for testing
feature_subset: null # specific features to include in the dataset
dataset_dir: ../data/hdf5_data_final # directory containing the dataset
bad_trials_dict: null # dictionary of bad trials to exclude from the dataset
sessions: # list of sessions to include in the dataset
- t15.2023.08.11
- t15.2023.08.13
- t15.2023.08.18
- t15.2023.08.20
- t15.2023.08.25
- t15.2023.08.27
- t15.2023.09.01
- t15.2023.09.03
- t15.2023.09.24
- t15.2023.09.29
- t15.2023.10.01
- t15.2023.10.06
- t15.2023.10.08
- t15.2023.10.13
- t15.2023.10.15
- t15.2023.10.20
- t15.2023.10.22
- t15.2023.11.03
- t15.2023.11.04
- t15.2023.11.17
- t15.2023.11.19
- t15.2023.11.26
- t15.2023.12.03
- t15.2023.12.08
- t15.2023.12.10
- t15.2023.12.17
- t15.2023.12.29
- t15.2024.02.25
- t15.2024.03.03
- t15.2024.03.08
- t15.2024.03.15
- t15.2024.03.17
- t15.2024.04.25
- t15.2024.04.28
- t15.2024.05.10
- t15.2024.06.14
- t15.2024.07.19
- t15.2024.07.21
- t15.2024.07.28
- t15.2025.01.10
- t15.2025.01.12
- t15.2025.03.14
- t15.2025.03.16
- t15.2025.03.30
- t15.2025.04.13
dataset_probability_val: # probability of including a trial in the validation set (0 or 1)
- 0 # no val or test data from this day
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 0 # no val or test data from this day
- 1
- 1
- 1
- 0 # no val or test data from this day
- 0 # no val or test data from this day
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1

View File

@@ -0,0 +1,94 @@
# 简化的TPU训练配置 - 更快编译
model:
n_input_features: 512
n_units: 384 # 减少从768到384
rnn_dropout: 0.2 # 减少dropout
rnn_trainable: true
n_layers: 3 # 减少从5到3层
patch_size: 8 # 减少从14到8
patch_stride: 4
input_network:
n_input_layers: 1
input_layer_sizes:
- 512
input_trainable: true
input_layer_dropout: 0.1 # 减少dropout
mode: train
use_amp: true
# TPU分布式训练设置
use_tpu: true
num_tpu_cores: 8
gradient_accumulation_steps: 4 # 增加梯度累积补偿小batch
output_dir: trained_models/simple_rnn
checkpoint_dir: trained_models/simple_rnn/checkpoint
init_from_checkpoint: false
save_best_checkpoint: true
save_val_metrics: true
num_training_batches: 1000 # 先测试1000个batch
lr_scheduler_type: cosine
lr_max: 0.003 # 稍微降低学习率
lr_min: 0.0001
lr_decay_steps: 1000
lr_warmup_steps: 100
lr_max_day: 0.003
lr_min_day: 0.0001
lr_decay_steps_day: 1000
lr_warmup_steps_day: 100
beta0: 0.9
beta1: 0.999
epsilon: 0.1
weight_decay: 0.001
weight_decay_day: 0
seed: 10
grad_norm_clip_value: 5 # 减少梯度裁剪
batches_per_train_log: 50 # 更频繁的日志
batches_per_val_step: 200
log_individual_day_val_PER: true
# 禁用对抗训练进行快速测试
adversarial:
enabled: false # 先禁用对抗训练
dataset:
data_transforms:
white_noise_std: 0.5 # 减少数据增强
constant_offset_std: 0.1
random_walk_std: 0.0
random_walk_axis: -1
static_gain_std: 0.0
random_cut: 1 # 减少随机裁剪
smooth_kernel_size: 50 # 减少平滑核大小
smooth_data: true
smooth_kernel_std: 1
neural_dim: 512
batch_size: 16 # 减少batch size从32到16
n_classes: 41
max_seq_elements: 300 # 减少序列长度
days_per_batch: 2 # 减少每批天数
seed: 1
num_dataloader_workers: 0
loader_shuffle: false
test_percentage: 0.1
dataset_dir: ../data/hdf5_data_final
# 只使用部分session进行快速测试
sessions:
- t15.2023.08.11
- t15.2023.08.13
- t15.2023.08.18
- t15.2023.08.20
dataset_probability_val:
- 0
- 1
- 1
- 1

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,580 @@
import torch
from torch import nn
from typing import cast
class GradientReversalFn(torch.autograd.Function):
"""
Gradient Reversal Layer (GRL)
Forward: identity
Backward: multiply incoming gradient by -lambda
"""
@staticmethod
def forward(ctx, x, lambd: float):
ctx.lambd = lambd
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambd * grad_output, None
def gradient_reverse(x, lambd: float = 1.0):
return GradientReversalFn.apply(x, lambd)
class NoiseModel(nn.Module):
'''
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
'''
def __init__(self,
neural_dim,
n_units,
n_days,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(NoiseModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
# Let Accelerator handle dtype automatically for TPU compatibility
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noise estimation
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.input_size, # Output same dimension as input
num_layers=2,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Learnable initial hidden state - let Accelerator handle dtype
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
def forward(self, x, day_idx, states=None):
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
batch_size = x.size(0)
# Stack all day weights and biases upfront for static indexing
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x = self.day_layer_activation(x)
# XLA-friendly conditional dropout
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization - avoid dynamic allocation
if states is None:
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
# Disable autocast for GRU to avoid dtype mismatches on XLA
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
return output, hidden_states
class CleanSpeechModel(nn.Module):
'''
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(CleanSpeechModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
# Let Accelerator handle dtype automatically for TPU compatibility
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 3-layer GRU for clean speech recognition
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.n_units,
num_layers=3,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Output classification layer
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
def forward(self, x, day_idx, states=None, return_state=False):
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
batch_size = x.size(0)
# Stack all day weights and biases upfront for static indexing
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x = self.day_layer_activation(x)
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization
if states is None:
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits
class NoisySpeechModel(nn.Module):
'''
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0):
super(NoisySpeechModel, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_days = n_days
self.n_classes = n_classes
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Calculate input size after patching
self.input_size = self.neural_dim
if self.patch_size > 0:
self.input_size *= self.patch_size
# 2-layer GRU for noisy speech recognition
self.gru = nn.GRU(
input_size=self.input_size,
hidden_size=self.n_units,
num_layers=2,
dropout=self.rnn_dropout,
batch_first=True,
bidirectional=False,
)
# Initialize GRU parameters
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Output classification layer
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
def forward(self, x, states=None, return_state=False):
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
batch_size = x.size(0)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization
if states is None:
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits
class TripleGRUDecoder(nn.Module):
'''
Three-model adversarial architecture for neural speech decoding
Combines:
- NoiseModel: estimates noise in neural data
- CleanSpeechModel: processes denoised signal for recognition
- NoisySpeechModel: processes noise signal for recognition
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=0,
patch_stride=0,
):
'''
neural_dim (int) - number of channels in a single timestep (e.g. 512)
n_units (int) - number of hidden units in each recurrent layer
n_days (int) - number of days in the dataset
n_classes (int) - number of classes (phonemes)
rnn_dropout (float) - percentage of units to dropout during training
input_dropout (float) - percentage of input units to dropout during training
patch_size (int) - number of timesteps to concat on initial input layer
patch_stride(int) - number of timesteps to stride over when concatenating initial input
'''
super(TripleGRUDecoder, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_classes = n_classes
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Create the three models
self.noise_model = NoiseModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
self.clean_speech_model = CleanSpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
self.noisy_speech_model = NoisySpeechModel(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=rnn_dropout,
input_dropout=input_dropout,
patch_size=patch_size,
patch_stride=patch_stride
)
# Training mode flag
self.training_mode = 'full' # 'full', 'inference'
def _apply_preprocessing(self, x, day_idx):
'''XLA-friendly preprocessing with static operations'''
batch_size = x.size(0)
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx)
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
# Use bmm (batch matrix multiply) which is highly optimized in XLA
# Ensure dtype consistency for mixed precision training
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
x_processed = x_processed.unsqueeze(1)
x_processed = x_processed.permute(0, 3, 1, 2)
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x_processed = x_processed.to(original_dtype)
return x_processed
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
batch_size = x_processed.size(0)
clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype
if x_processed.dtype != clean_gru_dtype:
x_processed = x_processed.to(clean_gru_dtype)
# XLA-friendly hidden state initialization with dtype consistency
if states is None:
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
if states.dtype != clean_gru_dtype:
states = states.to(clean_gru_dtype)
# GRU forward pass (skip preprocessing since input is already processed)
device_type = x_processed.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
# Classification
logits = self.clean_speech_model.out(output)
return logits
def _noisy_forward_with_processed_input(self, x_processed, states=None):
'''Forward pass for NoisySpeechModel with already processed input'''
batch_size = x_processed.size(0)
noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype
if x_processed.dtype != noisy_gru_dtype:
x_processed = x_processed.to(noisy_gru_dtype)
# XLA-friendly hidden state initialization with dtype consistency
if states is None:
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
if states.dtype != noisy_gru_dtype:
states = states.to(noisy_gru_dtype)
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
device_type = x_processed.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
# Classification
logits = self.noisy_speech_model.out(output)
return logits
def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
'''
Three-model adversarial forward pass
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
day_idx (tensor) - tensor of day indices for each example in the batch
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
'''
if mode == 'full':
# Training mode: run all three models
# 1. Noise model estimates noise in the data
noise_output, noise_hidden = self.noise_model(x, day_idx,
states['noise'] if states else None)
# 2. For residual connection, we need x in the same space as noise_output
# Apply the same preprocessing that the models use internally
x_processed = self._apply_preprocessing(x, day_idx)
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
# Ensure dtype consistency between processed input and noise output
if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
# 3. Clean speech model processes denoised signal
denoised_input = x_processed - noise_output # Residual connection in processed space
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
# But we need to reverse the preprocessing first, then let clean model do its own
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
# 4. Noisy speech model processes noise signal directly (no day layers needed)
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
noisy_input = cast(torch.Tensor, noisy_input)
noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
if noisy_input.dtype != noisy_dtype:
noisy_input = noisy_input.to(noisy_dtype)
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
states['noisy'] if states else None)
# XLA-friendly return - use tuple instead of dict for better compilation
if return_state:
return (clean_logits, noisy_logits, noise_output), noise_hidden
return clean_logits, noisy_logits, noise_output
elif mode == 'inference':
# Inference mode: only noise model + clean speech model
# 1. Estimate noise
noise_output, noise_hidden = self.noise_model(x, day_idx,
states['noise'] if states else None)
# 2. For residual connection, we need x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx)
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
# Ensure dtype consistency for mixed precision residual connection
if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
denoised_input = x_processed - noise_output
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
# XLA-friendly return - use tuple for consistency
if return_state:
return clean_logits, noise_hidden
return clean_logits
else:
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3):
'''
Apply combined gradients to noise model parameters
clean_grad (tensor) - gradients from clean speech model output layer
noisy_grad (tensor) - gradients from noisy speech model output layer
'''
# Combine gradients: negative from clean model, positive from noisy model
combined_grad = -clean_grad + noisy_grad
# Apply gradients to noise model parameters
# This is a simplified implementation - in practice you'd want more sophisticated update rules
with torch.no_grad():
for param in self.noise_model.parameters():
if param.grad is not None:
# Scale the combined gradient appropriately
# This is a placeholder - you'd need to implement proper gradient mapping
param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data)
def set_mode(self, mode):
'''Set the operating mode'''
self.training_mode = mode

View File

@@ -0,0 +1,952 @@
import os
# XLA multi-threading optimization - MUST be set before importing torch_xla
# Set these environment variables early to ensure they take effect
if 'TPU_CORES' in os.environ or 'COLAB_TPU_ADDR' in os.environ:
# Enable XLA multi-threading for compilation speedup
os.environ.setdefault('XLA_FLAGS',
'--xla_cpu_multi_thread_eigen=true ' +
'--xla_cpu_enable_fast_math=true ' +
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
# Set PyTorch XLA threading
os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count()))
print(f"Set XLA compilation threads to {os.cpu_count()}")
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import random
import time
import numpy as np
import math
import pathlib
import logging
import sys
import json
import pickle
from contextlib import nullcontext
from dataset import BrainToTextDataset, train_test_split_indicies
from data_augmentations import gauss_smooth
import torchaudio.functional as F # for edit distance
from omegaconf import OmegaConf
# Import Accelerate for TPU support
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.utils import set_seed
# Import XLA after setting environment variables
import torch_xla.core.xla_model as xm
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
torch.backends.cudnn.deterministic = True # makes training more reproducible
torch._dynamo.config.cache_size_limit = 64
from rnn_model import TripleGRUDecoder
class BrainToTextDecoder_Trainer:
"""
This class will initialize and train a brain-to-text phoneme decoder
Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
"""
def __init__(self, args):
'''
args : dictionary of training arguments
'''
# Configure DataLoader behavior for TPU compatibility
dataloader_config = DataLoaderConfiguration(
even_batches=False # Required for batch_size=None DataLoaders on TPU
)
# Initialize Accelerator for TPU/multi-device support
self.use_xla = bool(xm.get_xla_supported_devices())
self.amp_requested = args.get('use_amp', True)
mixed_precision_mode = 'bf16' if self.amp_requested else 'no'
self.accelerator = Accelerator(
mixed_precision=mixed_precision_mode,
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
log_with=None, # We'll use our own logging
project_dir=args.get('output_dir', './output'),
dataloader_config=dataloader_config,
)
# Trainer fields
self.args = args
self.logger = None
self.device = self.accelerator.device # Use accelerator device instead of manual device selection
self.model = None
self.optimizer = None
self.learning_rate_scheduler = None
self.ctc_loss = None
self.best_val_PER = torch.inf # track best PER for checkpointing
self.best_val_loss = torch.inf # track best loss for checkpointing
self.train_dataset = None
self.val_dataset = None
self.train_loader = None
self.val_loader = None
self.transform_args = self.args['dataset']['data_transforms']
# Adversarial training config (safe defaults if not provided)
adv_cfg = self.args.get('adversarial', {})
self.adv_enabled = adv_cfg.get('enabled', False)
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps
# Create output directory
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=True)
# Create checkpoint directory
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
# Set up logging
self.logger = logging.getLogger(__name__)
for handler in self.logger.handlers[:]: # make a copy of the list
self.logger.removeHandler(handler)
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
if args['mode']=='train':
# During training, save logs to file in output directory
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
fh.setFormatter(formatter)
self.logger.addHandler(fh)
# Always print logs to stdout
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
self.logger.addHandler(sh)
# Log device information (managed by Accelerator)
self.logger.info(f'Using device: {self.device}')
self.logger.info(f'Accelerator state: {self.accelerator.state}')
if self.accelerator.num_processes > 1:
self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
if self.use_xla and self.amp_requested:
self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.')
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
if self.args['seed'] != -1:
set_seed(self.args['seed'])
# Initialize the model
self.model = TripleGRUDecoder(
neural_dim = self.args['model']['n_input_features'],
n_units = self.args['model']['n_units'],
n_days = len(self.args['dataset']['sessions']),
n_classes = self.args['dataset']['n_classes'],
rnn_dropout = self.args['model']['rnn_dropout'],
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
patch_size = self.args['model']['patch_size'],
patch_stride = self.args['model']['patch_stride'],
)
if self.use_xla and self.amp_requested:
self.model = self.model.to(torch.bfloat16)
self.logger.info('Converted model parameters to bfloat16 for TPU training.')
self.model_dtype = next(self.model.parameters()).dtype
# Temporarily disable torch.compile for compatibility with new model architecture
# TODO: Re-enable torch.compile once model is stable
# self.logger.info("Using torch.compile")
# self.model = torch.compile(self.model)
self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
self.logger.info(f"Initialized RNN decoding model")
self.logger.info(self.model)
# Log how many parameters are in the model
total_params = sum(p.numel() for p in self.model.parameters())
self.logger.info(f"Model has {total_params:,} parameters")
# Determine how many day-specific parameters are in the model
day_params = 0
for name, param in self.model.named_parameters():
if 'day' in name:
day_params += param.numel()
self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
# Create datasets and dataloaders
train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
# Ensure that there are no duplicate days
if len(set(train_file_paths)) != len(train_file_paths):
raise ValueError("There are duplicate sessions listed in the train dataset")
if len(set(val_file_paths)) != len(val_file_paths):
raise ValueError("There are duplicate sessions listed in the val dataset")
# Split trials into train and test sets
train_trials, _ = train_test_split_indicies(
file_paths = train_file_paths,
test_percentage = 0,
seed = self.args['dataset']['seed'],
bad_trials_dict = None,
)
_, val_trials = train_test_split_indicies(
file_paths = val_file_paths,
test_percentage = 1,
seed = self.args['dataset']['seed'],
bad_trials_dict = None,
)
# Save dictionaries to output directory to know which trials were train vs val
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
json.dump({'train' : train_trials, 'val': val_trials}, f)
# Determine if a only a subset of neural features should be used
feature_subset = None
if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None:
feature_subset = self.args['dataset']['feature_subset']
self.logger.info(f'Using only a subset of features: {feature_subset}')
# train dataset and dataloader
self.train_dataset = BrainToTextDataset(
trial_indicies = train_trials,
split = 'train',
days_per_batch = self.args['dataset']['days_per_batch'],
n_batches = self.args['num_training_batches'],
batch_size = self.args['dataset']['batch_size'],
must_include_days = None,
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
# Our dataset returns full batches, so batch will be a list of single batch dict
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
if len(batch) == 1 and isinstance(batch[0], dict):
return batch[0]
else:
# Fallback for unexpected batch structure
return batch
# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
self.train_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = self.args['dataset']['num_dataloader_workers'],
pin_memory = True,
collate_fn = collate_fn
)
# val dataset and dataloader
self.val_dataset = BrainToTextDataset(
trial_indicies = val_trials,
split = 'test',
days_per_batch = None,
n_batches = None,
batch_size = self.args['dataset']['batch_size'],
must_include_days = None,
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Validation DataLoader with same collate function
self.val_loader = DataLoader(
self.val_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True,
collate_fn = collate_fn # Use same collate function
)
self.logger.info("Successfully initialized datasets")
# Create optimizer, learning rate scheduler, and loss
self.optimizer = self.create_optimizer()
if self.args['lr_scheduler_type'] == 'linear':
self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer = self.optimizer,
start_factor = 1.0,
end_factor = self.args['lr_min'] / self.args['lr_max'],
total_iters = self.args['lr_decay_steps'],
)
elif self.args['lr_scheduler_type'] == 'cosine':
self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
else:
raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
# If a checkpoint is provided, then load from checkpoint
if self.args['init_from_checkpoint']:
self.load_model_checkpoint(self.args['init_checkpoint_path'])
# Set rnn and/or input layers to not trainable if specified
for name, param in self.model.named_parameters():
if not self.args['model']['rnn_trainable'] and 'gru' in name:
param.requires_grad = False
elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
param.requires_grad = False
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
# Let Accelerator handle everything automatically for both GPU and TPU
(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
)
self.model_dtype = next(self.model.parameters()).dtype
self.logger.info("Prepared model and dataloaders with Accelerator")
if self.adv_enabled:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
def autocast_context(self):
"""Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
if self.device.type == 'xla':
return nullcontext()
return self.accelerator.autocast()
def create_optimizer(self):
'''
Create the optimizer with special param groups
Biases and day weights should not be decayed
Day weights should have a separate learning rate
'''
bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
if len(day_params) != 0:
param_groups = [
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
{'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
{'params' : other_params, 'group_type' : 'other'}
]
else:
param_groups = [
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
{'params' : other_params, 'group_type' : 'other'}
]
optim = torch.optim.AdamW(
param_groups,
lr = self.args['lr_max'],
betas = (self.args['beta0'], self.args['beta1']),
eps = self.args['epsilon'],
weight_decay = self.args['weight_decay'],
fused = True
)
return optim
def create_cosine_lr_scheduler(self, optim):
lr_max = self.args['lr_max']
lr_min = self.args['lr_min']
lr_decay_steps = self.args['lr_decay_steps']
lr_max_day = self.args['lr_max_day']
lr_min_day = self.args['lr_min_day']
lr_decay_steps_day = self.args['lr_decay_steps_day']
lr_warmup_steps = self.args['lr_warmup_steps']
lr_warmup_steps_day = self.args['lr_warmup_steps_day']
def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
'''
Create lr lambdas for each param group that implement cosine decay
Different lr lambda decaying for day params vs rest of the model
'''
# Warmup phase
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
# Cosine decay phase
if current_step < decay_steps:
progress = float(current_step - warmup_steps) / float(
max(1, decay_steps - warmup_steps)
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
# Scale from 1.0 to min_lr_ratio
return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
# After cosine decay is complete, maintain min_lr_ratio
return min_lr_ratio
if len(optim.param_groups) == 3:
lr_lambdas = [
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # biases
lambda step: lr_lambda(
step,
lr_min_day / lr_max_day,
lr_decay_steps_day,
lr_warmup_steps_day,
), # day params
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # rest of model weights
]
elif len(optim.param_groups) == 2:
lr_lambdas = [
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # biases
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # rest of model weights
]
else:
raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
return LambdaLR(optim, lr_lambdas, -1)
def load_model_checkpoint(self, load_path):
'''
Load a training checkpoint for distributed training
'''
# Load checkpoint on CPU first to avoid OOM issues
checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
# Get unwrapped model for loading state dict
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
# Device handling is managed by Accelerator, no need to manually move to device
self.logger.info("Loaded model from checkpoint: " + load_path)
def save_model_checkpoint(self, save_path, PER, loss):
'''
Save a training checkpoint using Accelerator for distributed training
'''
# Only save on main process to avoid conflicts
if self.accelerator.is_main_process:
# Unwrap model to get base model for saving
unwrapped_model = self.accelerator.unwrap_model(self.model)
checkpoint = {
'model_state_dict' : unwrapped_model.state_dict(),
'optimizer_state_dict' : self.optimizer.state_dict(),
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
'val_PER' : PER,
'val_loss' : loss
}
torch.save(checkpoint, save_path)
self.logger.info("Saved model to checkpoint: " + save_path)
# Save the args file alongside the checkpoint
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
# Wait for all processes to complete checkpoint saving
self.accelerator.wait_for_everyone()
def create_attention_mask(self, sequence_lengths):
max_length = torch.max(sequence_lengths).item()
batch_size = sequence_lengths.size(0)
# Create a mask for valid key positions (columns)
# Shape: [batch_size, max_length]
key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
key_mask = key_mask < sequence_lengths.unsqueeze(1)
# Expand key_mask to [batch_size, 1, 1, max_length]
# This will be broadcast across all query positions
key_mask = key_mask.unsqueeze(1).unsqueeze(1)
# Create the attention mask of shape [batch_size, 1, max_length, max_length]
# by broadcasting key_mask across all query positions
attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
# Convert boolean mask to float mask:
# - True (valid key positions) -> 0.0 (no change to attention scores)
# - False (padding key positions) -> -inf (will become 0 after softmax)
attention_mask_float = torch.where(attention_mask,
True,
False)
return attention_mask_float
def transform_data(self, features, n_time_steps, mode = 'train'):
'''
Apply various augmentations and smoothing to data
Performing augmentations is much faster on GPU than CPU
'''
# TPU and GPU should now handle data consistently with our improved DataLoader configuration
data_shape = features.shape
batch_size = data_shape[0]
channels = data_shape[-1]
# We only apply these augmentations in training
if mode == 'train':
# add static gain noise
if self.transform_args['static_gain_std'] > 0:
warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
features = torch.matmul(features, warp_mat)
# add white noise
if self.transform_args['white_noise_std'] > 0:
features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
# add constant offset noise
if self.transform_args['constant_offset_std'] > 0:
features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
# add random walk noise
if self.transform_args['random_walk_std'] > 0:
features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
# randomly cutoff part of the data timecourse
if self.transform_args['random_cut'] > 0:
cut = np.random.randint(0, self.transform_args['random_cut'])
features = features[:, cut:, :]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing to data
# This is done in both training and validation
if self.transform_args['smooth_data']:
features = gauss_smooth(
inputs = features,
device = self.device,
smooth_kernel_std = self.transform_args['smooth_kernel_std'],
smooth_kernel_size= self.transform_args['smooth_kernel_size'],
)
if hasattr(self, 'model_dtype'):
features = features.to(self.model_dtype)
return features, n_time_steps
def train(self):
'''
Train the model
'''
# Set model to train mode (specificially to make sure dropout layers are engaged)
self.model.train()
# create vars to track performance
train_losses = []
val_losses = []
val_PERs = []
val_results = []
val_steps_since_improvement = 0
# training params
save_best_checkpoint = self.args.get('save_best_checkpoint', True)
early_stopping = self.args.get('early_stopping', True)
early_stopping_val_steps = self.args['early_stopping_val_steps']
train_start_time = time.time()
# train for specified number of batches
self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...")
for i, batch in enumerate(self.train_loader):
self.model.train()
self.optimizer.zero_grad()
# Train step
start_time = time.time()
# Data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
with self.autocast_context():
# Apply augmentations to the data
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
# Ensure proper dtype handling for TPU mixed precision
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Forward pass: enable full adversarial mode if configured and past warmup
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
if use_full:
clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
else:
logits = self.model(features, day_indicies, None, False, 'inference')
# Calculate CTC Loss
if use_full:
# Clean CTC loss
clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2)
clean_loss = self.ctc_loss(
clean_log_probs,
labels,
adjusted_lens,
phone_seq_lens
)
clean_loss = torch.mean(clean_loss)
# Noisy branch CTC loss让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2)
noisy_loss = self.ctc_loss(
noisy_log_probs,
labels,
adjusted_lens,
phone_seq_lens
)
noisy_loss = torch.mean(noisy_loss)
# Optional noise energy regularization
noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
if self.adv_noise_l2_weight > 0.0:
noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype)
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
loss = self.ctc_loss(
log_probs=log_probs,
targets=labels,
input_lengths=adjusted_lens,
target_lengths=phone_seq_lens
)
loss = torch.mean(loss) # take mean loss over batches
# Use Accelerator's backward for distributed training
self.accelerator.backward(loss)
# Clip gradient using Accelerator's clip_grad_norm_
if self.args['grad_norm_clip_value'] > 0:
grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(),
max_norm = self.args['grad_norm_clip_value'])
self.optimizer.step()
self.learning_rate_scheduler.step()
# Save training metrics
train_step_duration = time.time() - start_time
train_losses.append(loss.detach().item())
# Incrementally log training progress
if i % self.args['batches_per_train_log'] == 0:
self.logger.info(f'Train batch {i}: ' +
f'loss: {(loss.detach().item()):.2f} ' +
f'grad norm: {grad_norm:.2f} '
f'time: {train_step_duration:.3f}')
# Incrementally run a test step
if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
self.logger.info(f"Running test after training batch: {i}")
# Calculate metrics on val data
start_time = time.time()
val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
val_step_duration = time.time() - start_time
# Log info
self.logger.info(f'Val batch {i}: ' +
f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
f'time: {val_step_duration:.3f}')
if self.args['log_individual_day_val_PER']:
for day in val_metrics['day_PERs'].keys():
self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
# Save metrics
val_PERs.append(val_metrics['avg_PER'])
val_losses.append(val_metrics['avg_loss'])
val_results.append(val_metrics)
# Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
new_best = False
if val_metrics['avg_PER'] < self.best_val_PER:
self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
self.best_val_PER = val_metrics['avg_PER']
self.best_val_loss = val_metrics['avg_loss']
new_best = True
elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss):
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
self.best_val_loss = val_metrics['avg_loss']
new_best = True
if new_best:
# Checkpoint if metrics have improved
if save_best_checkpoint:
self.logger.info(f"Checkpointing model")
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
# save validation metrics to pickle file
if self.args['save_val_metrics']:
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
pickle.dump(val_metrics, f)
val_steps_since_improvement = 0
else:
val_steps_since_improvement +=1
# Optionally save this validation checkpoint, regardless of performance
if self.args['save_all_val_steps']:
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
# Early stopping
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
break
# Log final training steps
training_duration = time.time() - train_start_time
self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
# Save final model
if self.args['save_final_model']:
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
train_stats = {}
train_stats['train_losses'] = train_losses
train_stats['val_losses'] = val_losses
train_stats['val_PERs'] = val_PERs
train_stats['val_metrics'] = val_results
return train_stats
def validation(self, loader, return_logits = False, return_data = False):
'''
Calculate metrics on the validation dataset
'''
self.model.eval()
metrics = {}
# Record metrics
if return_logits:
metrics['logits'] = []
metrics['n_time_steps'] = []
if return_data:
metrics['input_features'] = []
metrics['decoded_seqs'] = []
metrics['true_seq'] = []
metrics['phone_seq_lens'] = []
metrics['transcription'] = []
metrics['losses'] = []
metrics['block_nums'] = []
metrics['trial_nums'] = []
metrics['day_indicies'] = []
total_edit_distance = 0
total_seq_length = 0
# Calculate PER for each specific day
day_per = {}
for d in range(len(self.args['dataset']['sessions'])):
if self.args['dataset']['dataset_probability_val'][d] == 1:
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
for i, batch in enumerate(loader):
# Data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Determine if we should perform validation on this batch
day = day_indicies[0].item()
if self.args['dataset']['dataset_probability_val'][day] == 0:
if self.args['log_val_skip_logs']:
self.logger.info(f"Skipping validation on day {day}")
continue
with torch.no_grad():
with self.autocast_context():
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure proper dtype handling for TPU mixed precision
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
logits = self.model(features, day_indicies, None, False, 'inference')
val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
loss = self.ctc_loss(
val_log_probs,
labels,
adjusted_lens,
phone_seq_lens,
)
loss = torch.mean(loss)
metrics['losses'].append(loss.cpu().detach().numpy())
# Calculate PER per day and also avg over entire validation set
batch_edit_distance = 0
decoded_seqs = []
for iterIdx in range(logits.shape[0]):
decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
decoded_seq = decoded_seq.cpu().detach().numpy()
decoded_seq = np.array([i for i in decoded_seq if i != 0])
trueSeq = np.array(
labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
)
batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
decoded_seqs.append(decoded_seq)
day = batch['day_indicies'][0].item()
day_per[day]['total_edit_distance'] += batch_edit_distance
day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
total_edit_distance += batch_edit_distance
total_seq_length += torch.sum(phone_seq_lens)
# Record metrics
if return_logits:
metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
if return_data:
metrics['input_features'].append(batch['input_features'].cpu().numpy())
metrics['decoded_seqs'].append(decoded_seqs)
metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
metrics['losses'].append(loss.detach().item())
metrics['block_nums'].append(batch['block_nums'].numpy())
metrics['trial_nums'].append(batch['trial_nums'].numpy())
metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
if isinstance(total_seq_length, torch.Tensor):
total_length_value = float(total_seq_length.item())
else:
total_length_value = float(total_seq_length)
avg_PER = total_edit_distance / max(total_length_value, 1e-6)
metrics['day_PERs'] = day_per
metrics['avg_PER'] = avg_PER
metrics['avg_loss'] = float(np.mean(metrics['losses']))
return metrics
def inference(self, features, day_indicies, n_time_steps, mode='inference'):
'''
TPU-compatible inference method for generating phoneme logits
'''
self.model.eval()
with torch.no_grad():
with self.autocast_context():
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
return logits
def inference_batch(self, batch, mode='inference'):
'''
Inference method for processing a full batch
'''
self.model.eval()
# Data is automatically moved to device by Accelerator
features = batch['input_features']
day_indicies = batch['day_indicies']
n_time_steps = batch['n_time_steps']
with torch.no_grad():
with self.autocast_context():
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Calculate adjusted sequence lengths for CTC with proper dtype handling
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
return logits, adjusted_lens

View File

@@ -0,0 +1,27 @@
#!/bin/bash
# TPU XLA Multi-threading Environment Setup
# Set these BEFORE starting Python to ensure they take effect
echo "Setting up XLA multi-threading environment..."
# Get CPU core count
CPU_CORES=$(nproc)
echo "Detected $CPU_CORES CPU cores"
# Set XLA compilation flags
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES"
export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES
# Additional XLA optimizations
export XLA_USE_BF16=1
export TPU_CORES=8
# Print current settings
echo "XLA_FLAGS: $XLA_FLAGS"
echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS"
echo "XLA_USE_BF16: $XLA_USE_BF16"
# Start training
echo "Starting TPU training..."
python train_model.py --config_path rnn_args.yaml

View File

@@ -0,0 +1,162 @@
#!/usr/bin/env python3
"""
简化模型测试脚本 - 验证XLA编译是否正常工作
"""
import os
import time
import torch
import torch.nn as nn
# 设置XLA环境变量必须在导入torch_xla之前
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
os.environ['XLA_USE_BF16'] = '1'
print(f"🔧 XLA环境变量设置:")
print(f" CPU核心数: {os.cpu_count()}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
import torch_xla.core.xla_model as xm
class SimpleModel(nn.Module):
"""简化的测试模型"""
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(512, 256)
self.gru = nn.GRU(256, 128, batch_first=True)
self.linear2 = nn.Linear(128, 41) # 41个音素类别
def forward(self, x):
x = torch.relu(self.linear1(x))
x, _ = self.gru(x)
x = self.linear2(x)
return x
def test_xla_compilation():
"""测试XLA编译速度"""
print("\n🚀 开始简化模型XLA编译测试...")
# 检查TPU设备
device = xm.xla_device()
print(f"📱 TPU设备: {device}")
print(f"🌍 TPU World Size: {xm.xrt_world_size()}")
# 创建简化模型
model = SimpleModel().to(device)
print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
# 创建测试数据
batch_size = 8 # 小批次
seq_len = 100 # 短序列
x = torch.randn(batch_size, seq_len, 512, device=device)
print(f"📥 输入形状: {x.shape}")
# 首次前向传播 - 触发XLA编译
print(f"🔄 开始首次前向传播 (XLA编译)...")
start_time = time.time()
with torch.no_grad():
output = model(x)
compile_time = time.time() - start_time
print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}")
print(f"📤 输出形状: {output.shape}")
# 再次前向传播 - 使用编译后的图
print(f"🔄 第二次前向传播 (使用编译后的图)...")
start_time = time.time()
with torch.no_grad():
output2 = model(x)
execution_time = time.time() - start_time
print(f"⚡ 执行完成! 耗时: {execution_time:.4f}")
# 性能对比
speedup = compile_time / execution_time if execution_time > 0 else float('inf')
print(f"\n📈 性能分析:")
print(f" 编译时间: {compile_time:.2f}")
print(f" 执行时间: {execution_time:.4f}")
print(f" 加速比: {speedup:.1f}x")
if compile_time < 60: # 1分钟内编译完成
print("✅ XLA编译正常!")
return True
else:
print("❌ XLA编译过慢可能有问题")
return False
def test_training_step():
"""测试训练步骤"""
print("\n🎯 测试简化训练步骤...")
device = xm.xla_device()
model = SimpleModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 创建训练数据
x = torch.randn(4, 50, 512, device=device)
labels = torch.randint(0, 41, (4, 50), device=device)
print(f"🔄 开始训练步骤 (包含反向传播)...")
start_time = time.time()
# 前向传播
outputs = model(x)
# 计算损失
loss = criterion(outputs.view(-1, 41), labels.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
step_time = time.time() - start_time
print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}")
return step_time < 120 # 2分钟内完成
def main():
print("=" * 60)
print("🧪 XLA编译快速测试")
print("=" * 60)
try:
# 测试1: 简单模型编译
compilation_ok = test_xla_compilation()
if compilation_ok:
# 测试2: 训练步骤
training_ok = test_training_step()
if training_ok:
print("\n✅ 所有测试通过! 可以尝试完整模型训练")
print("💡 建议:")
print(" 1. 确保有足够内存 (32GB+)")
print(" 2. 减小batch_size (比如从32改为16)")
print(" 3. 使用gradient_accumulation_steps补偿")
else:
print("\n⚠️ 训练步骤较慢,建议优化")
else:
print("\n❌ XLA编译有问题需要检查环境")
except Exception as e:
print(f"\n💥 测试失败: {e}")
print("💡 可能的问题:")
print(" - TPU资源不可用")
print(" - PyTorch XLA安装问题")
print(" - 内存不足")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,25 @@
import argparse
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
def main():
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
parser.add_argument('--config_path', default='rnn_args.yaml',
help='Path to configuration file (default: rnn_args.yaml)')
args = parser.parse_args()
# Load configuration
config = OmegaConf.load(args.config_path)
# Initialize trainer
trainer = BrainToTextDecoder_Trainer(config)
# Start training
trainer.train()
print("Training completed successfully!")
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
if __name__ == "__main__":
main()