tpu
This commit is contained in:
79
model_training_nnn_tpu/README.md
Normal file
79
model_training_nnn_tpu/README.md
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
# TPU-Optimized Brain-to-Text Model Training
|
||||||
|
|
||||||
|
This directory contains TPU-optimized code for training the brain-to-text RNN model with advanced adversarial training architecture. The model is based on "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), enhanced with three-model adversarial training and comprehensive XLA optimizations for efficient TPU training.
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
- **Triple-Model Adversarial Architecture**: NoiseModel + CleanSpeechModel + NoisySpeechModel for robust neural decoding
|
||||||
|
- **XLA/TPU Optimizations**: Comprehensive optimizations for fast compilation and efficient TPU utilization
|
||||||
|
- **Mixed Precision Training**: bfloat16 support with full dtype consistency
|
||||||
|
- **Distributed Training**: 8-core TPU support with Accelerate library integration
|
||||||
|
- **687M Parameters**: Large-scale model with patch processing and day-specific adaptations
|
||||||
|
|
||||||
|
For detailed technical documentation, see [TPU_MODEL_SUMMARY.md](TPU_MODEL_SUMMARY.md).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
1. Install the required `b2txt25` conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
|
||||||
|
|
||||||
|
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. See the main [README.md](../README.md) file for more details on the included datasets and the proper `data` directory structure.
|
||||||
|
|
||||||
|
## TPU Training
|
||||||
|
|
||||||
|
### Triple-Model Adversarial Architecture
|
||||||
|
This implementation features an advanced three-model adversarial training system:
|
||||||
|
- **NoiseModel**: 2-layer GRU that estimates noise in neural data
|
||||||
|
- **CleanSpeechModel**: 3-layer GRU that processes denoised signals for speech recognition
|
||||||
|
- **NoisySpeechModel**: 2-layer GRU that processes noise signals for adversarial training
|
||||||
|
|
||||||
|
The architecture uses residual connections and gradient reversal layers (GRL) to improve robustness. All models include day-specific input layers (512x512 linear with softsign activation), patch processing (14 timesteps), and are optimized for XLA compilation on TPU.
|
||||||
|
|
||||||
|
### Training Methods
|
||||||
|
|
||||||
|
#### Option 1: Direct Training
|
||||||
|
```bash
|
||||||
|
conda activate b2txt25
|
||||||
|
python train_model.py --config_path rnn_args.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Option 2: Launcher Script (Recommended)
|
||||||
|
```bash
|
||||||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Option 3: Accelerate
|
||||||
|
```bash
|
||||||
|
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The model trains for 120,000 mini-batches with mixed precision (bfloat16) and distributed training across 8 TPU cores. Expected training time varies based on TPU type and configuration. All hyperparameters are specified in [`rnn_args.yaml`](rnn_args.yaml).
|
||||||
|
|
||||||
|
## Model Configuration
|
||||||
|
|
||||||
|
### Key Configuration Files
|
||||||
|
- **`rnn_args.yaml`**: Main training configuration with adversarial training settings
|
||||||
|
- **`accelerate_config_tpu.yaml`**: Accelerate library configuration for TPU
|
||||||
|
- **`launch_tpu_training.py`**: Convenient TPU training launcher
|
||||||
|
|
||||||
|
### Adversarial Training Settings
|
||||||
|
```yaml
|
||||||
|
adversarial:
|
||||||
|
enabled: true
|
||||||
|
grl_lambda: 0.5 # Gradient Reversal Layer strength
|
||||||
|
noisy_loss_weight: 0.2 # Weight for noisy branch CTC loss
|
||||||
|
noise_l2_weight: 0.0 # L2 regularization on noise output
|
||||||
|
warmup_steps: 0 # Steps before enabling adversarial training
|
||||||
|
```
|
||||||
|
|
||||||
|
### TPU-Specific Settings
|
||||||
|
```yaml
|
||||||
|
use_tpu: true
|
||||||
|
num_tpu_cores: 8
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
use_amp: true # bfloat16 mixed precision
|
||||||
|
batch_size: 32 # Per-core batch size
|
||||||
|
num_dataloader_workers: 0 # Required for TPU
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
Model evaluation using the trained TripleGRUDecoder requires the language model pipeline. Please refer to the main project README for complete evaluation setup instructions. The evaluation scripts in this directory are currently being adapted for TPU compatibility.
|
183
model_training_nnn_tpu/TPU_MODEL_SUMMARY.md
Normal file
183
model_training_nnn_tpu/TPU_MODEL_SUMMARY.md
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# TPU优化的Brain-to-Text模型代码总结
|
||||||
|
|
||||||
|
## 项目概述
|
||||||
|
|
||||||
|
这个目录包含了专门为TPU训练优化的Brain-to-Text RNN模型代码,基于发表在《新英格兰医学杂志》(2024)的"An Accurate and Rapidly Calibrating Speech Neuroprosthesis"论文。该模型将大脑语音运动皮层的神经信号转换为文本,使用RNN模型和n-gram语言模型。
|
||||||
|
|
||||||
|
## 核心架构改进
|
||||||
|
|
||||||
|
### 三模型对抗训练架构 (TripleGRUDecoder)
|
||||||
|
|
||||||
|
替代原来的单一GRU模型,新架构包含三个协同工作的子模型:
|
||||||
|
|
||||||
|
1. **NoiseModel** (2层GRU)
|
||||||
|
- 估计神经数据中的噪声
|
||||||
|
- 输入维度:512 → 输出维度:与输入相同
|
||||||
|
- 作用:从原始信号中分离噪声成分
|
||||||
|
|
||||||
|
2. **CleanSpeechModel** (3层GRU + 分类层)
|
||||||
|
- 处理去噪后的信号进行语音识别
|
||||||
|
- 包含day-specific输入层
|
||||||
|
- 输出:41类音素的logits
|
||||||
|
|
||||||
|
3. **NoisySpeechModel** (2层GRU + 分类层)
|
||||||
|
- 直接处理噪声信号进行语音识别
|
||||||
|
- 用于对抗训练,提高NoiseModel的鲁棒性
|
||||||
|
- 输出:41类音素的logits
|
||||||
|
|
||||||
|
### 对抗训练机制
|
||||||
|
|
||||||
|
- **残差连接**: `denoised_input = x_processed - noise_output`
|
||||||
|
- **梯度反转层(GRL)**: 在训练时对噪声输出应用梯度反转
|
||||||
|
- **多目标损失**: 结合clean和noisy分支的CTC损失
|
||||||
|
|
||||||
|
## TPU/XLA优化特性
|
||||||
|
|
||||||
|
### 1. XLA友好的操作设计
|
||||||
|
|
||||||
|
**静态张量操作替代动态操作**:
|
||||||
|
```python
|
||||||
|
# 优化前 (XLA不友好):
|
||||||
|
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
|
||||||
|
|
||||||
|
# 优化后 (XLA友好):
|
||||||
|
all_day_weights = torch.stack(list(self.day_weights), dim=0)
|
||||||
|
day_weights = torch.index_select(all_day_weights, 0, day_idx)
|
||||||
|
```
|
||||||
|
|
||||||
|
**XLA原语操作**:
|
||||||
|
```python
|
||||||
|
# 使用batch matrix multiplication (bmm)替代einsum
|
||||||
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 混合精度训练的数据类型一致性
|
||||||
|
|
||||||
|
**全面的dtype一致性处理**:
|
||||||
|
- 基础操作中的dtype转换
|
||||||
|
- 补丁处理过程中的dtype保持
|
||||||
|
- 对抗训练残差连接的dtype匹配
|
||||||
|
- 梯度反转层的dtype处理
|
||||||
|
- 隐藏状态初始化的dtype一致性
|
||||||
|
|
||||||
|
### 3. 内存和编译优化
|
||||||
|
|
||||||
|
- **禁用autocast**: 在GRU操作中禁用自动混合精度以避免dtype冲突
|
||||||
|
- **静态形状**: 避免动态批次大小分配
|
||||||
|
- **元组返回**: 使用元组替代字典以获得更好的XLA编译性能
|
||||||
|
|
||||||
|
## 关键文件结构
|
||||||
|
|
||||||
|
### 核心训练文件
|
||||||
|
|
||||||
|
- **`rnn_model.py`**: 包含TripleGRUDecoder和三个子模型的完整实现,具有XLA优化
|
||||||
|
- **`rnn_trainer.py`**: TPU训练器,集成Accelerate库,支持分布式训练
|
||||||
|
- **`train_model.py`**: 简洁的训练启动脚本
|
||||||
|
- **`rnn_args.yaml`**: TPU训练配置文件
|
||||||
|
|
||||||
|
### TPU特定文件
|
||||||
|
|
||||||
|
- **`accelerate_config_tpu.yaml`**: Accelerate库的TPU配置
|
||||||
|
- **`launch_tpu_training.py`**: TPU训练的便捷启动脚本
|
||||||
|
- **`TPU_SETUP_GUIDE.md`**: TPU环境设置指南
|
||||||
|
|
||||||
|
### 辅助文件
|
||||||
|
|
||||||
|
- **`dataset.py`**: 数据集加载和批处理
|
||||||
|
- **`data_augmentations.py`**: 数据增强工具
|
||||||
|
- **`evaluate_model_helpers.py`**: 评估工具函数
|
||||||
|
|
||||||
|
## 训练配置亮点
|
||||||
|
|
||||||
|
### TPU特定设置
|
||||||
|
```yaml
|
||||||
|
# TPU分布式训练设置
|
||||||
|
use_tpu: true
|
||||||
|
num_tpu_cores: 8
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
use_amp: true # bfloat16混合精度
|
||||||
|
|
||||||
|
# 优化的批次配置
|
||||||
|
batch_size: 32 # 每个TPU核心的批次大小
|
||||||
|
num_dataloader_workers: 0 # TPU上设为0避免多进程问题
|
||||||
|
```
|
||||||
|
|
||||||
|
### 对抗训练配置
|
||||||
|
```yaml
|
||||||
|
adversarial:
|
||||||
|
enabled: true
|
||||||
|
grl_lambda: 0.5 # 梯度反转强度
|
||||||
|
noisy_loss_weight: 0.2 # 噪声分支损失权重
|
||||||
|
noise_l2_weight: 0.0 # 噪声输出L2正则化
|
||||||
|
warmup_steps: 0 # 对抗训练预热步数
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型规模
|
||||||
|
|
||||||
|
- **总参数**: ~687M个参数
|
||||||
|
- **神经输入**: 512特征 (每个电极2个特征 × 256个电极)
|
||||||
|
- **GRU隐藏单元**: 768个/层
|
||||||
|
- **输出类别**: 41个音素
|
||||||
|
- **补丁处理**: 14个时间步的补丁,步长为4
|
||||||
|
|
||||||
|
## 数据流
|
||||||
|
|
||||||
|
1. **输入**: 512维神经特征,20ms分辨率
|
||||||
|
2. **Day-specific变换**: 每日特定的线性变换和softsign激活
|
||||||
|
3. **补丁处理**: 将14个时间步连接为更大的输入向量
|
||||||
|
4. **三模型处理**:
|
||||||
|
- NoiseModel估计噪声
|
||||||
|
- CleanSpeechModel处理去噪信号
|
||||||
|
- NoisySpeechModel处理噪声信号(仅训练时)
|
||||||
|
5. **输出**: CTC兼容的音素logits
|
||||||
|
|
||||||
|
## 训练流程
|
||||||
|
|
||||||
|
### 推理模式 (`mode='inference'`):
|
||||||
|
- 只使用NoiseModel + CleanSpeechModel
|
||||||
|
- 计算: `clean_logits = CleanSpeechModel(x - NoiseModel(x))`
|
||||||
|
|
||||||
|
### 完整模式 (`mode='full'`, 训练时):
|
||||||
|
- 使用所有三个模型
|
||||||
|
- 对抗训练与梯度反转
|
||||||
|
- 多目标CTC损失
|
||||||
|
|
||||||
|
## 性能特点
|
||||||
|
|
||||||
|
- **编译优化**: XLA优化实现更快的TPU编译
|
||||||
|
- **内存效率**: bfloat16混合精度减少内存使用
|
||||||
|
- **分布式训练**: 支持8核心TPU并行训练
|
||||||
|
- **数据增强**: 高斯平滑、白噪声、时间抖动等
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 基本训练
|
||||||
|
```bash
|
||||||
|
python train_model.py --config_path rnn_args.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用启动脚本
|
||||||
|
```bash
|
||||||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用Accelerate
|
||||||
|
```bash
|
||||||
|
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 与原始模型的兼容性
|
||||||
|
|
||||||
|
- 保持相同的数学运算和模型架构
|
||||||
|
- 保留所有原始接口
|
||||||
|
- 支持'inference'和'full'两种模式
|
||||||
|
- 向后兼容现有训练脚本
|
||||||
|
|
||||||
|
## 技术创新点
|
||||||
|
|
||||||
|
1. **三模型对抗架构**: 创新的噪声估计和去噪方法
|
||||||
|
2. **XLA优化**: 全面的TPU编译优化
|
||||||
|
3. **混合精度一致性**: 解决了复杂对抗训练中的dtype冲突
|
||||||
|
4. **分布式训练集成**: 无缝的多核心TPU支持
|
||||||
|
|
||||||
|
这个TPU优化版本保持了原始模型的准确性,同时显著提高了训练效率和可扩展性,特别适合大规模神经解码任务的训练。
|
204
model_training_nnn_tpu/TPU_SETUP_GUIDE.md
Normal file
204
model_training_nnn_tpu/TPU_SETUP_GUIDE.md
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
# TPU Training Setup Guide for Brain-to-Text RNN
|
||||||
|
|
||||||
|
This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
### 1. Install PyTorch XLA for TPU Support
|
||||||
|
```bash
|
||||||
|
# Install PyTorch XLA (adjust version as needed)
|
||||||
|
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
|
||||||
|
# Or for specific PyTorch version:
|
||||||
|
pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Install Accelerate Library
|
||||||
|
```bash
|
||||||
|
pip install accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Verify TPU Access
|
||||||
|
```bash
|
||||||
|
# Check if TPU is available
|
||||||
|
python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Setup
|
||||||
|
|
||||||
|
### 1. Enable TPU in Configuration File
|
||||||
|
|
||||||
|
Update your `rnn_args.yaml` file with TPU settings:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# TPU and distributed training settings
|
||||||
|
use_tpu: true # Enable TPU training
|
||||||
|
num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8)
|
||||||
|
gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size
|
||||||
|
dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues
|
||||||
|
use_amp: true # Enable mixed precision (bfloat16)
|
||||||
|
|
||||||
|
# Adjust batch size for multi-core TPU
|
||||||
|
dataset:
|
||||||
|
batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. TPU-Optimized Hyperparameters
|
||||||
|
|
||||||
|
Recommended adjustments for TPU training:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Learning rate scaling for distributed training
|
||||||
|
lr_max: 0.005 # May need to scale with number of cores
|
||||||
|
lr_max_day: 0.005
|
||||||
|
|
||||||
|
# Batch size considerations
|
||||||
|
dataset:
|
||||||
|
batch_size: 8 # Per-core batch size
|
||||||
|
days_per_batch: 4 # Keep consistent across cores
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Launch Options
|
||||||
|
|
||||||
|
### Method 1: Using the TPU Launch Script (Recommended)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic TPU training with 8 cores
|
||||||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||||
|
|
||||||
|
# Check TPU environment only
|
||||||
|
python launch_tpu_training.py --check_only
|
||||||
|
|
||||||
|
# Custom configuration file
|
||||||
|
python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Method 2: Direct Accelerate Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Configure accelerate (one-time setup)
|
||||||
|
accelerate config
|
||||||
|
|
||||||
|
# Or use provided TPU config
|
||||||
|
export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
|
||||||
|
|
||||||
|
# Launch training
|
||||||
|
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Method 3: Manual XLA Launch (Advanced)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set TPU environment variables
|
||||||
|
export TPU_CORES=8
|
||||||
|
export XLA_USE_BF16=1
|
||||||
|
|
||||||
|
# Launch with PyTorch XLA
|
||||||
|
python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key TPU Features Implemented
|
||||||
|
|
||||||
|
### 1. Distributed Training Support
|
||||||
|
- Automatic model parallelization across 8 TPU cores
|
||||||
|
- Synchronized gradient updates across all cores
|
||||||
|
- Proper checkpoint saving/loading for distributed training
|
||||||
|
|
||||||
|
### 2. Mixed Precision Training
|
||||||
|
- Automatic bfloat16 precision for TPU optimization
|
||||||
|
- Faster training with maintained numerical stability
|
||||||
|
- Reduced memory usage
|
||||||
|
|
||||||
|
### 3. TPU-Optimized Data Loading
|
||||||
|
- Single-threaded data loading (num_workers=0) for TPU compatibility
|
||||||
|
- Automatic data distribution across TPU cores
|
||||||
|
- Efficient batch processing
|
||||||
|
|
||||||
|
### 4. Inference Support
|
||||||
|
- TPU-compatible inference methods added to trainer class
|
||||||
|
- `inference()` and `inference_batch()` methods for production use
|
||||||
|
- Automatic mixed precision during inference
|
||||||
|
|
||||||
|
## Performance Optimization Tips
|
||||||
|
|
||||||
|
### 1. Batch Size Tuning
|
||||||
|
- Start with total batch size = 64 (8 cores × 8 per core)
|
||||||
|
- Increase gradually if memory allows
|
||||||
|
- Monitor TPU utilization with `top` command
|
||||||
|
|
||||||
|
### 2. Gradient Accumulation
|
||||||
|
- Use `gradient_accumulation_steps` to simulate larger batch sizes
|
||||||
|
- Effective batch size = batch_size × num_cores × gradient_accumulation_steps
|
||||||
|
|
||||||
|
### 3. Learning Rate Scaling
|
||||||
|
- Consider scaling learning rate with number of cores
|
||||||
|
- Linear scaling: `lr_new = lr_base × num_cores`
|
||||||
|
- May need warmup adjustment for large batch training
|
||||||
|
|
||||||
|
### 4. Memory Management
|
||||||
|
- TPU v3-8: 128GB HBM memory total
|
||||||
|
- TPU v4-8: 512GB HBM memory total
|
||||||
|
- Monitor memory usage to avoid OOM errors
|
||||||
|
|
||||||
|
## Monitoring and Debugging
|
||||||
|
|
||||||
|
### 1. TPU Utilization
|
||||||
|
```bash
|
||||||
|
# Monitor TPU usage
|
||||||
|
watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Training Logs
|
||||||
|
- Training logs include device information and core count
|
||||||
|
- Monitor validation metrics across all cores
|
||||||
|
- Check for synchronization issues in distributed training
|
||||||
|
|
||||||
|
### 3. Common Issues and Solutions
|
||||||
|
|
||||||
|
**Issue**: "No TPU devices found"
|
||||||
|
- **Solution**: Verify TPU runtime is started and accessible
|
||||||
|
|
||||||
|
**Issue**: "DataLoader workers > 0 causes hangs"
|
||||||
|
- **Solution**: Set `dataloader_num_workers: 0` in config
|
||||||
|
|
||||||
|
**Issue**: "Mixed precision errors"
|
||||||
|
- **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16
|
||||||
|
|
||||||
|
**Issue**: "Gradient synchronization timeouts"
|
||||||
|
- **Solution**: Check network connectivity between TPU cores
|
||||||
|
|
||||||
|
## Example Training Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Complete TPU training example
|
||||||
|
cd model_training_nnn
|
||||||
|
|
||||||
|
# 1. Update config for TPU
|
||||||
|
vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8
|
||||||
|
|
||||||
|
# 2. Launch TPU training
|
||||||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||||
|
|
||||||
|
# 3. Monitor training progress
|
||||||
|
tail -f trained_models/baseline_rnn/training_log
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Reference
|
||||||
|
|
||||||
|
### Required TPU Settings
|
||||||
|
```yaml
|
||||||
|
use_tpu: true
|
||||||
|
num_tpu_cores: 8
|
||||||
|
dataloader_num_workers: 0
|
||||||
|
use_amp: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Optional TPU Optimizations
|
||||||
|
```yaml
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
dataset:
|
||||||
|
batch_size: 8 # Per-core batch size
|
||||||
|
mixed_precision: bf16
|
||||||
|
```
|
||||||
|
|
||||||
|
This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.
|
26
model_training_nnn_tpu/accelerate_config_tpu.yaml
Normal file
26
model_training_nnn_tpu/accelerate_config_tpu.yaml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Accelerate Configuration for TPU Training
|
||||||
|
# This file configures Accelerate library for 8-core TPU training
|
||||||
|
# with mixed precision (bfloat16) support
|
||||||
|
|
||||||
|
compute_environment: TPU
|
||||||
|
distributed_type: TPU
|
||||||
|
tpu_name: null # Will use default TPU
|
||||||
|
tpu_zone: null # Will use default zone
|
||||||
|
|
||||||
|
# Mixed precision settings (use bfloat16 for TPU)
|
||||||
|
mixed_precision: bf16
|
||||||
|
|
||||||
|
# Number of TPU cores (v3-8 or v4-8 TPUs have 8 cores)
|
||||||
|
num_processes: 8
|
||||||
|
|
||||||
|
# Enable TPU debugging (set to false for production)
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
|
||||||
|
# Logging settings
|
||||||
|
main_process_port: null
|
||||||
|
machine_rank: 0
|
||||||
|
num_machines: 1
|
||||||
|
|
||||||
|
# Enable automatic optimization
|
||||||
|
use_cpu: false
|
148
model_training_nnn_tpu/check_xla_threads.py
Normal file
148
model_training_nnn_tpu/check_xla_threads.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
XLA Multi-threading Diagnostic Script
|
||||||
|
检查XLA编译是否正确使用多CPU核心
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
def set_xla_environment():
|
||||||
|
"""设置XLA环境变量"""
|
||||||
|
cpu_count = os.cpu_count()
|
||||||
|
|
||||||
|
# 设置XLA环境变量
|
||||||
|
os.environ['XLA_FLAGS'] = (
|
||||||
|
'--xla_cpu_multi_thread_eigen=true '
|
||||||
|
'--xla_cpu_enable_fast_math=true '
|
||||||
|
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||||
|
)
|
||||||
|
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||||||
|
|
||||||
|
print(f"🔧 设置XLA环境变量:")
|
||||||
|
print(f" CPU核心数: {cpu_count}")
|
||||||
|
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||||||
|
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
def monitor_cpu_usage(duration=30, interval=1):
|
||||||
|
"""监控CPU使用情况"""
|
||||||
|
print(f"🔍 监控CPU使用情况 {duration}秒...")
|
||||||
|
|
||||||
|
cpu_usage_data = []
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while time.time() - start_time < duration:
|
||||||
|
# 获取每个CPU核心的使用率
|
||||||
|
cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True)
|
||||||
|
cpu_usage_data.append(cpu_percent_per_core)
|
||||||
|
|
||||||
|
# 实时显示
|
||||||
|
active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10)
|
||||||
|
print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, "
|
||||||
|
f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%",
|
||||||
|
end='\r')
|
||||||
|
|
||||||
|
print() # 换行
|
||||||
|
|
||||||
|
# 分析结果
|
||||||
|
if cpu_usage_data:
|
||||||
|
avg_usage_per_core = [
|
||||||
|
sum(core_data) / len(cpu_usage_data)
|
||||||
|
for core_data in zip(*cpu_usage_data)
|
||||||
|
]
|
||||||
|
|
||||||
|
active_cores = sum(1 for avg in avg_usage_per_core if avg > 5)
|
||||||
|
max_usage = max(avg_usage_per_core)
|
||||||
|
|
||||||
|
print(f"📊 CPU使用分析:")
|
||||||
|
print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}")
|
||||||
|
print(f" 最高平均使用率: {max_usage:.1f}%")
|
||||||
|
|
||||||
|
for i, usage in enumerate(avg_usage_per_core):
|
||||||
|
status = "🟢" if usage > 10 else "🔴" if usage > 5 else "⚫"
|
||||||
|
print(f" CPU核心 {i}: {usage:.1f}% {status}")
|
||||||
|
|
||||||
|
return active_cores > 1
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_xla_compilation():
|
||||||
|
"""测试XLA编译"""
|
||||||
|
print(f"🚀 开始XLA编译测试...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
print(f"✅ PyTorch XLA导入成功")
|
||||||
|
print(f" XLA设备: {xm.xla_device()}")
|
||||||
|
print(f" XLA world size: {xm.xrt_world_size()}")
|
||||||
|
|
||||||
|
# 创建一个简单的计算图进行编译
|
||||||
|
device = xm.xla_device()
|
||||||
|
|
||||||
|
print(f"🔄 创建测试计算图...")
|
||||||
|
x = torch.randn(100, 100, device=device)
|
||||||
|
y = torch.randn(100, 100, device=device)
|
||||||
|
|
||||||
|
print(f"🔄 执行矩阵运算 (将触发XLA编译)...")
|
||||||
|
|
||||||
|
# 启动CPU监控
|
||||||
|
monitor_thread = threading.Thread(
|
||||||
|
target=lambda: monitor_cpu_usage(20, 0.5),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
monitor_thread.start()
|
||||||
|
|
||||||
|
# 执行计算,触发编译
|
||||||
|
for i in range(10):
|
||||||
|
z = torch.matmul(x, y)
|
||||||
|
z = torch.relu(z)
|
||||||
|
z = torch.matmul(z, x.T)
|
||||||
|
if i == 0:
|
||||||
|
print(f"🔄 首次计算完成 (XLA编译应该正在进行)...")
|
||||||
|
elif i == 5:
|
||||||
|
print(f"🔄 第6次计算完成...")
|
||||||
|
|
||||||
|
# 等待监控完成
|
||||||
|
monitor_thread.join(timeout=25)
|
||||||
|
|
||||||
|
print(f"✅ XLA测试完成")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ PyTorch XLA导入失败: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ XLA测试失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("🧪 XLA多线程编译诊断工具")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 1. 设置环境
|
||||||
|
set_xla_environment()
|
||||||
|
|
||||||
|
# 2. 测试XLA编译并监控CPU
|
||||||
|
success = test_xla_compilation()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
if success:
|
||||||
|
print("✅ 诊断完成")
|
||||||
|
print("💡 如果看到多个CPU核心被激活,说明XLA多线程工作正常")
|
||||||
|
print("💡 如果只有1-2个核心活跃,可能需要其他优化方法")
|
||||||
|
else:
|
||||||
|
print("❌ 诊断失败")
|
||||||
|
print("💡 请检查PyTorch XLA安装和TPU环境")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
37
model_training_nnn_tpu/data_augmentations.py
Normal file
37
model_training_nnn_tpu/data_augmentations.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from scipy.ndimage import gaussian_filter1d
|
||||||
|
|
||||||
|
def gauss_smooth(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='same'):
|
||||||
|
"""
|
||||||
|
Applies a 1D Gaussian smoothing operation with PyTorch to smooth the data along the time axis.
|
||||||
|
Args:
|
||||||
|
inputs (tensor : B x T x N): A 3D tensor with batch size B, time steps T, and number of features N.
|
||||||
|
Assumed to already be on the correct device (e.g., GPU).
|
||||||
|
kernelSD (float): Standard deviation of the Gaussian smoothing kernel.
|
||||||
|
padding (str): Padding mode, either 'same' or 'valid'.
|
||||||
|
device (str): Device to use for computation (e.g., 'cuda' or 'cpu').
|
||||||
|
Returns:
|
||||||
|
smoothed (tensor : B x T x N): A smoothed 3D tensor with batch size B, time steps T, and number of features N.
|
||||||
|
"""
|
||||||
|
# Get Gaussian kernel
|
||||||
|
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
|
||||||
|
inp[smooth_kernel_size // 2] = 1
|
||||||
|
gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)
|
||||||
|
validIdx = np.argwhere(gaussKernel > 0.01)
|
||||||
|
gaussKernel = gaussKernel[validIdx]
|
||||||
|
gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))
|
||||||
|
|
||||||
|
# Convert to tensor
|
||||||
|
gaussKernel = torch.tensor(gaussKernel, dtype=torch.float32, device=device)
|
||||||
|
gaussKernel = gaussKernel.view(1, 1, -1) # [1, 1, kernel_size]
|
||||||
|
|
||||||
|
# Prepare convolution
|
||||||
|
B, T, C = inputs.shape
|
||||||
|
inputs = inputs.permute(0, 2, 1) # [B, C, T]
|
||||||
|
gaussKernel = gaussKernel.repeat(C, 1, 1) # [C, 1, kernel_size]
|
||||||
|
|
||||||
|
# Perform convolution
|
||||||
|
smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)
|
||||||
|
return smoothed.permute(0, 2, 1) # [B, T, C]
|
336
model_training_nnn_tpu/dataset.py
Normal file
336
model_training_nnn_tpu/dataset.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
import math
|
||||||
|
|
||||||
|
class BrainToTextDataset(Dataset):
|
||||||
|
'''
|
||||||
|
Dataset for brain-to-text data
|
||||||
|
|
||||||
|
Returns an entire batch of data instead of a single example
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
trial_indicies,
|
||||||
|
n_batches,
|
||||||
|
split = 'train',
|
||||||
|
batch_size = 64,
|
||||||
|
days_per_batch = 1,
|
||||||
|
random_seed = -1,
|
||||||
|
must_include_days = None,
|
||||||
|
feature_subset = None
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values
|
||||||
|
n_batches: (int) - number of random training batches to create
|
||||||
|
split: (string) - string specifying if this is a train or test dataset
|
||||||
|
batch_size: (int) - number of examples to include in batch returned from __getitem_()
|
||||||
|
days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates
|
||||||
|
to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch
|
||||||
|
random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random
|
||||||
|
must_include_days ([int]) - list of days that must be included in every batch
|
||||||
|
feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
if random_seed != -1:
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
torch.manual_seed(random_seed)
|
||||||
|
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
# Ensure the split is valid
|
||||||
|
if self.split not in ['train', 'test']:
|
||||||
|
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
|
||||||
|
|
||||||
|
self.days_per_batch = days_per_batch
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
self.n_batches = n_batches
|
||||||
|
|
||||||
|
self.days = {}
|
||||||
|
self.n_trials = 0
|
||||||
|
self.trial_indicies = trial_indicies
|
||||||
|
self.n_days = len(trial_indicies.keys())
|
||||||
|
|
||||||
|
self.feature_subset = feature_subset
|
||||||
|
|
||||||
|
# Calculate total number of trials in the dataset
|
||||||
|
for d in trial_indicies:
|
||||||
|
self.n_trials += len(trial_indicies[d]['trials'])
|
||||||
|
|
||||||
|
if must_include_days is not None and len(must_include_days) > days_per_batch:
|
||||||
|
raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}')
|
||||||
|
|
||||||
|
if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train':
|
||||||
|
raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset')
|
||||||
|
|
||||||
|
if must_include_days is not None:
|
||||||
|
# Map must_include_days to correct indicies if they are negative
|
||||||
|
for i, d in enumerate(must_include_days):
|
||||||
|
if d < 0:
|
||||||
|
must_include_days[i] = self.n_days + d
|
||||||
|
|
||||||
|
self.must_include_days = must_include_days
|
||||||
|
|
||||||
|
# Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error
|
||||||
|
if self.split == 'train' and self.days_per_batch > self.n_days:
|
||||||
|
raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.')
|
||||||
|
|
||||||
|
|
||||||
|
if self.split == 'train':
|
||||||
|
self.batch_index = self.create_batch_index_train()
|
||||||
|
else:
|
||||||
|
self.batch_index = self.create_batch_index_test()
|
||||||
|
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
'''
|
||||||
|
How many batches are in this dataset.
|
||||||
|
Because training data is sampled randomly, there is no fixed dataset length,
|
||||||
|
however this method is required for DataLoader to work
|
||||||
|
'''
|
||||||
|
return self.n_batches if self.n_batches is not None else 0
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
'''
|
||||||
|
Gets an entire batch of data from the dataset, not just a single item
|
||||||
|
'''
|
||||||
|
batch = {
|
||||||
|
'input_features' : [],
|
||||||
|
'seq_class_ids' : [],
|
||||||
|
'n_time_steps' : [],
|
||||||
|
'phone_seq_lens' : [],
|
||||||
|
'day_indicies' : [],
|
||||||
|
'transcriptions' : [],
|
||||||
|
'block_nums' : [],
|
||||||
|
'trial_nums' : [],
|
||||||
|
}
|
||||||
|
|
||||||
|
index = self.batch_index[idx]
|
||||||
|
|
||||||
|
# Iterate through each day in the index
|
||||||
|
for d in index.keys():
|
||||||
|
|
||||||
|
# Open the hdf5 file for that day
|
||||||
|
with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f:
|
||||||
|
|
||||||
|
# For each trial in the selected trials in that day
|
||||||
|
for t in index[d]:
|
||||||
|
|
||||||
|
try:
|
||||||
|
g = f[f'trial_{t:04d}']
|
||||||
|
|
||||||
|
# Remove features is neccessary
|
||||||
|
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # neural data - convert to bf16 for TPU compatibility
|
||||||
|
if self.feature_subset:
|
||||||
|
input_features = input_features[:,self.feature_subset]
|
||||||
|
|
||||||
|
batch['input_features'].append(input_features)
|
||||||
|
|
||||||
|
batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels
|
||||||
|
batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions
|
||||||
|
batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding
|
||||||
|
batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding
|
||||||
|
batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers
|
||||||
|
batch['block_nums'].append(g.attrs['block_num'])
|
||||||
|
batch['trial_nums'].append(g.attrs['trial_num'])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
|
||||||
|
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||||||
|
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
|
||||||
|
|
||||||
|
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
|
||||||
|
batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens'])
|
||||||
|
batch['day_indicies'] = torch.tensor(batch['day_indicies'])
|
||||||
|
batch['transcriptions'] = torch.stack(batch['transcriptions'])
|
||||||
|
batch['block_nums'] = torch.tensor(batch['block_nums'])
|
||||||
|
batch['trial_nums'] = torch.tensor(batch['trial_nums'])
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def create_batch_index_train(self):
|
||||||
|
'''
|
||||||
|
Create an index that maps a batch_number to batch_size number of trials
|
||||||
|
|
||||||
|
Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days
|
||||||
|
(or as even as possible if batch_size is not divisible by days_per_batch)
|
||||||
|
'''
|
||||||
|
|
||||||
|
batch_index = {}
|
||||||
|
|
||||||
|
# Precompute the days that are not in must_include_days
|
||||||
|
if self.must_include_days is not None:
|
||||||
|
non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days]
|
||||||
|
|
||||||
|
for batch_idx in range(self.n_batches):
|
||||||
|
batch = {}
|
||||||
|
|
||||||
|
# Which days will be used for this batch. Picked randomly without replacement
|
||||||
|
# TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day
|
||||||
|
|
||||||
|
# If must_include_days is not empty, we will use those days and then randomly sample the rest
|
||||||
|
if self.must_include_days is not None and len(self.must_include_days) > 0:
|
||||||
|
|
||||||
|
days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False)))
|
||||||
|
|
||||||
|
# Otherwise we will select random days without replacement
|
||||||
|
else:
|
||||||
|
days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False)
|
||||||
|
|
||||||
|
# How many trials will be sampled from each day
|
||||||
|
num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials
|
||||||
|
|
||||||
|
for d in days:
|
||||||
|
|
||||||
|
# Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem
|
||||||
|
trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True)
|
||||||
|
batch[d] = trial_idxs
|
||||||
|
|
||||||
|
# Remove extra trials
|
||||||
|
extra_trials = (num_trials * len(days)) - self.batch_size
|
||||||
|
|
||||||
|
# While we still have extra trials, remove the last trial from a random day
|
||||||
|
while extra_trials > 0:
|
||||||
|
d = np.random.choice(days)
|
||||||
|
batch[d] = batch[d][:-1]
|
||||||
|
extra_trials -= 1
|
||||||
|
|
||||||
|
batch_index[batch_idx] = batch
|
||||||
|
|
||||||
|
return batch_index
|
||||||
|
|
||||||
|
def create_batch_index_test(self):
|
||||||
|
'''
|
||||||
|
Create an index that is all validation/testing data in batches of up to self.batch_size
|
||||||
|
|
||||||
|
If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size
|
||||||
|
|
||||||
|
This index will ensures that every trial in the validation set is seen once and only once
|
||||||
|
'''
|
||||||
|
batch_index = {}
|
||||||
|
batch_idx = 0
|
||||||
|
|
||||||
|
for d in self.trial_indicies.keys():
|
||||||
|
|
||||||
|
# Calculate how many batches we need for this day
|
||||||
|
num_trials = len(self.trial_indicies[d]['trials'])
|
||||||
|
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
|
||||||
|
|
||||||
|
# Create batches for this day
|
||||||
|
for i in range(num_batches):
|
||||||
|
start_idx = i * self.batch_size
|
||||||
|
end_idx = min((i + 1) * self.batch_size, num_trials)
|
||||||
|
|
||||||
|
# Get the trial indices for this batch
|
||||||
|
batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx]
|
||||||
|
|
||||||
|
# Add to batch_index
|
||||||
|
batch_index[batch_idx] = {d : batch_trials}
|
||||||
|
batch_idx += 1
|
||||||
|
|
||||||
|
return batch_index
|
||||||
|
|
||||||
|
def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None):
|
||||||
|
'''
|
||||||
|
Split data from file_paths into train and test splits
|
||||||
|
Returns two dictionaries that detail which trials in each day will be a part of that split:
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
0: trials[1,2,3], session_path: 'path'
|
||||||
|
1: trials[2,5,6], session_path: 'path'
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths (list): List of file paths to the hdf5 files containing the data
|
||||||
|
test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing
|
||||||
|
seed (int): Seed for reproducibility. If set to -1, the split will be random
|
||||||
|
bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as:
|
||||||
|
{
|
||||||
|
'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
|
||||||
|
'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# Set seed for reporoducibility
|
||||||
|
if seed != -1:
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
# Get trials in each day
|
||||||
|
trials_per_day = {}
|
||||||
|
for i, path in enumerate(file_paths):
|
||||||
|
# Handle both Windows and Unix path separators
|
||||||
|
path_parts = path.replace('\\', '/').split('/')
|
||||||
|
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||||||
|
|
||||||
|
good_trial_indices = []
|
||||||
|
|
||||||
|
if os.path.exists(path):
|
||||||
|
with h5py.File(path, 'r') as f:
|
||||||
|
num_trials = len(list(f.keys()))
|
||||||
|
for t in range(num_trials):
|
||||||
|
key = f'trial_{t:04d}'
|
||||||
|
|
||||||
|
block_num = f[key].attrs['block_num']
|
||||||
|
trial_num = f[key].attrs['trial_num']
|
||||||
|
|
||||||
|
if (
|
||||||
|
bad_trials_dict is not None
|
||||||
|
and session in bad_trials_dict
|
||||||
|
and str(block_num) in bad_trials_dict[session]
|
||||||
|
and trial_num in bad_trials_dict[session][str(block_num)]
|
||||||
|
):
|
||||||
|
# print(f'Bad trial: {session}_{block_num}_{trial_num}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
good_trial_indices.append(t)
|
||||||
|
|
||||||
|
trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path}
|
||||||
|
|
||||||
|
# Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training
|
||||||
|
train_trials = {}
|
||||||
|
test_trials = {}
|
||||||
|
|
||||||
|
for day in trials_per_day.keys():
|
||||||
|
|
||||||
|
num_trials = trials_per_day[day]['num_trials']
|
||||||
|
|
||||||
|
# Generate all trial indices for this day (assuming 0-indexed)
|
||||||
|
all_trial_indices = trials_per_day[day]['trial_indices']
|
||||||
|
|
||||||
|
# If test_percentage is 0 or 1, we can just assign all trials to either train or test
|
||||||
|
if test_percentage == 0:
|
||||||
|
train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||||
|
test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif test_percentage == 1:
|
||||||
|
train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
|
||||||
|
test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Calculate how many trials to use for testing
|
||||||
|
num_test = max(1, int(num_trials * test_percentage))
|
||||||
|
|
||||||
|
# Randomly select indices for testing
|
||||||
|
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
|
||||||
|
|
||||||
|
# Remaining indices go to training
|
||||||
|
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
|
||||||
|
|
||||||
|
# Store the split indices
|
||||||
|
train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||||
|
test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||||
|
|
||||||
|
return train_trials, test_trials
|
304
model_training_nnn_tpu/evaluate_model.py
Normal file
304
model_training_nnn_tpu/evaluate_model.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import redis
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
import editdistance
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from rnn_model import GRUDecoder
|
||||||
|
from evaluate_model_helpers import *
|
||||||
|
|
||||||
|
# argument parser for command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
|
||||||
|
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
|
||||||
|
help='Path to the pretrained model directory (relative to the current working directory).')
|
||||||
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
||||||
|
help='Path to the dataset directory (relative to the current working directory).')
|
||||||
|
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
|
||||||
|
help='Evaluation type: "val" for validation set, "test" for test set. '
|
||||||
|
'If "test", ground truth is not available.')
|
||||||
|
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
||||||
|
help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
|
||||||
|
parser.add_argument('--gpu_number', type=int, default=-1,
|
||||||
|
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# paths to model and data directories
|
||||||
|
# Note: these paths are relative to the current working directory
|
||||||
|
model_path = args.model_path
|
||||||
|
data_dir = args.data_dir
|
||||||
|
|
||||||
|
# define evaluation type
|
||||||
|
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
|
||||||
|
|
||||||
|
# load csv file
|
||||||
|
b2txt_csv_df = pd.read_csv(args.csv_path)
|
||||||
|
|
||||||
|
# load model args
|
||||||
|
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
|
||||||
|
|
||||||
|
# set up gpu device
|
||||||
|
gpu_number = args.gpu_number
|
||||||
|
if torch.cuda.is_available() and gpu_number >= 0:
|
||||||
|
if gpu_number >= torch.cuda.device_count():
|
||||||
|
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
|
||||||
|
device = f'cuda:{gpu_number}'
|
||||||
|
device = torch.device(device)
|
||||||
|
print(f'Using {device} for model inference.')
|
||||||
|
else:
|
||||||
|
if gpu_number >= 0:
|
||||||
|
print(f'GPU number {gpu_number} requested but not available.')
|
||||||
|
print('Using CPU for model inference.')
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
# define model
|
||||||
|
model = GRUDecoder(
|
||||||
|
neural_dim = model_args['model']['n_input_features'],
|
||||||
|
n_units = model_args['model']['n_units'],
|
||||||
|
n_days = len(model_args['dataset']['sessions']),
|
||||||
|
n_classes = model_args['dataset']['n_classes'],
|
||||||
|
rnn_dropout = model_args['model']['rnn_dropout'],
|
||||||
|
input_dropout = model_args['model']['input_network']['input_layer_dropout'],
|
||||||
|
n_layers = model_args['model']['n_layers'],
|
||||||
|
patch_size = model_args['model']['patch_size'],
|
||||||
|
patch_stride = model_args['model']['patch_stride'],
|
||||||
|
)
|
||||||
|
|
||||||
|
# load model weights
|
||||||
|
checkpoint = torch.load(
|
||||||
|
os.path.join(model_path, 'checkpoint/best_checkpoint'),
|
||||||
|
map_location=device,
|
||||||
|
weights_only=False,
|
||||||
|
)
|
||||||
|
# rename keys to not start with "module." (happens if model was saved with DataParallel)
|
||||||
|
for key in list(checkpoint['model_state_dict'].keys()):
|
||||||
|
checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
|
||||||
|
checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
|
# add model to device
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# set model to eval mode
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# load data for each session
|
||||||
|
test_data = {}
|
||||||
|
total_test_trials = 0
|
||||||
|
for session in model_args['dataset']['sessions']:
|
||||||
|
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
|
||||||
|
if f'data_{eval_type}.hdf5' in files:
|
||||||
|
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
|
||||||
|
|
||||||
|
data = load_h5py_file(eval_file, b2txt_csv_df)
|
||||||
|
test_data[session] = data
|
||||||
|
|
||||||
|
total_test_trials += len(test_data[session]["neural_features"])
|
||||||
|
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
|
||||||
|
print(f'Total number of {eval_type} trials: {total_test_trials}')
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# put neural data through the pretrained model to get phoneme predictions (logits)
|
||||||
|
with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
|
||||||
|
for session, data in test_data.items():
|
||||||
|
|
||||||
|
data['logits'] = []
|
||||||
|
data['pred_seq'] = []
|
||||||
|
input_layer = model_args['dataset']['sessions'].index(session)
|
||||||
|
|
||||||
|
for trial in range(len(data['neural_features'])):
|
||||||
|
# get neural input for the trial
|
||||||
|
neural_input = data['neural_features'][trial]
|
||||||
|
|
||||||
|
# add batch dimension
|
||||||
|
neural_input = np.expand_dims(neural_input, axis=0)
|
||||||
|
|
||||||
|
# convert to torch tensor
|
||||||
|
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# run decoding step
|
||||||
|
logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)
|
||||||
|
data['logits'].append(logits)
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
|
# convert logits to phoneme sequences and print them out
|
||||||
|
for session, data in test_data.items():
|
||||||
|
data['pred_seq'] = []
|
||||||
|
for trial in range(len(data['logits'])):
|
||||||
|
logits = data['logits'][trial][0]
|
||||||
|
pred_seq = np.argmax(logits, axis=-1)
|
||||||
|
# remove blanks (0)
|
||||||
|
pred_seq = [int(p) for p in pred_seq if p != 0]
|
||||||
|
# remove consecutive duplicates
|
||||||
|
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
|
||||||
|
# convert to phonemes
|
||||||
|
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
|
||||||
|
# add to data
|
||||||
|
data['pred_seq'].append(pred_seq)
|
||||||
|
|
||||||
|
# print out the predicted sequences
|
||||||
|
block_num = data['block_num'][trial]
|
||||||
|
trial_num = data['trial_num'][trial]
|
||||||
|
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
|
||||||
|
if eval_type == 'val':
|
||||||
|
sentence_label = data['sentence_label'][trial]
|
||||||
|
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
|
||||||
|
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
||||||
|
|
||||||
|
print(f'Sentence label: {sentence_label}')
|
||||||
|
print(f'True sequence: {" ".join(true_seq)}')
|
||||||
|
print(f'Predicted Sequence: {" ".join(pred_seq)}')
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# language model inference via redis
|
||||||
|
# make sure that the standalone language model is running on the localhost redis ip
|
||||||
|
# see README.md for instructions on how to run the language model
|
||||||
|
|
||||||
|
def connect_to_redis_with_retry(host, port, password, db=0, max_retries=10, retry_delay=3):
|
||||||
|
"""Connect to Redis with retry logic"""
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
print(f"Attempting to connect to Redis at {host}:{port} (attempt {attempt + 1}/{max_retries})...")
|
||||||
|
r = redis.Redis(host=host, port=port, db=db, password=password)
|
||||||
|
r.ping() # Test the connection
|
||||||
|
print(f"Successfully connected to Redis at {host}:{port}")
|
||||||
|
return r
|
||||||
|
except redis.exceptions.ConnectionError as e:
|
||||||
|
print(f"Redis connection failed (attempt {attempt + 1}/{max_retries}): {e}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
print(f"Retrying in {retry_delay} seconds...")
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
print("Max retries reached. Could not connect to Redis.")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error connecting to Redis: {e}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
print(f"Retrying in {retry_delay} seconds...")
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
r = connect_to_redis_with_retry('hs.zchens.cn', 6379, 'admin01')
|
||||||
|
r.flushall() # clear all streams in redis
|
||||||
|
|
||||||
|
# define redis streams for the remote language model
|
||||||
|
remote_lm_input_stream = 'remote_lm_input'
|
||||||
|
remote_lm_output_partial_stream = 'remote_lm_output_partial'
|
||||||
|
remote_lm_output_final_stream = 'remote_lm_output_final'
|
||||||
|
|
||||||
|
# set timestamps for last entries seen in the redis streams
|
||||||
|
remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
|
||||||
|
remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
|
||||||
|
remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
|
||||||
|
remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
|
||||||
|
remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
|
||||||
|
|
||||||
|
lm_results = {
|
||||||
|
'session': [],
|
||||||
|
'block': [],
|
||||||
|
'trial': [],
|
||||||
|
'true_sentence': [],
|
||||||
|
'pred_sentence': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# loop through all trials and put logits into the remote language model to get text predictions
|
||||||
|
# note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090)
|
||||||
|
with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
|
||||||
|
for session in test_data.keys():
|
||||||
|
for trial in range(len(test_data[session]['logits'])):
|
||||||
|
# get trial logits and rearrange them for the LM
|
||||||
|
logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
|
||||||
|
|
||||||
|
# reset language model
|
||||||
|
remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
|
||||||
|
|
||||||
|
'''
|
||||||
|
# update language model parameters
|
||||||
|
remote_lm_done_updating_lastEntrySeen = update_remote_lm_params(
|
||||||
|
r,
|
||||||
|
remote_lm_done_updating_lastEntrySeen,
|
||||||
|
acoustic_scale=0.35,
|
||||||
|
blank_penalty=90.0,
|
||||||
|
alpha=0.55,
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
|
||||||
|
# put logits into LM
|
||||||
|
remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
|
||||||
|
r,
|
||||||
|
remote_lm_input_stream,
|
||||||
|
remote_lm_output_partial_stream,
|
||||||
|
remote_lm_output_partial_lastEntrySeen,
|
||||||
|
logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
# finalize remote LM
|
||||||
|
remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
|
||||||
|
r,
|
||||||
|
remote_lm_output_final_stream,
|
||||||
|
remote_lm_output_final_lastEntrySeen,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the best candidate sentence
|
||||||
|
best_candidate_sentence = lm_out['candidate_sentences'][0]
|
||||||
|
|
||||||
|
# store results
|
||||||
|
lm_results['session'].append(session)
|
||||||
|
lm_results['block'].append(test_data[session]['block_num'][trial])
|
||||||
|
lm_results['trial'].append(test_data[session]['trial_num'][trial])
|
||||||
|
if eval_type == 'val':
|
||||||
|
lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
|
||||||
|
else:
|
||||||
|
lm_results['true_sentence'].append(None)
|
||||||
|
lm_results['pred_sentence'].append(best_candidate_sentence)
|
||||||
|
|
||||||
|
# update progress bar
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
|
# if using the validation set, lets calculate the aggregate word error rate (WER)
|
||||||
|
if eval_type == 'val':
|
||||||
|
total_true_length = 0
|
||||||
|
total_edit_distance = 0
|
||||||
|
|
||||||
|
lm_results['edit_distance'] = []
|
||||||
|
lm_results['num_words'] = []
|
||||||
|
|
||||||
|
for i in range(len(lm_results['pred_sentence'])):
|
||||||
|
true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
|
||||||
|
pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
|
||||||
|
ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
|
||||||
|
|
||||||
|
total_true_length += len(true_sentence.split())
|
||||||
|
total_edit_distance += ed
|
||||||
|
|
||||||
|
lm_results['edit_distance'].append(ed)
|
||||||
|
lm_results['num_words'].append(len(true_sentence.split()))
|
||||||
|
|
||||||
|
print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
|
||||||
|
print(f'True sentence: {true_sentence}')
|
||||||
|
print(f'Predicted sentence: {pred_sentence}')
|
||||||
|
print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%')
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f'Total true sentence length: {total_true_length}')
|
||||||
|
print(f'Total edit distance: {total_edit_distance}')
|
||||||
|
print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
|
||||||
|
|
||||||
|
|
||||||
|
# write predicted sentences to a csv file. put a timestamp in the filename (YYYYMMDD_HHMMSS)
|
||||||
|
output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.csv')
|
||||||
|
ids = [i for i in range(len(lm_results['pred_sentence']))]
|
||||||
|
df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']})
|
||||||
|
df_out.to_csv(output_file, index=False)
|
297
model_training_nnn_tpu/evaluate_model_helpers.py
Normal file
297
model_training_nnn_tpu/evaluate_model_helpers.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import h5py
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
|
||||||
|
from data_augmentations import gauss_smooth
|
||||||
|
|
||||||
|
LOGIT_TO_PHONEME = [
|
||||||
|
'BLANK',
|
||||||
|
'AA', 'AE', 'AH', 'AO', 'AW',
|
||||||
|
'AY', 'B', 'CH', 'D', 'DH',
|
||||||
|
'EH', 'ER', 'EY', 'F', 'G',
|
||||||
|
'HH', 'IH', 'IY', 'JH', 'K',
|
||||||
|
'L', 'M', 'N', 'NG', 'OW',
|
||||||
|
'OY', 'P', 'R', 'S', 'SH',
|
||||||
|
'T', 'TH', 'UH', 'UW', 'V',
|
||||||
|
'W', 'Y', 'Z', 'ZH',
|
||||||
|
' | ',
|
||||||
|
]
|
||||||
|
|
||||||
|
def _extract_transcription(input):
|
||||||
|
endIdx = np.argwhere(input == 0)[0, 0]
|
||||||
|
trans = ''
|
||||||
|
for c in range(endIdx):
|
||||||
|
trans += chr(input[c])
|
||||||
|
return trans
|
||||||
|
|
||||||
|
def load_h5py_file(file_path, b2txt_csv_df):
|
||||||
|
data = {
|
||||||
|
'neural_features': [],
|
||||||
|
'n_time_steps': [],
|
||||||
|
'seq_class_ids': [],
|
||||||
|
'seq_len': [],
|
||||||
|
'transcriptions': [],
|
||||||
|
'sentence_label': [],
|
||||||
|
'session': [],
|
||||||
|
'block_num': [],
|
||||||
|
'trial_num': [],
|
||||||
|
'corpus': [],
|
||||||
|
}
|
||||||
|
# Open the hdf5 file for that day
|
||||||
|
with h5py.File(file_path, 'r') as f:
|
||||||
|
|
||||||
|
keys = list(f.keys())
|
||||||
|
|
||||||
|
# For each trial in the selected trials in that day
|
||||||
|
for key in keys:
|
||||||
|
g = f[key]
|
||||||
|
|
||||||
|
neural_features = g['input_features'][:]
|
||||||
|
n_time_steps = g.attrs['n_time_steps']
|
||||||
|
seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None
|
||||||
|
seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None
|
||||||
|
transcription = g['transcription'][:] if 'transcription' in g else None
|
||||||
|
sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None
|
||||||
|
session = g.attrs['session']
|
||||||
|
block_num = g.attrs['block_num']
|
||||||
|
trial_num = g.attrs['trial_num']
|
||||||
|
|
||||||
|
# match this trial up with the csv to get the corpus name
|
||||||
|
year, month, day = session.split('.')[1:]
|
||||||
|
date = f'{year}-{month}-{day}'
|
||||||
|
row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]
|
||||||
|
corpus_name = row['Corpus'].values[0]
|
||||||
|
|
||||||
|
data['neural_features'].append(neural_features)
|
||||||
|
data['n_time_steps'].append(n_time_steps)
|
||||||
|
data['seq_class_ids'].append(seq_class_ids)
|
||||||
|
data['seq_len'].append(seq_len)
|
||||||
|
data['transcriptions'].append(transcription)
|
||||||
|
data['sentence_label'].append(sentence_label)
|
||||||
|
data['session'].append(session)
|
||||||
|
data['block_num'].append(block_num)
|
||||||
|
data['trial_num'].append(trial_num)
|
||||||
|
data['corpus'].append(corpus_name)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def rearrange_speech_logits_pt(logits):
|
||||||
|
# original order is [BLANK, phonemes..., SIL]
|
||||||
|
# rearrange so the order is [BLANK, SIL, phonemes...]
|
||||||
|
logits = np.concatenate((logits[:, :, 0:1], logits[:, :, -1:], logits[:, :, 1:-1]), axis=-1)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# single decoding step function.
|
||||||
|
# smooths data and puts it through the model.
|
||||||
|
def runSingleDecodingStep(x, input_layer, model, model_args, device):
|
||||||
|
|
||||||
|
# Use autocast for efficiency
|
||||||
|
with torch.autocast(device_type = "cuda", enabled = model_args['use_amp'], dtype = torch.bfloat16):
|
||||||
|
|
||||||
|
x = gauss_smooth(
|
||||||
|
inputs = x,
|
||||||
|
device = device,
|
||||||
|
smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],
|
||||||
|
smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],
|
||||||
|
padding = 'valid',
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits, _ = model(
|
||||||
|
x = x,
|
||||||
|
day_idx = torch.tensor([input_layer], device=device),
|
||||||
|
states = None, # no initial states
|
||||||
|
return_state = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert logits from bfloat16 to float32
|
||||||
|
logits = logits.float().cpu().numpy()
|
||||||
|
|
||||||
|
# # original order is [BLANK, phonemes..., SIL]
|
||||||
|
# # rearrange so the order is [BLANK, SIL, phonemes...]
|
||||||
|
# logits = rearrange_speech_logits_pt(logits)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def remove_punctuation(sentence):
|
||||||
|
# Remove punctuation
|
||||||
|
sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
|
||||||
|
sentence = sentence.replace('- ', ' ').lower()
|
||||||
|
sentence = sentence.replace('--', '').lower()
|
||||||
|
sentence = sentence.replace(" '", "'").lower()
|
||||||
|
|
||||||
|
sentence = sentence.strip()
|
||||||
|
sentence = ' '.join([word for word in sentence.split() if word != ''])
|
||||||
|
|
||||||
|
return sentence
|
||||||
|
|
||||||
|
def get_current_redis_time_ms(redis_conn):
|
||||||
|
t = redis_conn.time()
|
||||||
|
return int(t[0]*1000 + t[1]/1000)
|
||||||
|
|
||||||
|
|
||||||
|
######### language model helper functions ##########
|
||||||
|
|
||||||
|
def reset_remote_language_model(
|
||||||
|
r,
|
||||||
|
remote_lm_done_resetting_lastEntrySeen,
|
||||||
|
):
|
||||||
|
|
||||||
|
r.xadd('remote_lm_reset', {'done': 0})
|
||||||
|
time.sleep(0.001)
|
||||||
|
# print('Resetting remote language model before continuing...')
|
||||||
|
remote_lm_done_resetting = []
|
||||||
|
while len(remote_lm_done_resetting) == 0:
|
||||||
|
remote_lm_done_resetting = r.xread(
|
||||||
|
{'remote_lm_done_resetting': remote_lm_done_resetting_lastEntrySeen},
|
||||||
|
count=1,
|
||||||
|
block=10000,
|
||||||
|
)
|
||||||
|
if len(remote_lm_done_resetting) == 0:
|
||||||
|
print(f'Still waiting for remote lm reset from ts {remote_lm_done_resetting_lastEntrySeen}...')
|
||||||
|
for entry_id, entry_data in remote_lm_done_resetting[0][1]:
|
||||||
|
remote_lm_done_resetting_lastEntrySeen = entry_id
|
||||||
|
# print('Remote language model reset.')
|
||||||
|
|
||||||
|
return remote_lm_done_resetting_lastEntrySeen
|
||||||
|
|
||||||
|
|
||||||
|
def update_remote_lm_params(
|
||||||
|
r,
|
||||||
|
remote_lm_done_updating_lastEntrySeen,
|
||||||
|
acoustic_scale=0.35,
|
||||||
|
blank_penalty=90.0,
|
||||||
|
alpha=0.55,
|
||||||
|
):
|
||||||
|
|
||||||
|
# update remote lm params
|
||||||
|
entry_dict = {
|
||||||
|
# 'max_active': max_active,
|
||||||
|
# 'min_active': min_active,
|
||||||
|
# 'beam': beam,
|
||||||
|
# 'lattice_beam': lattice_beam,
|
||||||
|
'acoustic_scale': acoustic_scale,
|
||||||
|
# 'ctc_blank_skip_threshold': ctc_blank_skip_threshold,
|
||||||
|
# 'length_penalty': length_penalty,
|
||||||
|
# 'nbest': nbest,
|
||||||
|
'blank_penalty': blank_penalty,
|
||||||
|
'alpha': alpha,
|
||||||
|
# 'do_opt': do_opt,
|
||||||
|
# 'rescore': rescore,
|
||||||
|
# 'top_candidates_to_augment': top_candidates_to_augment,
|
||||||
|
# 'score_penalty_percent': score_penalty_percent,
|
||||||
|
# 'specific_word_bias': specific_word_bias,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.xadd('remote_lm_update_params', entry_dict)
|
||||||
|
time.sleep(0.001)
|
||||||
|
remote_lm_done_updating = []
|
||||||
|
while len(remote_lm_done_updating) == 0:
|
||||||
|
remote_lm_done_updating = r.xread(
|
||||||
|
{'remote_lm_done_updating_params': remote_lm_done_updating_lastEntrySeen},
|
||||||
|
block=10000,
|
||||||
|
count=1,
|
||||||
|
)
|
||||||
|
if len(remote_lm_done_updating) == 0:
|
||||||
|
print(f'Still waiting for remote lm to update parameters from ts {remote_lm_done_updating_lastEntrySeen}...')
|
||||||
|
for entry_id, entry_data in remote_lm_done_updating[0][1]:
|
||||||
|
remote_lm_done_updating_lastEntrySeen = entry_id
|
||||||
|
# print('Remote language model params updated.')
|
||||||
|
|
||||||
|
return remote_lm_done_updating_lastEntrySeen
|
||||||
|
|
||||||
|
|
||||||
|
def send_logits_to_remote_lm(
|
||||||
|
r,
|
||||||
|
remote_lm_input_stream,
|
||||||
|
remote_lm_output_partial_stream,
|
||||||
|
remote_lm_output_partial_lastEntrySeen,
|
||||||
|
logits,
|
||||||
|
):
|
||||||
|
|
||||||
|
# put logits into remote lm and get partial output
|
||||||
|
r.xadd(remote_lm_input_stream, {'logits': np.float32(logits).tobytes()})
|
||||||
|
remote_lm_output = []
|
||||||
|
while len(remote_lm_output) == 0:
|
||||||
|
remote_lm_output = r.xread(
|
||||||
|
{remote_lm_output_partial_stream: remote_lm_output_partial_lastEntrySeen},
|
||||||
|
block=10000,
|
||||||
|
count=1,
|
||||||
|
)
|
||||||
|
if len(remote_lm_output) == 0:
|
||||||
|
print(f'Still waiting for remote lm partial output from ts {remote_lm_output_partial_lastEntrySeen}...')
|
||||||
|
for entry_id, entry_data in remote_lm_output[0][1]:
|
||||||
|
remote_lm_output_partial_lastEntrySeen = entry_id
|
||||||
|
decoded = entry_data[b'lm_response_partial'].decode()
|
||||||
|
|
||||||
|
return remote_lm_output_partial_lastEntrySeen, decoded
|
||||||
|
|
||||||
|
|
||||||
|
def finalize_remote_lm(
|
||||||
|
r,
|
||||||
|
remote_lm_output_final_stream,
|
||||||
|
remote_lm_output_final_lastEntrySeen,
|
||||||
|
):
|
||||||
|
|
||||||
|
# finalize remote lm
|
||||||
|
r.xadd('remote_lm_finalize', {'done': 0})
|
||||||
|
time.sleep(0.005)
|
||||||
|
remote_lm_output = []
|
||||||
|
while len(remote_lm_output) == 0:
|
||||||
|
remote_lm_output = r.xread(
|
||||||
|
{remote_lm_output_final_stream: remote_lm_output_final_lastEntrySeen},
|
||||||
|
block=10000,
|
||||||
|
count=1,
|
||||||
|
)
|
||||||
|
if len(remote_lm_output) == 0:
|
||||||
|
print(f'Still waiting for remote lm final output from ts {remote_lm_output_final_lastEntrySeen}...')
|
||||||
|
# print('Received remote lm final output.')
|
||||||
|
|
||||||
|
for entry_id, entry_data in remote_lm_output[0][1]:
|
||||||
|
remote_lm_output_final_lastEntrySeen = entry_id
|
||||||
|
|
||||||
|
candidate_sentences = [str(c) for c in entry_data[b'scoring'].decode().split(';')[::5]]
|
||||||
|
candidate_acoustic_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[1::5]]
|
||||||
|
candidate_ngram_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[2::5]]
|
||||||
|
candidate_llm_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[3::5]]
|
||||||
|
candidate_total_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[4::5]]
|
||||||
|
|
||||||
|
|
||||||
|
# account for a weird edge case where there are no candidate sentences
|
||||||
|
if len(candidate_sentences) == 0 or len(candidate_total_scores) == 0:
|
||||||
|
print('No candidate sentences were received from the language model.')
|
||||||
|
candidate_sentences = ['']
|
||||||
|
candidate_acoustic_scores = [0]
|
||||||
|
candidate_ngram_scores = [0]
|
||||||
|
candidate_llm_scores = [0]
|
||||||
|
candidate_total_scores = [0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# sort candidate sentences by total score (higher is better)
|
||||||
|
sort_order = np.argsort(candidate_total_scores)[::-1]
|
||||||
|
|
||||||
|
candidate_sentences = [candidate_sentences[i] for i in sort_order]
|
||||||
|
candidate_acoustic_scores = [candidate_acoustic_scores[i] for i in sort_order]
|
||||||
|
candidate_ngram_scores = [candidate_ngram_scores[i] for i in sort_order]
|
||||||
|
candidate_llm_scores = [candidate_llm_scores[i] for i in sort_order]
|
||||||
|
candidate_total_scores = [candidate_total_scores[i] for i in sort_order]
|
||||||
|
|
||||||
|
# loop through candidates backwards and remove any duplicates
|
||||||
|
for i in range(len(candidate_sentences)-1, 0, -1):
|
||||||
|
if candidate_sentences[i] in candidate_sentences[:i]:
|
||||||
|
candidate_sentences.pop(i)
|
||||||
|
candidate_acoustic_scores.pop(i)
|
||||||
|
candidate_ngram_scores.pop(i)
|
||||||
|
candidate_llm_scores.pop(i)
|
||||||
|
candidate_total_scores.pop(i)
|
||||||
|
|
||||||
|
lm_out = {
|
||||||
|
'candidate_sentences': candidate_sentences,
|
||||||
|
'candidate_acoustic_scores': candidate_acoustic_scores,
|
||||||
|
'candidate_ngram_scores': candidate_ngram_scores,
|
||||||
|
'candidate_llm_scores': candidate_llm_scores,
|
||||||
|
'candidate_total_scores': candidate_total_scores,
|
||||||
|
}
|
||||||
|
|
||||||
|
return remote_lm_output_final_lastEntrySeen, lm_out
|
124
model_training_nnn_tpu/jupyter_debug_full_model.py
Normal file
124
model_training_nnn_tpu/jupyter_debug_full_model.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# ====================
|
||||||
|
# 单元格4: 逐步调试完整模型编译
|
||||||
|
# ====================
|
||||||
|
|
||||||
|
# 如果单元格3测试通过,运行这个单元格
|
||||||
|
print("🔧 逐步测试完整TripleGRUDecoder模型...")
|
||||||
|
|
||||||
|
# 导入完整模型
|
||||||
|
import sys
|
||||||
|
sys.path.append('.') # 确保能导入本地模块
|
||||||
|
|
||||||
|
try:
|
||||||
|
from rnn_model import TripleGRUDecoder
|
||||||
|
print("✅ TripleGRUDecoder导入成功")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ 模型导入失败: {e}")
|
||||||
|
print("请确保rnn_model.py在当前目录中")
|
||||||
|
|
||||||
|
# 分阶段测试
|
||||||
|
def test_model_compilation_stages():
|
||||||
|
"""分阶段测试模型编译"""
|
||||||
|
device = xm.xla_device()
|
||||||
|
|
||||||
|
# 阶段1: 测试NoiseModel单独编译
|
||||||
|
print("\n🔬 阶段1: 测试NoiseModel...")
|
||||||
|
try:
|
||||||
|
from rnn_model import NoiseModel
|
||||||
|
noise_model = NoiseModel(
|
||||||
|
neural_dim=512,
|
||||||
|
n_units=384, # 减小参数
|
||||||
|
n_days=4,
|
||||||
|
patch_size=8 # 减小patch size
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
x = torch.randn(2, 20, 512, device=device)
|
||||||
|
day_idx = torch.tensor([0, 1], device=device)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
output, states = noise_model(x, day_idx)
|
||||||
|
compile_time = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}秒")
|
||||||
|
print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
|
||||||
|
|
||||||
|
return True, compile_time
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ NoiseModel编译失败: {e}")
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# 阶段2: 测试CleanSpeechModel
|
||||||
|
print("\n🔬 阶段2: 测试CleanSpeechModel...")
|
||||||
|
try:
|
||||||
|
from rnn_model import CleanSpeechModel
|
||||||
|
clean_model = CleanSpeechModel(
|
||||||
|
neural_dim=512,
|
||||||
|
n_units=384,
|
||||||
|
n_days=4,
|
||||||
|
n_classes=41,
|
||||||
|
patch_size=8
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
output = clean_model(x, day_idx)
|
||||||
|
compile_time = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}秒")
|
||||||
|
return True, compile_time
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ CleanSpeechModel编译失败: {e}")
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# 阶段3: 测试完整TripleGRUDecoder
|
||||||
|
print("\n🔬 阶段3: 测试TripleGRUDecoder...")
|
||||||
|
try:
|
||||||
|
model = TripleGRUDecoder(
|
||||||
|
neural_dim=512,
|
||||||
|
n_units=384, # 比原来的768小
|
||||||
|
n_days=4, # 减少天数
|
||||||
|
n_classes=41,
|
||||||
|
patch_size=8 # 减小patch size
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
# 启动编译监控
|
||||||
|
compilation_monitor.start_monitoring()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
# 测试inference模式
|
||||||
|
logits = model(x, day_idx, None, False, 'inference')
|
||||||
|
compile_time = time.time() - start_time
|
||||||
|
|
||||||
|
compilation_monitor.complete_monitoring()
|
||||||
|
|
||||||
|
print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}秒")
|
||||||
|
print(f"📤 输出形状: {logits.shape}")
|
||||||
|
|
||||||
|
return True, compile_time
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
compilation_monitor.complete_monitoring()
|
||||||
|
print(f"❌ TripleGRUDecoder编译失败: {e}")
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# 运行分阶段测试
|
||||||
|
stage_results = test_model_compilation_stages()
|
||||||
|
|
||||||
|
if stage_results:
|
||||||
|
print(f"\n🎉 所有编译测试完成!")
|
||||||
|
print("💡 下一步可以尝试:")
|
||||||
|
print(" 1. 使用简化配置进行训练")
|
||||||
|
print(" 2. 逐步增加模型复杂度")
|
||||||
|
print(" 3. 监控TPU资源使用情况")
|
||||||
|
else:
|
||||||
|
print(f"\n⚠️ 编译测试发现问题")
|
||||||
|
print("💡 建议:")
|
||||||
|
print(" 1. 进一步减小模型参数")
|
||||||
|
print(" 2. 检查内存使用情况")
|
||||||
|
print(" 3. 使用CPU模式进行调试")
|
131
model_training_nnn_tpu/jupyter_xla_monitor.py
Normal file
131
model_training_nnn_tpu/jupyter_xla_monitor.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# ====================
|
||||||
|
# 单元格2: XLA编译进度监控
|
||||||
|
# ====================
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from IPython.display import display, HTML, clear_output
|
||||||
|
import ipywidgets as widgets
|
||||||
|
|
||||||
|
# 导入XLA (环境变量已在单元格1中设置)
|
||||||
|
print("🚀 导入PyTorch XLA...")
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
print(f"✅ XLA导入成功!")
|
||||||
|
print(f" TPU设备: {xm.xla_device()}")
|
||||||
|
print(f" World Size: {xm.xrt_world_size()}")
|
||||||
|
|
||||||
|
# 创建编译进度监控器
|
||||||
|
class JupyterCompilationMonitor:
|
||||||
|
def __init__(self):
|
||||||
|
self.start_time = None
|
||||||
|
self.is_monitoring = False
|
||||||
|
|
||||||
|
# 创建输出widget
|
||||||
|
self.output_widget = widgets.Output()
|
||||||
|
|
||||||
|
# 创建进度条
|
||||||
|
self.progress_bar = widgets.IntProgress(
|
||||||
|
value=0,
|
||||||
|
min=0,
|
||||||
|
max=100,
|
||||||
|
description='XLA编译:',
|
||||||
|
bar_style='info',
|
||||||
|
style={'bar_color': '#1f77b4'},
|
||||||
|
orientation='horizontal'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建状态标签
|
||||||
|
self.status_label = widgets.HTML(
|
||||||
|
value="<b>准备开始编译...</b>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建CPU使用率显示
|
||||||
|
self.cpu_label = widgets.HTML(
|
||||||
|
value="CPU: ---%"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.memory_label = widgets.HTML(
|
||||||
|
value="内存: ---%"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 组合界面
|
||||||
|
self.monitor_box = widgets.VBox([
|
||||||
|
widgets.HTML("<h3>🔄 XLA编译监控</h3>"),
|
||||||
|
self.progress_bar,
|
||||||
|
self.status_label,
|
||||||
|
widgets.HBox([self.cpu_label, self.memory_label]),
|
||||||
|
self.output_widget
|
||||||
|
])
|
||||||
|
|
||||||
|
def start_monitoring(self):
|
||||||
|
"""开始监控"""
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.is_monitoring = True
|
||||||
|
|
||||||
|
display(self.monitor_box)
|
||||||
|
|
||||||
|
# 启动监控线程
|
||||||
|
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||||
|
self.monitor_thread.start()
|
||||||
|
|
||||||
|
def _monitor_loop(self):
|
||||||
|
"""监控循环"""
|
||||||
|
while self.is_monitoring:
|
||||||
|
try:
|
||||||
|
elapsed = time.time() - self.start_time
|
||||||
|
minutes = int(elapsed // 60)
|
||||||
|
seconds = int(elapsed % 60)
|
||||||
|
|
||||||
|
# 更新进度条 (模拟进度)
|
||||||
|
progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
|
||||||
|
self.progress_bar.value = progress
|
||||||
|
|
||||||
|
# 获取系统资源
|
||||||
|
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||||
|
memory_percent = psutil.virtual_memory().percent
|
||||||
|
|
||||||
|
# 更新显示
|
||||||
|
self.status_label.value = f"<b>编译进行中... ⏱️ {minutes:02d}:{seconds:02d}</b>"
|
||||||
|
self.cpu_label.value = f"<b>🖥️ CPU: {cpu_percent:5.1f}%</b>"
|
||||||
|
self.memory_label.value = f"<b>💾 内存: {memory_percent:5.1f}%</b>"
|
||||||
|
|
||||||
|
# 检测是否编译完成 (CPU使用率突然下降)
|
||||||
|
if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
|
||||||
|
self.complete_monitoring()
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
with self.output_widget:
|
||||||
|
print(f"监控错误: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
def complete_monitoring(self):
|
||||||
|
"""完成监控"""
|
||||||
|
if self.is_monitoring:
|
||||||
|
self.is_monitoring = False
|
||||||
|
elapsed = time.time() - self.start_time
|
||||||
|
|
||||||
|
self.progress_bar.value = 100
|
||||||
|
self.progress_bar.bar_style = 'success'
|
||||||
|
self.status_label.value = f"<b style='color: green'>✅ 编译完成! 总耗时: {elapsed:.2f}秒</b>"
|
||||||
|
|
||||||
|
with self.output_widget:
|
||||||
|
print(f"\n🎉 XLA编译成功完成!")
|
||||||
|
print(f"⏱️ 总耗时: {elapsed:.2f}秒")
|
||||||
|
if elapsed < 60:
|
||||||
|
print("✅ 编译速度正常")
|
||||||
|
elif elapsed < 300:
|
||||||
|
print("⚠️ 编译稍慢,但可接受")
|
||||||
|
else:
|
||||||
|
print("❌ 编译过慢,建议检查设置")
|
||||||
|
|
||||||
|
# 创建全局监控器
|
||||||
|
compilation_monitor = JupyterCompilationMonitor()
|
||||||
|
|
||||||
|
print("✅ 编译监控器已准备就绪!")
|
||||||
|
print("💡 运行下一个单元格开始XLA编译测试")
|
45
model_training_nnn_tpu/jupyter_xla_setup.py
Normal file
45
model_training_nnn_tpu/jupyter_xla_setup.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# ====================
|
||||||
|
# 单元格1: 环境设置 (必须第一个运行!)
|
||||||
|
# ====================
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
from IPython.display import display, HTML, clear_output
|
||||||
|
import ipywidgets as widgets
|
||||||
|
|
||||||
|
# ⚠️ 关键: 在导入torch_xla之前设置环境变量
|
||||||
|
print("🔧 设置XLA环境变量...")
|
||||||
|
|
||||||
|
# 获取CPU核心数
|
||||||
|
cpu_count = os.cpu_count()
|
||||||
|
print(f"检测到 {cpu_count} 个CPU核心")
|
||||||
|
|
||||||
|
# 设置XLA编译优化环境变量
|
||||||
|
os.environ['XLA_FLAGS'] = (
|
||||||
|
'--xla_cpu_multi_thread_eigen=true '
|
||||||
|
'--xla_cpu_enable_fast_math=true '
|
||||||
|
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||||
|
)
|
||||||
|
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||||||
|
os.environ['XLA_USE_BF16'] = '1'
|
||||||
|
|
||||||
|
# 显示设置结果
|
||||||
|
print("✅ XLA环境变量已设置:")
|
||||||
|
print(f" CPU核心数: {cpu_count}")
|
||||||
|
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||||||
|
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||||
|
|
||||||
|
# 系统资源检查
|
||||||
|
memory_info = psutil.virtual_memory()
|
||||||
|
print(f"\n💾 系统内存信息:")
|
||||||
|
print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
|
||||||
|
print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
|
||||||
|
print(f" 使用率: {memory_info.percent:.1f}%")
|
||||||
|
|
||||||
|
if memory_info.available < 8 * (1024**3): # 小于8GB
|
||||||
|
print("⚠️ 警告: 可用内存不足8GB,可能影响XLA编译速度")
|
||||||
|
else:
|
||||||
|
print("✅ 内存充足")
|
||||||
|
|
||||||
|
print("\n🎯 环境设置完成! 现在可以运行下一个单元格")
|
78
model_training_nnn_tpu/jupyter_xla_test.py
Normal file
78
model_training_nnn_tpu/jupyter_xla_test.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# ====================
|
||||||
|
# 单元格3: 快速XLA编译测试
|
||||||
|
# ====================
|
||||||
|
|
||||||
|
# 简化测试模型
|
||||||
|
class QuickTestModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(512, 128)
|
||||||
|
self.gru = nn.GRU(128, 64, batch_first=True)
|
||||||
|
self.linear2 = nn.Linear(64, 41)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.relu(self.linear1(x))
|
||||||
|
x, _ = self.gru(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
print("🧪 开始XLA编译快速测试...")
|
||||||
|
|
||||||
|
# 启动监控
|
||||||
|
compilation_monitor.start_monitoring()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取TPU设备
|
||||||
|
device = xm.xla_device()
|
||||||
|
|
||||||
|
# 创建小模型
|
||||||
|
model = QuickTestModel().to(device)
|
||||||
|
param_count = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"📊 测试模型参数: {param_count:,}")
|
||||||
|
|
||||||
|
# 创建测试数据 (很小的batch)
|
||||||
|
x = torch.randn(2, 20, 512, device=device)
|
||||||
|
print(f"📥 输入数据形状: {x.shape}")
|
||||||
|
|
||||||
|
print("🔄 开始首次前向传播 (触发XLA编译)...")
|
||||||
|
|
||||||
|
# 首次前向传播 - 这会触发XLA编译
|
||||||
|
with torch.no_grad():
|
||||||
|
start_compile = time.time()
|
||||||
|
output = model(x)
|
||||||
|
compile_time = time.time() - start_compile
|
||||||
|
|
||||||
|
print(f"✅ XLA编译完成!")
|
||||||
|
print(f"📤 输出形状: {output.shape}")
|
||||||
|
|
||||||
|
# 完成监控
|
||||||
|
compilation_monitor.complete_monitoring()
|
||||||
|
|
||||||
|
# 测试编译后的性能
|
||||||
|
print("\n🚀 测试编译后的执行速度...")
|
||||||
|
with torch.no_grad():
|
||||||
|
start_exec = time.time()
|
||||||
|
for _ in range(10):
|
||||||
|
output = model(x)
|
||||||
|
avg_exec_time = (time.time() - start_exec) / 10
|
||||||
|
|
||||||
|
print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms")
|
||||||
|
|
||||||
|
# 性能评估
|
||||||
|
if compile_time < 30:
|
||||||
|
print("✅ 编译速度优秀! 可以尝试完整模型")
|
||||||
|
test_result = "excellent"
|
||||||
|
elif compile_time < 120:
|
||||||
|
print("✅ 编译速度良好! 建议使用简化配置")
|
||||||
|
test_result = "good"
|
||||||
|
else:
|
||||||
|
print("⚠️ 编译速度较慢,建议进一步优化")
|
||||||
|
test_result = "slow"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
compilation_monitor.complete_monitoring()
|
||||||
|
print(f"❌ 测试失败: {e}")
|
||||||
|
test_result = "failed"
|
||||||
|
|
||||||
|
print(f"\n📋 测试结果: {test_result}")
|
||||||
|
print("💡 如果测试通过,可以运行下一个单元格进行完整训练")
|
161
model_training_nnn_tpu/launch_tpu_training.py
Normal file
161
model_training_nnn_tpu/launch_tpu_training.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
TPU Training Launch Script for Brain-to-Text RNN Model
|
||||||
|
|
||||||
|
This script provides easy TPU training setup using Accelerate library.
|
||||||
|
Supports both single TPU core and multi-core (8 cores) training.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- PyTorch XLA installed
|
||||||
|
- Accelerate library installed
|
||||||
|
- TPU runtime available
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def update_config_for_tpu(config_path, num_cores=8):
|
||||||
|
"""
|
||||||
|
Update configuration file to enable TPU training
|
||||||
|
"""
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Enable TPU settings
|
||||||
|
config['use_tpu'] = True
|
||||||
|
config['num_tpu_cores'] = num_cores
|
||||||
|
config['dataloader_num_workers'] = 0 # Required for TPU
|
||||||
|
config['use_amp'] = True # Enable mixed precision with bfloat16
|
||||||
|
|
||||||
|
# Adjust batch size and gradient accumulation for multi-core TPU
|
||||||
|
if num_cores > 1:
|
||||||
|
# Distribute batch size across cores
|
||||||
|
original_batch_size = config['dataset']['batch_size']
|
||||||
|
config['dataset']['batch_size'] = max(1, original_batch_size // num_cores)
|
||||||
|
config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1))
|
||||||
|
|
||||||
|
print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core")
|
||||||
|
print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
|
||||||
|
|
||||||
|
# Save updated config
|
||||||
|
tpu_config_path = config_path.replace('.yaml', '_tpu.yaml')
|
||||||
|
with open(tpu_config_path, 'w') as f:
|
||||||
|
yaml.dump(config, f, default_flow_style=False)
|
||||||
|
|
||||||
|
print(f"TPU configuration saved to: {tpu_config_path}")
|
||||||
|
return tpu_config_path
|
||||||
|
|
||||||
|
def check_tpu_environment():
|
||||||
|
"""
|
||||||
|
Check if TPU environment is properly set up
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import torch_xla
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
# Check if TPUs are available
|
||||||
|
device = xm.xla_device()
|
||||||
|
print(f"TPU device available: {device}")
|
||||||
|
print(f"TPU ordinal: {xm.get_ordinal()}")
|
||||||
|
print(f"TPU world size: {xm.xrt_world_size()}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: TPU not available - {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run_tpu_training(config_path, num_cores=8):
|
||||||
|
"""
|
||||||
|
Launch TPU training using accelerate
|
||||||
|
"""
|
||||||
|
# Check TPU environment
|
||||||
|
if not check_tpu_environment():
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Update config for TPU
|
||||||
|
tpu_config_path = update_config_for_tpu(config_path, num_cores)
|
||||||
|
|
||||||
|
# Set TPU environment variables BEFORE launching training
|
||||||
|
os.environ['TPU_CORES'] = str(num_cores)
|
||||||
|
os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16
|
||||||
|
|
||||||
|
# Critical XLA multi-threading settings - must be set before torch_xla import
|
||||||
|
cpu_count = os.cpu_count()
|
||||||
|
os.environ['XLA_FLAGS'] = (
|
||||||
|
'--xla_cpu_multi_thread_eigen=true '
|
||||||
|
'--xla_cpu_enable_fast_math=true '
|
||||||
|
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||||
|
)
|
||||||
|
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||||||
|
|
||||||
|
print(f"Set XLA compilation to use {cpu_count} CPU threads")
|
||||||
|
print(f"XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||||||
|
print(f"PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||||
|
|
||||||
|
# Launch training with accelerate using subprocess to ensure environment variables are passed
|
||||||
|
cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}"
|
||||||
|
|
||||||
|
print(f"Launching TPU training with command:")
|
||||||
|
print(f" {cmd}")
|
||||||
|
print(f"Using {num_cores} TPU cores")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
# Use subprocess to ensure environment variables are properly inherited
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
# Create environment with our XLA settings
|
||||||
|
env = os.environ.copy()
|
||||||
|
env.update({
|
||||||
|
'TPU_CORES': str(num_cores),
|
||||||
|
'XLA_USE_BF16': '1',
|
||||||
|
'XLA_FLAGS': (
|
||||||
|
'--xla_cpu_multi_thread_eigen=true '
|
||||||
|
'--xla_cpu_enable_fast_math=true '
|
||||||
|
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||||
|
),
|
||||||
|
'PYTORCH_XLA_COMPILATION_THREADS': str(cpu_count)
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"Environment variables set for subprocess:")
|
||||||
|
print(f" XLA_FLAGS: {env['XLA_FLAGS']}")
|
||||||
|
print(f" PYTORCH_XLA_COMPILATION_THREADS: {env['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
# Execute training with proper environment
|
||||||
|
result = subprocess.run(cmd.split(), env=env)
|
||||||
|
return result.returncode
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN')
|
||||||
|
parser.add_argument('--config', default='rnn_args.yaml',
|
||||||
|
help='Path to configuration file (default: rnn_args.yaml)')
|
||||||
|
parser.add_argument('--num_cores', type=int, default=8,
|
||||||
|
help='Number of TPU cores to use (default: 8)')
|
||||||
|
parser.add_argument('--check_only', action='store_true',
|
||||||
|
help='Only check TPU environment, do not launch training')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Verify config file exists
|
||||||
|
if not os.path.exists(args.config):
|
||||||
|
print(f"ERROR: Configuration file {args.config} not found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.check_only:
|
||||||
|
check_tpu_environment()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run TPU training
|
||||||
|
run_tpu_training(args.config, args.num_cores)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
100
model_training_nnn_tpu/monitor_xla_compilation.py
Normal file
100
model_training_nnn_tpu/monitor_xla_compilation.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
XLA编译进度监控脚本
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import psutil
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
def monitor_compilation_progress():
|
||||||
|
"""监控XLA编译进度"""
|
||||||
|
print("🔍 XLA编译进度监控已启动...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
dots = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
minutes = int(elapsed // 60)
|
||||||
|
seconds = int(elapsed % 60)
|
||||||
|
|
||||||
|
# 获取CPU使用率
|
||||||
|
cpu_percent = psutil.cpu_percent(interval=1)
|
||||||
|
memory_percent = psutil.virtual_memory().percent
|
||||||
|
|
||||||
|
# 动态显示
|
||||||
|
dots = (dots + 1) % 4
|
||||||
|
dot_str = "." * dots + " " * (3 - dots)
|
||||||
|
|
||||||
|
print(f"\r🔄 XLA编译中{dot_str} "
|
||||||
|
f"⏱️ {minutes:02d}:{seconds:02d} "
|
||||||
|
f"🖥️ CPU: {cpu_percent:5.1f}% "
|
||||||
|
f"💾 内存: {memory_percent:5.1f}%", end="", flush=True)
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def compilation_monitor():
|
||||||
|
"""编译监控上下文管理器"""
|
||||||
|
print("🚀 开始XLA编译监控...")
|
||||||
|
|
||||||
|
# 启动监控线程
|
||||||
|
monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True)
|
||||||
|
monitor_thread.start()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}秒")
|
||||||
|
|
||||||
|
# 修改trainer中的使用方法
|
||||||
|
def add_compilation_monitoring_to_trainer():
|
||||||
|
"""给trainer添加编译监控的示例代码"""
|
||||||
|
example_code = '''
|
||||||
|
# 在rnn_trainer.py的train方法中添加:
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
from monitor_xla_compilation import compilation_monitor
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
|
train_losses = []
|
||||||
|
# ... 其他初始化代码 ...
|
||||||
|
|
||||||
|
self.logger.info("Starting training loop - XLA compilation monitoring enabled...")
|
||||||
|
|
||||||
|
# 使用编译监控
|
||||||
|
with compilation_monitor():
|
||||||
|
for i, batch in enumerate(self.train_loader):
|
||||||
|
# 第一个batch会触发XLA编译
|
||||||
|
# 监控会显示编译进度
|
||||||
|
|
||||||
|
# ... 训练代码 ...
|
||||||
|
|
||||||
|
# 编译完成后会自动停止监控
|
||||||
|
break # 只需要第一个batch来触发编译
|
||||||
|
|
||||||
|
# 继续正常训练循环
|
||||||
|
for i, batch in enumerate(self.train_loader):
|
||||||
|
# ... 正常训练代码 ...
|
||||||
|
'''
|
||||||
|
|
||||||
|
print("📝 如何在trainer中使用编译监控:")
|
||||||
|
print(example_code)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("🧪 XLA编译监控工具")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 演示如何使用
|
||||||
|
print("📖 使用方法:")
|
||||||
|
print("1. 将此文件导入到你的训练脚本中")
|
||||||
|
print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器")
|
||||||
|
print("3. 会实时显示编译进度和系统资源使用情况")
|
||||||
|
|
||||||
|
add_compilation_monitoring_to_trainer()
|
181
model_training_nnn_tpu/rnn_args.yaml
Normal file
181
model_training_nnn_tpu/rnn_args.yaml
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
model:
|
||||||
|
n_input_features: 512 # number of input features in the neural data. (2 features per electrode, 256 electrodes)
|
||||||
|
n_units: 768 # number of units per GRU layer
|
||||||
|
rnn_dropout: 0.4 # dropout rate for the GRU layers
|
||||||
|
rnn_trainable: true # whether the GRU layers are trainable
|
||||||
|
n_layers: 5 # number of GRU layers
|
||||||
|
patch_size: 14 # size of the input patches (14 time steps)
|
||||||
|
patch_stride: 4 # stride for the input patches (4 time steps)
|
||||||
|
|
||||||
|
input_network:
|
||||||
|
n_input_layers: 1 # number of input layers per network (one network for each day)
|
||||||
|
input_layer_sizes:
|
||||||
|
- 512 # size of the input layer (number of input features)
|
||||||
|
input_trainable: true # whether the input layer is trainable
|
||||||
|
input_layer_dropout: 0.2 # dropout rate for the input layer
|
||||||
|
|
||||||
|
mode: train
|
||||||
|
use_amp: true # whether to use automatic mixed precision (AMP) for training with bfloat16 on TPU
|
||||||
|
|
||||||
|
# TPU distributed training settings
|
||||||
|
use_tpu: true # TPU training enabled
|
||||||
|
num_tpu_cores: 8 # number of TPU cores to use (full TPU v3-8 or v4-8)
|
||||||
|
gradient_accumulation_steps: 2 # number of gradient accumulation steps for distributed training (2x32=64 effective batch size)
|
||||||
|
|
||||||
|
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
|
||||||
|
checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
|
||||||
|
init_from_checkpoint: false # whether to initialize the model from a checkpoint
|
||||||
|
init_checkpoint_path: None # path to the checkpoint to initialize the model from, if any
|
||||||
|
save_best_checkpoint: true # whether to save the best checkpoint based on validation metrics
|
||||||
|
save_all_val_steps: false # whether to save checkpoints at all validation steps
|
||||||
|
save_final_model: false # whether to save the final model after training
|
||||||
|
save_val_metrics: true # whether to save validation metrics during training
|
||||||
|
early_stopping: false # whether to use early stopping based on validation metrics
|
||||||
|
early_stopping_val_steps: 20 # number of validation steps to wait before stopping training if no improvement is seen
|
||||||
|
|
||||||
|
num_training_batches: 120000 # number of training batches to run
|
||||||
|
lr_scheduler_type: cosine # type of learning rate scheduler to use
|
||||||
|
lr_max: 0.005 # maximum learning rate for the main model
|
||||||
|
lr_min: 0.0001 # minimum learning rate for the main model
|
||||||
|
lr_decay_steps: 120000 # number of steps for the learning rate decay
|
||||||
|
lr_warmup_steps: 1000 # number of warmup steps for the learning rate scheduler
|
||||||
|
lr_max_day: 0.005 # maximum learning rate for the day specific input layers
|
||||||
|
lr_min_day: 0.0001 # minimum learning rate for the day specific input layers
|
||||||
|
lr_decay_steps_day: 120000 # number of steps for the learning rate decay for the day specific input layers
|
||||||
|
lr_warmup_steps_day: 1000 # number of warmup steps for the learning rate scheduler for the day specific input layers
|
||||||
|
|
||||||
|
beta0: 0.9 # beta0 parameter for the Adam optimizer
|
||||||
|
beta1: 0.999 # beta1 parameter for the Adam optimizer
|
||||||
|
epsilon: 0.1 # epsilon parameter for the Adam optimizer
|
||||||
|
weight_decay: 0.001 # weight decay for the main model
|
||||||
|
weight_decay_day: 0 # weight decay for the day specific input layers
|
||||||
|
seed: 10 # random seed for reproducibility
|
||||||
|
grad_norm_clip_value: 10 # gradient norm clipping value
|
||||||
|
|
||||||
|
batches_per_train_log: 200 # number of batches per training log
|
||||||
|
batches_per_val_step: 2000 # number of batches per validation step
|
||||||
|
|
||||||
|
batches_per_save: 0 # number of batches per save
|
||||||
|
log_individual_day_val_PER: true # whether to log individual day validation performance
|
||||||
|
log_val_skip_logs: false # whether to skip logging validation metrics
|
||||||
|
save_val_logits: true # whether to save validation logits
|
||||||
|
save_val_data: false # whether to save validation data
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
data_transforms:
|
||||||
|
white_noise_std: 1.0 # standard deviation of the white noise added to the data
|
||||||
|
constant_offset_std: 0.2 # standard deviation of the constant offset added to the data
|
||||||
|
random_walk_std: 0.0 # standard deviation of the random walk added to the data
|
||||||
|
random_walk_axis: -1 # axis along which the random walk is applied
|
||||||
|
static_gain_std: 0.0 # standard deviation of the static gain applied to the data
|
||||||
|
random_cut: 3 # number of time steps to randomly cut from the beginning of each batch of trials
|
||||||
|
smooth_kernel_size: 100 # size of the smoothing kernel applied to the data
|
||||||
|
smooth_data: true # whether to smooth the data
|
||||||
|
smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data
|
||||||
|
|
||||||
|
neural_dim: 512 # dimensionality of the neural data
|
||||||
|
batch_size: 32 # batch size for training (reduced for TPU memory constraints)
|
||||||
|
n_classes: 41 # number of classes (phonemes) in the dataset
|
||||||
|
max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial
|
||||||
|
days_per_batch: 4 # number of randomly-selected days to include in each batch
|
||||||
|
seed: 1 # random seed for reproducibility
|
||||||
|
num_dataloader_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
|
||||||
|
loader_shuffle: false # whether to shuffle the data loader
|
||||||
|
must_include_days: null # specific days to include in the dataset
|
||||||
|
test_percentage: 0.1 # percentage of data to use for testing
|
||||||
|
feature_subset: null # specific features to include in the dataset
|
||||||
|
|
||||||
|
dataset_dir: ../data/hdf5_data_final # directory containing the dataset
|
||||||
|
bad_trials_dict: null # dictionary of bad trials to exclude from the dataset
|
||||||
|
sessions: # list of sessions to include in the dataset
|
||||||
|
- t15.2023.08.11
|
||||||
|
- t15.2023.08.13
|
||||||
|
- t15.2023.08.18
|
||||||
|
- t15.2023.08.20
|
||||||
|
- t15.2023.08.25
|
||||||
|
- t15.2023.08.27
|
||||||
|
- t15.2023.09.01
|
||||||
|
- t15.2023.09.03
|
||||||
|
- t15.2023.09.24
|
||||||
|
- t15.2023.09.29
|
||||||
|
- t15.2023.10.01
|
||||||
|
- t15.2023.10.06
|
||||||
|
- t15.2023.10.08
|
||||||
|
- t15.2023.10.13
|
||||||
|
- t15.2023.10.15
|
||||||
|
- t15.2023.10.20
|
||||||
|
- t15.2023.10.22
|
||||||
|
- t15.2023.11.03
|
||||||
|
- t15.2023.11.04
|
||||||
|
- t15.2023.11.17
|
||||||
|
- t15.2023.11.19
|
||||||
|
- t15.2023.11.26
|
||||||
|
- t15.2023.12.03
|
||||||
|
- t15.2023.12.08
|
||||||
|
- t15.2023.12.10
|
||||||
|
- t15.2023.12.17
|
||||||
|
- t15.2023.12.29
|
||||||
|
- t15.2024.02.25
|
||||||
|
- t15.2024.03.03
|
||||||
|
- t15.2024.03.08
|
||||||
|
- t15.2024.03.15
|
||||||
|
- t15.2024.03.17
|
||||||
|
- t15.2024.04.25
|
||||||
|
- t15.2024.04.28
|
||||||
|
- t15.2024.05.10
|
||||||
|
- t15.2024.06.14
|
||||||
|
- t15.2024.07.19
|
||||||
|
- t15.2024.07.21
|
||||||
|
- t15.2024.07.28
|
||||||
|
- t15.2025.01.10
|
||||||
|
- t15.2025.01.12
|
||||||
|
- t15.2025.03.14
|
||||||
|
- t15.2025.03.16
|
||||||
|
- t15.2025.03.30
|
||||||
|
- t15.2025.04.13
|
||||||
|
dataset_probability_val: # probability of including a trial in the validation set (0 or 1)
|
||||||
|
- 0 # no val or test data from this day
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 0 # no val or test data from this day
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 0 # no val or test data from this day
|
||||||
|
- 0 # no val or test data from this day
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
94
model_training_nnn_tpu/rnn_args_simple.yaml
Normal file
94
model_training_nnn_tpu/rnn_args_simple.yaml
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# 简化的TPU训练配置 - 更快编译
|
||||||
|
model:
|
||||||
|
n_input_features: 512
|
||||||
|
n_units: 384 # 减少从768到384
|
||||||
|
rnn_dropout: 0.2 # 减少dropout
|
||||||
|
rnn_trainable: true
|
||||||
|
n_layers: 3 # 减少从5到3层
|
||||||
|
patch_size: 8 # 减少从14到8
|
||||||
|
patch_stride: 4
|
||||||
|
|
||||||
|
input_network:
|
||||||
|
n_input_layers: 1
|
||||||
|
input_layer_sizes:
|
||||||
|
- 512
|
||||||
|
input_trainable: true
|
||||||
|
input_layer_dropout: 0.1 # 减少dropout
|
||||||
|
|
||||||
|
mode: train
|
||||||
|
use_amp: true
|
||||||
|
|
||||||
|
# TPU分布式训练设置
|
||||||
|
use_tpu: true
|
||||||
|
num_tpu_cores: 8
|
||||||
|
gradient_accumulation_steps: 4 # 增加梯度累积补偿小batch
|
||||||
|
|
||||||
|
output_dir: trained_models/simple_rnn
|
||||||
|
checkpoint_dir: trained_models/simple_rnn/checkpoint
|
||||||
|
init_from_checkpoint: false
|
||||||
|
save_best_checkpoint: true
|
||||||
|
save_val_metrics: true
|
||||||
|
|
||||||
|
num_training_batches: 1000 # 先测试1000个batch
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
lr_max: 0.003 # 稍微降低学习率
|
||||||
|
lr_min: 0.0001
|
||||||
|
lr_decay_steps: 1000
|
||||||
|
lr_warmup_steps: 100
|
||||||
|
|
||||||
|
lr_max_day: 0.003
|
||||||
|
lr_min_day: 0.0001
|
||||||
|
lr_decay_steps_day: 1000
|
||||||
|
lr_warmup_steps_day: 100
|
||||||
|
|
||||||
|
beta0: 0.9
|
||||||
|
beta1: 0.999
|
||||||
|
epsilon: 0.1
|
||||||
|
weight_decay: 0.001
|
||||||
|
weight_decay_day: 0
|
||||||
|
seed: 10
|
||||||
|
grad_norm_clip_value: 5 # 减少梯度裁剪
|
||||||
|
|
||||||
|
batches_per_train_log: 50 # 更频繁的日志
|
||||||
|
batches_per_val_step: 200
|
||||||
|
log_individual_day_val_PER: true
|
||||||
|
|
||||||
|
# 禁用对抗训练进行快速测试
|
||||||
|
adversarial:
|
||||||
|
enabled: false # 先禁用对抗训练
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
data_transforms:
|
||||||
|
white_noise_std: 0.5 # 减少数据增强
|
||||||
|
constant_offset_std: 0.1
|
||||||
|
random_walk_std: 0.0
|
||||||
|
random_walk_axis: -1
|
||||||
|
static_gain_std: 0.0
|
||||||
|
random_cut: 1 # 减少随机裁剪
|
||||||
|
smooth_kernel_size: 50 # 减少平滑核大小
|
||||||
|
smooth_data: true
|
||||||
|
smooth_kernel_std: 1
|
||||||
|
|
||||||
|
neural_dim: 512
|
||||||
|
batch_size: 16 # 减少batch size从32到16
|
||||||
|
n_classes: 41
|
||||||
|
max_seq_elements: 300 # 减少序列长度
|
||||||
|
days_per_batch: 2 # 减少每批天数
|
||||||
|
seed: 1
|
||||||
|
num_dataloader_workers: 0
|
||||||
|
loader_shuffle: false
|
||||||
|
test_percentage: 0.1
|
||||||
|
dataset_dir: ../data/hdf5_data_final
|
||||||
|
|
||||||
|
# 只使用部分session进行快速测试
|
||||||
|
sessions:
|
||||||
|
- t15.2023.08.11
|
||||||
|
- t15.2023.08.13
|
||||||
|
- t15.2023.08.18
|
||||||
|
- t15.2023.08.20
|
||||||
|
|
||||||
|
dataset_probability_val:
|
||||||
|
- 0
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
1427
model_training_nnn_tpu/rnn_baseline_submission_file_valsplit.csv
Normal file
1427
model_training_nnn_tpu/rnn_baseline_submission_file_valsplit.csv
Normal file
File diff suppressed because it is too large
Load Diff
580
model_training_nnn_tpu/rnn_model.py
Normal file
580
model_training_nnn_tpu/rnn_model.py
Normal file
@@ -0,0 +1,580 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
class GradientReversalFn(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Gradient Reversal Layer (GRL)
|
||||||
|
Forward: identity
|
||||||
|
Backward: multiply incoming gradient by -lambda
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, lambd: float):
|
||||||
|
ctx.lambd = lambd
|
||||||
|
return x.view_as(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return -ctx.lambd * grad_output, None
|
||||||
|
|
||||||
|
def gradient_reverse(x, lambd: float = 1.0):
|
||||||
|
return GradientReversalFn.apply(x, lambd)
|
||||||
|
|
||||||
|
class NoiseModel(nn.Module):
|
||||||
|
'''
|
||||||
|
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
neural_dim,
|
||||||
|
n_units,
|
||||||
|
n_days,
|
||||||
|
rnn_dropout=0.0,
|
||||||
|
input_dropout=0.0,
|
||||||
|
patch_size=0,
|
||||||
|
patch_stride=0):
|
||||||
|
super(NoiseModel, self).__init__()
|
||||||
|
|
||||||
|
self.neural_dim = neural_dim
|
||||||
|
self.n_units = n_units
|
||||||
|
self.n_days = n_days
|
||||||
|
self.rnn_dropout = rnn_dropout
|
||||||
|
self.input_dropout = input_dropout
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_stride = patch_stride
|
||||||
|
|
||||||
|
# Day-specific input layers
|
||||||
|
self.day_layer_activation = nn.Softsign()
|
||||||
|
# Let Accelerator handle dtype automatically for TPU compatibility
|
||||||
|
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||||
|
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||||
|
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||||
|
|
||||||
|
# Calculate input size after patching
|
||||||
|
self.input_size = self.neural_dim
|
||||||
|
if self.patch_size > 0:
|
||||||
|
self.input_size *= self.patch_size
|
||||||
|
|
||||||
|
# 2-layer GRU for noise estimation
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_size=self.input_size,
|
||||||
|
hidden_size=self.input_size, # Output same dimension as input
|
||||||
|
num_layers=2,
|
||||||
|
dropout=self.rnn_dropout,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize GRU parameters
|
||||||
|
for name, param in self.gru.named_parameters():
|
||||||
|
if "weight_hh" in name:
|
||||||
|
nn.init.orthogonal_(param)
|
||||||
|
if "weight_ih" in name:
|
||||||
|
nn.init.xavier_uniform_(param)
|
||||||
|
|
||||||
|
# Learnable initial hidden state - let Accelerator handle dtype
|
||||||
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
|
||||||
|
|
||||||
|
def forward(self, x, day_idx, states=None):
|
||||||
|
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
# Stack all day weights and biases upfront for static indexing
|
||||||
|
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
|
||||||
|
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
|
||||||
|
|
||||||
|
# XLA-friendly gather operation
|
||||||
|
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
|
||||||
|
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
||||||
|
|
||||||
|
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||||
|
# Ensure dtype consistency for mixed precision training
|
||||||
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
|
x = self.day_layer_activation(x)
|
||||||
|
|
||||||
|
# XLA-friendly conditional dropout
|
||||||
|
if self.input_dropout > 0:
|
||||||
|
x = self.day_layer_dropout(x)
|
||||||
|
|
||||||
|
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||||
|
if self.patch_size > 0:
|
||||||
|
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||||
|
x_unfold = x_unfold.squeeze(2)
|
||||||
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||||
|
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||||
|
# Ensure dtype consistency after patch processing operations
|
||||||
|
x = x.to(original_dtype)
|
||||||
|
|
||||||
|
gru_dtype = next(self.gru.parameters()).dtype
|
||||||
|
if x.dtype != gru_dtype:
|
||||||
|
x = x.to(gru_dtype)
|
||||||
|
|
||||||
|
# XLA-friendly hidden state initialization - avoid dynamic allocation
|
||||||
|
if states is None:
|
||||||
|
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
|
||||||
|
if states.dtype != gru_dtype:
|
||||||
|
states = states.to(gru_dtype)
|
||||||
|
|
||||||
|
# Disable autocast for GRU to avoid dtype mismatches on XLA
|
||||||
|
device_type = x.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
output, hidden_states = self.gru(x, states)
|
||||||
|
|
||||||
|
return output, hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CleanSpeechModel(nn.Module):
|
||||||
|
'''
|
||||||
|
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
neural_dim,
|
||||||
|
n_units,
|
||||||
|
n_days,
|
||||||
|
n_classes,
|
||||||
|
rnn_dropout=0.0,
|
||||||
|
input_dropout=0.0,
|
||||||
|
patch_size=0,
|
||||||
|
patch_stride=0):
|
||||||
|
super(CleanSpeechModel, self).__init__()
|
||||||
|
|
||||||
|
self.neural_dim = neural_dim
|
||||||
|
self.n_units = n_units
|
||||||
|
self.n_days = n_days
|
||||||
|
self.n_classes = n_classes
|
||||||
|
self.rnn_dropout = rnn_dropout
|
||||||
|
self.input_dropout = input_dropout
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_stride = patch_stride
|
||||||
|
|
||||||
|
# Day-specific input layers
|
||||||
|
self.day_layer_activation = nn.Softsign()
|
||||||
|
# Let Accelerator handle dtype automatically for TPU compatibility
|
||||||
|
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||||
|
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||||
|
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||||
|
|
||||||
|
# Calculate input size after patching
|
||||||
|
self.input_size = self.neural_dim
|
||||||
|
if self.patch_size > 0:
|
||||||
|
self.input_size *= self.patch_size
|
||||||
|
|
||||||
|
# 3-layer GRU for clean speech recognition
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_size=self.input_size,
|
||||||
|
hidden_size=self.n_units,
|
||||||
|
num_layers=3,
|
||||||
|
dropout=self.rnn_dropout,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize GRU parameters
|
||||||
|
for name, param in self.gru.named_parameters():
|
||||||
|
if "weight_hh" in name:
|
||||||
|
nn.init.orthogonal_(param)
|
||||||
|
if "weight_ih" in name:
|
||||||
|
nn.init.xavier_uniform_(param)
|
||||||
|
|
||||||
|
# Output classification layer
|
||||||
|
self.out = nn.Linear(self.n_units, self.n_classes)
|
||||||
|
nn.init.xavier_uniform_(self.out.weight)
|
||||||
|
|
||||||
|
# Learnable initial hidden state
|
||||||
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||||
|
|
||||||
|
def forward(self, x, day_idx, states=None, return_state=False):
|
||||||
|
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
# Stack all day weights and biases upfront for static indexing
|
||||||
|
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
|
||||||
|
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
|
||||||
|
|
||||||
|
# XLA-friendly gather operation
|
||||||
|
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
|
||||||
|
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
||||||
|
|
||||||
|
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||||
|
# Ensure dtype consistency for mixed precision training
|
||||||
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
|
x = self.day_layer_activation(x)
|
||||||
|
|
||||||
|
if self.input_dropout > 0:
|
||||||
|
x = self.day_layer_dropout(x)
|
||||||
|
|
||||||
|
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||||
|
if self.patch_size > 0:
|
||||||
|
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||||
|
x_unfold = x_unfold.squeeze(2)
|
||||||
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||||
|
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||||
|
# Ensure dtype consistency after patch processing operations
|
||||||
|
x = x.to(original_dtype)
|
||||||
|
|
||||||
|
gru_dtype = next(self.gru.parameters()).dtype
|
||||||
|
if x.dtype != gru_dtype:
|
||||||
|
x = x.to(gru_dtype)
|
||||||
|
|
||||||
|
# XLA-friendly hidden state initialization
|
||||||
|
if states is None:
|
||||||
|
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
|
||||||
|
if states.dtype != gru_dtype:
|
||||||
|
states = states.to(gru_dtype)
|
||||||
|
|
||||||
|
device_type = x.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
output, hidden_states = self.gru(x, states)
|
||||||
|
|
||||||
|
# Classification
|
||||||
|
logits = self.out(output)
|
||||||
|
|
||||||
|
if return_state:
|
||||||
|
return logits, hidden_states
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class NoisySpeechModel(nn.Module):
|
||||||
|
'''
|
||||||
|
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
neural_dim,
|
||||||
|
n_units,
|
||||||
|
n_days,
|
||||||
|
n_classes,
|
||||||
|
rnn_dropout=0.0,
|
||||||
|
input_dropout=0.0,
|
||||||
|
patch_size=0,
|
||||||
|
patch_stride=0):
|
||||||
|
super(NoisySpeechModel, self).__init__()
|
||||||
|
|
||||||
|
self.neural_dim = neural_dim
|
||||||
|
self.n_units = n_units
|
||||||
|
self.n_days = n_days
|
||||||
|
self.n_classes = n_classes
|
||||||
|
self.rnn_dropout = rnn_dropout
|
||||||
|
self.input_dropout = input_dropout
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_stride = patch_stride
|
||||||
|
|
||||||
|
# Calculate input size after patching
|
||||||
|
self.input_size = self.neural_dim
|
||||||
|
if self.patch_size > 0:
|
||||||
|
self.input_size *= self.patch_size
|
||||||
|
|
||||||
|
# 2-layer GRU for noisy speech recognition
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_size=self.input_size,
|
||||||
|
hidden_size=self.n_units,
|
||||||
|
num_layers=2,
|
||||||
|
dropout=self.rnn_dropout,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize GRU parameters
|
||||||
|
for name, param in self.gru.named_parameters():
|
||||||
|
if "weight_hh" in name:
|
||||||
|
nn.init.orthogonal_(param)
|
||||||
|
if "weight_ih" in name:
|
||||||
|
nn.init.xavier_uniform_(param)
|
||||||
|
|
||||||
|
# Output classification layer
|
||||||
|
self.out = nn.Linear(self.n_units, self.n_classes)
|
||||||
|
nn.init.xavier_uniform_(self.out.weight)
|
||||||
|
|
||||||
|
# Learnable initial hidden state
|
||||||
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||||
|
|
||||||
|
def forward(self, x, states=None, return_state=False):
|
||||||
|
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
gru_dtype = next(self.gru.parameters()).dtype
|
||||||
|
if x.dtype != gru_dtype:
|
||||||
|
x = x.to(gru_dtype)
|
||||||
|
|
||||||
|
# XLA-friendly hidden state initialization
|
||||||
|
if states is None:
|
||||||
|
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
|
||||||
|
if states.dtype != gru_dtype:
|
||||||
|
states = states.to(gru_dtype)
|
||||||
|
|
||||||
|
device_type = x.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
output, hidden_states = self.gru(x, states)
|
||||||
|
|
||||||
|
# Classification
|
||||||
|
logits = self.out(output)
|
||||||
|
|
||||||
|
if return_state:
|
||||||
|
return logits, hidden_states
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class TripleGRUDecoder(nn.Module):
|
||||||
|
'''
|
||||||
|
Three-model adversarial architecture for neural speech decoding
|
||||||
|
|
||||||
|
Combines:
|
||||||
|
- NoiseModel: estimates noise in neural data
|
||||||
|
- CleanSpeechModel: processes denoised signal for recognition
|
||||||
|
- NoisySpeechModel: processes noise signal for recognition
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
neural_dim,
|
||||||
|
n_units,
|
||||||
|
n_days,
|
||||||
|
n_classes,
|
||||||
|
rnn_dropout=0.0,
|
||||||
|
input_dropout=0.0,
|
||||||
|
patch_size=0,
|
||||||
|
patch_stride=0,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
neural_dim (int) - number of channels in a single timestep (e.g. 512)
|
||||||
|
n_units (int) - number of hidden units in each recurrent layer
|
||||||
|
n_days (int) - number of days in the dataset
|
||||||
|
n_classes (int) - number of classes (phonemes)
|
||||||
|
rnn_dropout (float) - percentage of units to dropout during training
|
||||||
|
input_dropout (float) - percentage of input units to dropout during training
|
||||||
|
patch_size (int) - number of timesteps to concat on initial input layer
|
||||||
|
patch_stride(int) - number of timesteps to stride over when concatenating initial input
|
||||||
|
'''
|
||||||
|
super(TripleGRUDecoder, self).__init__()
|
||||||
|
|
||||||
|
self.neural_dim = neural_dim
|
||||||
|
self.n_units = n_units
|
||||||
|
self.n_classes = n_classes
|
||||||
|
self.n_days = n_days
|
||||||
|
|
||||||
|
self.rnn_dropout = rnn_dropout
|
||||||
|
self.input_dropout = input_dropout
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.patch_stride = patch_stride
|
||||||
|
|
||||||
|
# Create the three models
|
||||||
|
self.noise_model = NoiseModel(
|
||||||
|
neural_dim=neural_dim,
|
||||||
|
n_units=n_units,
|
||||||
|
n_days=n_days,
|
||||||
|
rnn_dropout=rnn_dropout,
|
||||||
|
input_dropout=input_dropout,
|
||||||
|
patch_size=patch_size,
|
||||||
|
patch_stride=patch_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clean_speech_model = CleanSpeechModel(
|
||||||
|
neural_dim=neural_dim,
|
||||||
|
n_units=n_units,
|
||||||
|
n_days=n_days,
|
||||||
|
n_classes=n_classes,
|
||||||
|
rnn_dropout=rnn_dropout,
|
||||||
|
input_dropout=input_dropout,
|
||||||
|
patch_size=patch_size,
|
||||||
|
patch_stride=patch_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
self.noisy_speech_model = NoisySpeechModel(
|
||||||
|
neural_dim=neural_dim,
|
||||||
|
n_units=n_units,
|
||||||
|
n_days=n_days,
|
||||||
|
n_classes=n_classes,
|
||||||
|
rnn_dropout=rnn_dropout,
|
||||||
|
input_dropout=input_dropout,
|
||||||
|
patch_size=patch_size,
|
||||||
|
patch_stride=patch_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training mode flag
|
||||||
|
self.training_mode = 'full' # 'full', 'inference'
|
||||||
|
|
||||||
|
def _apply_preprocessing(self, x, day_idx):
|
||||||
|
'''XLA-friendly preprocessing with static operations'''
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
|
||||||
|
all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
|
||||||
|
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
|
||||||
|
|
||||||
|
# XLA-friendly gather operation
|
||||||
|
day_weights = torch.index_select(all_day_weights, 0, day_idx)
|
||||||
|
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
||||||
|
|
||||||
|
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||||
|
# Ensure dtype consistency for mixed precision training
|
||||||
|
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
|
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||||
|
|
||||||
|
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||||
|
if self.patch_size > 0:
|
||||||
|
original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||||
|
x_processed = x_processed.unsqueeze(1)
|
||||||
|
x_processed = x_processed.permute(0, 3, 1, 2)
|
||||||
|
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
|
||||||
|
x_unfold = x_unfold.squeeze(2)
|
||||||
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||||
|
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||||
|
# Ensure dtype consistency after patch processing operations
|
||||||
|
x_processed = x_processed.to(original_dtype)
|
||||||
|
|
||||||
|
return x_processed
|
||||||
|
|
||||||
|
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
|
||||||
|
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
|
||||||
|
batch_size = x_processed.size(0)
|
||||||
|
|
||||||
|
clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype
|
||||||
|
if x_processed.dtype != clean_gru_dtype:
|
||||||
|
x_processed = x_processed.to(clean_gru_dtype)
|
||||||
|
|
||||||
|
# XLA-friendly hidden state initialization with dtype consistency
|
||||||
|
if states is None:
|
||||||
|
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
||||||
|
# Ensure hidden states match input dtype for mixed precision training
|
||||||
|
if states.dtype != clean_gru_dtype:
|
||||||
|
states = states.to(clean_gru_dtype)
|
||||||
|
|
||||||
|
# GRU forward pass (skip preprocessing since input is already processed)
|
||||||
|
device_type = x_processed.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
||||||
|
|
||||||
|
# Classification
|
||||||
|
logits = self.clean_speech_model.out(output)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _noisy_forward_with_processed_input(self, x_processed, states=None):
|
||||||
|
'''Forward pass for NoisySpeechModel with already processed input'''
|
||||||
|
batch_size = x_processed.size(0)
|
||||||
|
|
||||||
|
noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype
|
||||||
|
if x_processed.dtype != noisy_gru_dtype:
|
||||||
|
x_processed = x_processed.to(noisy_gru_dtype)
|
||||||
|
|
||||||
|
# XLA-friendly hidden state initialization with dtype consistency
|
||||||
|
if states is None:
|
||||||
|
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
||||||
|
# Ensure hidden states match input dtype for mixed precision training
|
||||||
|
if states.dtype != noisy_gru_dtype:
|
||||||
|
states = states.to(noisy_gru_dtype)
|
||||||
|
|
||||||
|
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
||||||
|
device_type = x_processed.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
||||||
|
|
||||||
|
# Classification
|
||||||
|
logits = self.noisy_speech_model.out(output)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
|
||||||
|
'''
|
||||||
|
Three-model adversarial forward pass
|
||||||
|
|
||||||
|
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
|
||||||
|
day_idx (tensor) - tensor of day indices for each example in the batch
|
||||||
|
states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
|
||||||
|
mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
|
||||||
|
grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
|
||||||
|
'''
|
||||||
|
|
||||||
|
if mode == 'full':
|
||||||
|
# Training mode: run all three models
|
||||||
|
|
||||||
|
# 1. Noise model estimates noise in the data
|
||||||
|
noise_output, noise_hidden = self.noise_model(x, day_idx,
|
||||||
|
states['noise'] if states else None)
|
||||||
|
|
||||||
|
# 2. For residual connection, we need x in the same space as noise_output
|
||||||
|
# Apply the same preprocessing that the models use internally
|
||||||
|
x_processed = self._apply_preprocessing(x, day_idx)
|
||||||
|
clean_dtype = next(self.clean_speech_model.parameters()).dtype
|
||||||
|
if x_processed.dtype != clean_dtype:
|
||||||
|
x_processed = x_processed.to(clean_dtype)
|
||||||
|
|
||||||
|
# Ensure dtype consistency between processed input and noise output
|
||||||
|
if noise_output.dtype != clean_dtype:
|
||||||
|
noise_output = noise_output.to(clean_dtype)
|
||||||
|
|
||||||
|
# 3. Clean speech model processes denoised signal
|
||||||
|
denoised_input = x_processed - noise_output # Residual connection in processed space
|
||||||
|
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
|
||||||
|
# But we need to reverse the preprocessing first, then let clean model do its own
|
||||||
|
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
|
||||||
|
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||||
|
states['clean'] if states else None)
|
||||||
|
|
||||||
|
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
||||||
|
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
|
||||||
|
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
|
||||||
|
noisy_input = cast(torch.Tensor, noisy_input)
|
||||||
|
noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
|
||||||
|
if noisy_input.dtype != noisy_dtype:
|
||||||
|
noisy_input = noisy_input.to(noisy_dtype)
|
||||||
|
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
|
||||||
|
states['noisy'] if states else None)
|
||||||
|
|
||||||
|
# XLA-friendly return - use tuple instead of dict for better compilation
|
||||||
|
if return_state:
|
||||||
|
return (clean_logits, noisy_logits, noise_output), noise_hidden
|
||||||
|
return clean_logits, noisy_logits, noise_output
|
||||||
|
|
||||||
|
elif mode == 'inference':
|
||||||
|
# Inference mode: only noise model + clean speech model
|
||||||
|
|
||||||
|
# 1. Estimate noise
|
||||||
|
noise_output, noise_hidden = self.noise_model(x, day_idx,
|
||||||
|
states['noise'] if states else None)
|
||||||
|
|
||||||
|
# 2. For residual connection, we need x in the same space as noise_output
|
||||||
|
x_processed = self._apply_preprocessing(x, day_idx)
|
||||||
|
clean_dtype = next(self.clean_speech_model.parameters()).dtype
|
||||||
|
if x_processed.dtype != clean_dtype:
|
||||||
|
x_processed = x_processed.to(clean_dtype)
|
||||||
|
|
||||||
|
# Ensure dtype consistency for mixed precision residual connection
|
||||||
|
if noise_output.dtype != clean_dtype:
|
||||||
|
noise_output = noise_output.to(clean_dtype)
|
||||||
|
denoised_input = x_processed - noise_output
|
||||||
|
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||||
|
states['clean'] if states else None)
|
||||||
|
|
||||||
|
# XLA-friendly return - use tuple for consistency
|
||||||
|
if return_state:
|
||||||
|
return clean_logits, noise_hidden
|
||||||
|
return clean_logits
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
|
||||||
|
|
||||||
|
def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3):
|
||||||
|
'''
|
||||||
|
Apply combined gradients to noise model parameters
|
||||||
|
|
||||||
|
clean_grad (tensor) - gradients from clean speech model output layer
|
||||||
|
noisy_grad (tensor) - gradients from noisy speech model output layer
|
||||||
|
'''
|
||||||
|
# Combine gradients: negative from clean model, positive from noisy model
|
||||||
|
combined_grad = -clean_grad + noisy_grad
|
||||||
|
|
||||||
|
# Apply gradients to noise model parameters
|
||||||
|
# This is a simplified implementation - in practice you'd want more sophisticated update rules
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.noise_model.parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
# Scale the combined gradient appropriately
|
||||||
|
# This is a placeholder - you'd need to implement proper gradient mapping
|
||||||
|
param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data)
|
||||||
|
|
||||||
|
def set_mode(self, mode):
|
||||||
|
'''Set the operating mode'''
|
||||||
|
self.training_mode = mode
|
||||||
|
|
||||||
|
|
952
model_training_nnn_tpu/rnn_trainer.py
Normal file
952
model_training_nnn_tpu/rnn_trainer.py
Normal file
@@ -0,0 +1,952 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
# XLA multi-threading optimization - MUST be set before importing torch_xla
|
||||||
|
# Set these environment variables early to ensure they take effect
|
||||||
|
if 'TPU_CORES' in os.environ or 'COLAB_TPU_ADDR' in os.environ:
|
||||||
|
# Enable XLA multi-threading for compilation speedup
|
||||||
|
os.environ.setdefault('XLA_FLAGS',
|
||||||
|
'--xla_cpu_multi_thread_eigen=true ' +
|
||||||
|
'--xla_cpu_enable_fast_math=true ' +
|
||||||
|
f'--xla_force_host_platform_device_count={os.cpu_count()}'
|
||||||
|
)
|
||||||
|
# Set PyTorch XLA threading
|
||||||
|
os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count()))
|
||||||
|
print(f"Set XLA compilation threads to {os.cpu_count()}")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import pathlib
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
from dataset import BrainToTextDataset, train_test_split_indicies
|
||||||
|
from data_augmentations import gauss_smooth
|
||||||
|
|
||||||
|
import torchaudio.functional as F # for edit distance
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
# Import Accelerate for TPU support
|
||||||
|
from accelerate import Accelerator, DataLoaderConfiguration
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
|
||||||
|
# Import XLA after setting environment variables
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
|
||||||
|
torch.backends.cudnn.deterministic = True # makes training more reproducible
|
||||||
|
torch._dynamo.config.cache_size_limit = 64
|
||||||
|
|
||||||
|
from rnn_model import TripleGRUDecoder
|
||||||
|
|
||||||
|
class BrainToTextDecoder_Trainer:
|
||||||
|
"""
|
||||||
|
This class will initialize and train a brain-to-text phoneme decoder
|
||||||
|
|
||||||
|
Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
'''
|
||||||
|
args : dictionary of training arguments
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Configure DataLoader behavior for TPU compatibility
|
||||||
|
dataloader_config = DataLoaderConfiguration(
|
||||||
|
even_batches=False # Required for batch_size=None DataLoaders on TPU
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize Accelerator for TPU/multi-device support
|
||||||
|
self.use_xla = bool(xm.get_xla_supported_devices())
|
||||||
|
self.amp_requested = args.get('use_amp', True)
|
||||||
|
mixed_precision_mode = 'bf16' if self.amp_requested else 'no'
|
||||||
|
|
||||||
|
self.accelerator = Accelerator(
|
||||||
|
mixed_precision=mixed_precision_mode,
|
||||||
|
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
|
||||||
|
log_with=None, # We'll use our own logging
|
||||||
|
project_dir=args.get('output_dir', './output'),
|
||||||
|
dataloader_config=dataloader_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Trainer fields
|
||||||
|
self.args = args
|
||||||
|
self.logger = None
|
||||||
|
self.device = self.accelerator.device # Use accelerator device instead of manual device selection
|
||||||
|
self.model = None
|
||||||
|
self.optimizer = None
|
||||||
|
self.learning_rate_scheduler = None
|
||||||
|
self.ctc_loss = None
|
||||||
|
|
||||||
|
self.best_val_PER = torch.inf # track best PER for checkpointing
|
||||||
|
self.best_val_loss = torch.inf # track best loss for checkpointing
|
||||||
|
|
||||||
|
self.train_dataset = None
|
||||||
|
self.val_dataset = None
|
||||||
|
self.train_loader = None
|
||||||
|
self.val_loader = None
|
||||||
|
|
||||||
|
self.transform_args = self.args['dataset']['data_transforms']
|
||||||
|
|
||||||
|
# Adversarial training config (safe defaults if not provided)
|
||||||
|
adv_cfg = self.args.get('adversarial', {})
|
||||||
|
self.adv_enabled = adv_cfg.get('enabled', False)
|
||||||
|
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength
|
||||||
|
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC
|
||||||
|
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output
|
||||||
|
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
if args['mode'] == 'train':
|
||||||
|
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||||||
|
|
||||||
|
# Create checkpoint directory
|
||||||
|
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
|
||||||
|
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
for handler in self.logger.handlers[:]: # make a copy of the list
|
||||||
|
self.logger.removeHandler(handler)
|
||||||
|
self.logger.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
|
||||||
|
|
||||||
|
if args['mode']=='train':
|
||||||
|
# During training, save logs to file in output directory
|
||||||
|
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
self.logger.addHandler(fh)
|
||||||
|
|
||||||
|
# Always print logs to stdout
|
||||||
|
sh = logging.StreamHandler(sys.stdout)
|
||||||
|
sh.setFormatter(formatter)
|
||||||
|
self.logger.addHandler(sh)
|
||||||
|
|
||||||
|
# Log device information (managed by Accelerator)
|
||||||
|
self.logger.info(f'Using device: {self.device}')
|
||||||
|
self.logger.info(f'Accelerator state: {self.accelerator.state}')
|
||||||
|
if self.accelerator.num_processes > 1:
|
||||||
|
self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
|
||||||
|
if self.use_xla and self.amp_requested:
|
||||||
|
self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.')
|
||||||
|
|
||||||
|
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
|
||||||
|
if self.args['seed'] != -1:
|
||||||
|
set_seed(self.args['seed'])
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
self.model = TripleGRUDecoder(
|
||||||
|
neural_dim = self.args['model']['n_input_features'],
|
||||||
|
n_units = self.args['model']['n_units'],
|
||||||
|
n_days = len(self.args['dataset']['sessions']),
|
||||||
|
n_classes = self.args['dataset']['n_classes'],
|
||||||
|
rnn_dropout = self.args['model']['rnn_dropout'],
|
||||||
|
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
|
||||||
|
patch_size = self.args['model']['patch_size'],
|
||||||
|
patch_stride = self.args['model']['patch_stride'],
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_xla and self.amp_requested:
|
||||||
|
self.model = self.model.to(torch.bfloat16)
|
||||||
|
self.logger.info('Converted model parameters to bfloat16 for TPU training.')
|
||||||
|
|
||||||
|
self.model_dtype = next(self.model.parameters()).dtype
|
||||||
|
|
||||||
|
# Temporarily disable torch.compile for compatibility with new model architecture
|
||||||
|
# TODO: Re-enable torch.compile once model is stable
|
||||||
|
# self.logger.info("Using torch.compile")
|
||||||
|
# self.model = torch.compile(self.model)
|
||||||
|
self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
|
||||||
|
|
||||||
|
self.logger.info(f"Initialized RNN decoding model")
|
||||||
|
|
||||||
|
self.logger.info(self.model)
|
||||||
|
|
||||||
|
# Log how many parameters are in the model
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters())
|
||||||
|
self.logger.info(f"Model has {total_params:,} parameters")
|
||||||
|
|
||||||
|
# Determine how many day-specific parameters are in the model
|
||||||
|
day_params = 0
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
if 'day' in name:
|
||||||
|
day_params += param.numel()
|
||||||
|
|
||||||
|
self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
|
||||||
|
|
||||||
|
# Create datasets and dataloaders
|
||||||
|
train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
|
||||||
|
val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
|
||||||
|
|
||||||
|
# Ensure that there are no duplicate days
|
||||||
|
if len(set(train_file_paths)) != len(train_file_paths):
|
||||||
|
raise ValueError("There are duplicate sessions listed in the train dataset")
|
||||||
|
if len(set(val_file_paths)) != len(val_file_paths):
|
||||||
|
raise ValueError("There are duplicate sessions listed in the val dataset")
|
||||||
|
|
||||||
|
# Split trials into train and test sets
|
||||||
|
train_trials, _ = train_test_split_indicies(
|
||||||
|
file_paths = train_file_paths,
|
||||||
|
test_percentage = 0,
|
||||||
|
seed = self.args['dataset']['seed'],
|
||||||
|
bad_trials_dict = None,
|
||||||
|
)
|
||||||
|
_, val_trials = train_test_split_indicies(
|
||||||
|
file_paths = val_file_paths,
|
||||||
|
test_percentage = 1,
|
||||||
|
seed = self.args['dataset']['seed'],
|
||||||
|
bad_trials_dict = None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save dictionaries to output directory to know which trials were train vs val
|
||||||
|
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
|
||||||
|
json.dump({'train' : train_trials, 'val': val_trials}, f)
|
||||||
|
|
||||||
|
# Determine if a only a subset of neural features should be used
|
||||||
|
feature_subset = None
|
||||||
|
if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None:
|
||||||
|
feature_subset = self.args['dataset']['feature_subset']
|
||||||
|
self.logger.info(f'Using only a subset of features: {feature_subset}')
|
||||||
|
|
||||||
|
# train dataset and dataloader
|
||||||
|
self.train_dataset = BrainToTextDataset(
|
||||||
|
trial_indicies = train_trials,
|
||||||
|
split = 'train',
|
||||||
|
days_per_batch = self.args['dataset']['days_per_batch'],
|
||||||
|
n_batches = self.args['num_training_batches'],
|
||||||
|
batch_size = self.args['dataset']['batch_size'],
|
||||||
|
must_include_days = None,
|
||||||
|
random_seed = self.args['dataset']['seed'],
|
||||||
|
feature_subset = feature_subset
|
||||||
|
)
|
||||||
|
# Custom collate function that handles pre-batched data from our dataset
|
||||||
|
def collate_fn(batch):
|
||||||
|
# Our dataset returns full batches, so batch will be a list of single batch dict
|
||||||
|
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
|
||||||
|
if len(batch) == 1 and isinstance(batch[0], dict):
|
||||||
|
return batch[0]
|
||||||
|
else:
|
||||||
|
# Fallback for unexpected batch structure
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# DataLoader configuration compatible with Accelerate
|
||||||
|
self.train_loader = DataLoader(
|
||||||
|
self.train_dataset,
|
||||||
|
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||||
|
shuffle = self.args['dataset']['loader_shuffle'],
|
||||||
|
num_workers = self.args['dataset']['num_dataloader_workers'],
|
||||||
|
pin_memory = True,
|
||||||
|
collate_fn = collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# val dataset and dataloader
|
||||||
|
self.val_dataset = BrainToTextDataset(
|
||||||
|
trial_indicies = val_trials,
|
||||||
|
split = 'test',
|
||||||
|
days_per_batch = None,
|
||||||
|
n_batches = None,
|
||||||
|
batch_size = self.args['dataset']['batch_size'],
|
||||||
|
must_include_days = None,
|
||||||
|
random_seed = self.args['dataset']['seed'],
|
||||||
|
feature_subset = feature_subset
|
||||||
|
)
|
||||||
|
# Validation DataLoader with same collate function
|
||||||
|
self.val_loader = DataLoader(
|
||||||
|
self.val_dataset,
|
||||||
|
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||||
|
shuffle = False,
|
||||||
|
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||||
|
pin_memory = True,
|
||||||
|
collate_fn = collate_fn # Use same collate function
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info("Successfully initialized datasets")
|
||||||
|
|
||||||
|
# Create optimizer, learning rate scheduler, and loss
|
||||||
|
self.optimizer = self.create_optimizer()
|
||||||
|
|
||||||
|
if self.args['lr_scheduler_type'] == 'linear':
|
||||||
|
self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
|
||||||
|
optimizer = self.optimizer,
|
||||||
|
start_factor = 1.0,
|
||||||
|
end_factor = self.args['lr_min'] / self.args['lr_max'],
|
||||||
|
total_iters = self.args['lr_decay_steps'],
|
||||||
|
)
|
||||||
|
elif self.args['lr_scheduler_type'] == 'cosine':
|
||||||
|
self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
|
||||||
|
|
||||||
|
self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
|
||||||
|
|
||||||
|
# If a checkpoint is provided, then load from checkpoint
|
||||||
|
if self.args['init_from_checkpoint']:
|
||||||
|
self.load_model_checkpoint(self.args['init_checkpoint_path'])
|
||||||
|
|
||||||
|
# Set rnn and/or input layers to not trainable if specified
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
if not self.args['model']['rnn_trainable'] and 'gru' in name:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
|
||||||
|
# Let Accelerator handle everything automatically for both GPU and TPU
|
||||||
|
(
|
||||||
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
|
self.learning_rate_scheduler,
|
||||||
|
self.train_loader,
|
||||||
|
self.val_loader,
|
||||||
|
) = self.accelerator.prepare(
|
||||||
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
|
self.learning_rate_scheduler,
|
||||||
|
self.train_loader,
|
||||||
|
self.val_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_dtype = next(self.model.parameters()).dtype
|
||||||
|
|
||||||
|
self.logger.info("Prepared model and dataloaders with Accelerator")
|
||||||
|
if self.adv_enabled:
|
||||||
|
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
|
||||||
|
|
||||||
|
def autocast_context(self):
|
||||||
|
"""Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
|
||||||
|
if self.device.type == 'xla':
|
||||||
|
return nullcontext()
|
||||||
|
return self.accelerator.autocast()
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
'''
|
||||||
|
Create the optimizer with special param groups
|
||||||
|
|
||||||
|
Biases and day weights should not be decayed
|
||||||
|
|
||||||
|
Day weights should have a separate learning rate
|
||||||
|
'''
|
||||||
|
bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
|
||||||
|
day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
|
||||||
|
other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
|
||||||
|
|
||||||
|
if len(day_params) != 0:
|
||||||
|
param_groups = [
|
||||||
|
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
|
||||||
|
{'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
|
||||||
|
{'params' : other_params, 'group_type' : 'other'}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
param_groups = [
|
||||||
|
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
|
||||||
|
{'params' : other_params, 'group_type' : 'other'}
|
||||||
|
]
|
||||||
|
|
||||||
|
optim = torch.optim.AdamW(
|
||||||
|
param_groups,
|
||||||
|
lr = self.args['lr_max'],
|
||||||
|
betas = (self.args['beta0'], self.args['beta1']),
|
||||||
|
eps = self.args['epsilon'],
|
||||||
|
weight_decay = self.args['weight_decay'],
|
||||||
|
fused = True
|
||||||
|
)
|
||||||
|
|
||||||
|
return optim
|
||||||
|
|
||||||
|
def create_cosine_lr_scheduler(self, optim):
|
||||||
|
lr_max = self.args['lr_max']
|
||||||
|
lr_min = self.args['lr_min']
|
||||||
|
lr_decay_steps = self.args['lr_decay_steps']
|
||||||
|
|
||||||
|
lr_max_day = self.args['lr_max_day']
|
||||||
|
lr_min_day = self.args['lr_min_day']
|
||||||
|
lr_decay_steps_day = self.args['lr_decay_steps_day']
|
||||||
|
|
||||||
|
lr_warmup_steps = self.args['lr_warmup_steps']
|
||||||
|
lr_warmup_steps_day = self.args['lr_warmup_steps_day']
|
||||||
|
|
||||||
|
def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
|
||||||
|
'''
|
||||||
|
Create lr lambdas for each param group that implement cosine decay
|
||||||
|
|
||||||
|
Different lr lambda decaying for day params vs rest of the model
|
||||||
|
'''
|
||||||
|
# Warmup phase
|
||||||
|
if current_step < warmup_steps:
|
||||||
|
return float(current_step) / float(max(1, warmup_steps))
|
||||||
|
|
||||||
|
# Cosine decay phase
|
||||||
|
if current_step < decay_steps:
|
||||||
|
progress = float(current_step - warmup_steps) / float(
|
||||||
|
max(1, decay_steps - warmup_steps)
|
||||||
|
)
|
||||||
|
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
||||||
|
# Scale from 1.0 to min_lr_ratio
|
||||||
|
return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
|
||||||
|
|
||||||
|
# After cosine decay is complete, maintain min_lr_ratio
|
||||||
|
return min_lr_ratio
|
||||||
|
|
||||||
|
if len(optim.param_groups) == 3:
|
||||||
|
lr_lambdas = [
|
||||||
|
lambda step: lr_lambda(
|
||||||
|
step,
|
||||||
|
lr_min / lr_max,
|
||||||
|
lr_decay_steps,
|
||||||
|
lr_warmup_steps), # biases
|
||||||
|
lambda step: lr_lambda(
|
||||||
|
step,
|
||||||
|
lr_min_day / lr_max_day,
|
||||||
|
lr_decay_steps_day,
|
||||||
|
lr_warmup_steps_day,
|
||||||
|
), # day params
|
||||||
|
lambda step: lr_lambda(
|
||||||
|
step,
|
||||||
|
lr_min / lr_max,
|
||||||
|
lr_decay_steps,
|
||||||
|
lr_warmup_steps), # rest of model weights
|
||||||
|
]
|
||||||
|
elif len(optim.param_groups) == 2:
|
||||||
|
lr_lambdas = [
|
||||||
|
lambda step: lr_lambda(
|
||||||
|
step,
|
||||||
|
lr_min / lr_max,
|
||||||
|
lr_decay_steps,
|
||||||
|
lr_warmup_steps), # biases
|
||||||
|
lambda step: lr_lambda(
|
||||||
|
step,
|
||||||
|
lr_min / lr_max,
|
||||||
|
lr_decay_steps,
|
||||||
|
lr_warmup_steps), # rest of model weights
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
|
||||||
|
|
||||||
|
return LambdaLR(optim, lr_lambdas, -1)
|
||||||
|
|
||||||
|
def load_model_checkpoint(self, load_path):
|
||||||
|
'''
|
||||||
|
Load a training checkpoint for distributed training
|
||||||
|
'''
|
||||||
|
# Load checkpoint on CPU first to avoid OOM issues
|
||||||
|
checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
|
||||||
|
|
||||||
|
# Get unwrapped model for loading state dict
|
||||||
|
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||||
|
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
|
||||||
|
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
|
||||||
|
|
||||||
|
# Device handling is managed by Accelerator, no need to manually move to device
|
||||||
|
|
||||||
|
self.logger.info("Loaded model from checkpoint: " + load_path)
|
||||||
|
|
||||||
|
def save_model_checkpoint(self, save_path, PER, loss):
|
||||||
|
'''
|
||||||
|
Save a training checkpoint using Accelerator for distributed training
|
||||||
|
'''
|
||||||
|
# Only save on main process to avoid conflicts
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
# Unwrap model to get base model for saving
|
||||||
|
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
'model_state_dict' : unwrapped_model.state_dict(),
|
||||||
|
'optimizer_state_dict' : self.optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
|
||||||
|
'val_PER' : PER,
|
||||||
|
'val_loss' : loss
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(checkpoint, save_path)
|
||||||
|
|
||||||
|
self.logger.info("Saved model to checkpoint: " + save_path)
|
||||||
|
|
||||||
|
# Save the args file alongside the checkpoint
|
||||||
|
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||||||
|
OmegaConf.save(config=self.args, f=f)
|
||||||
|
|
||||||
|
# Wait for all processes to complete checkpoint saving
|
||||||
|
self.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
def create_attention_mask(self, sequence_lengths):
|
||||||
|
|
||||||
|
max_length = torch.max(sequence_lengths).item()
|
||||||
|
|
||||||
|
batch_size = sequence_lengths.size(0)
|
||||||
|
|
||||||
|
# Create a mask for valid key positions (columns)
|
||||||
|
# Shape: [batch_size, max_length]
|
||||||
|
key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
|
||||||
|
key_mask = key_mask < sequence_lengths.unsqueeze(1)
|
||||||
|
|
||||||
|
# Expand key_mask to [batch_size, 1, 1, max_length]
|
||||||
|
# This will be broadcast across all query positions
|
||||||
|
key_mask = key_mask.unsqueeze(1).unsqueeze(1)
|
||||||
|
|
||||||
|
# Create the attention mask of shape [batch_size, 1, max_length, max_length]
|
||||||
|
# by broadcasting key_mask across all query positions
|
||||||
|
attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
|
||||||
|
|
||||||
|
# Convert boolean mask to float mask:
|
||||||
|
# - True (valid key positions) -> 0.0 (no change to attention scores)
|
||||||
|
# - False (padding key positions) -> -inf (will become 0 after softmax)
|
||||||
|
attention_mask_float = torch.where(attention_mask,
|
||||||
|
True,
|
||||||
|
False)
|
||||||
|
|
||||||
|
return attention_mask_float
|
||||||
|
|
||||||
|
def transform_data(self, features, n_time_steps, mode = 'train'):
|
||||||
|
'''
|
||||||
|
Apply various augmentations and smoothing to data
|
||||||
|
Performing augmentations is much faster on GPU than CPU
|
||||||
|
'''
|
||||||
|
|
||||||
|
# TPU and GPU should now handle data consistently with our improved DataLoader configuration
|
||||||
|
|
||||||
|
data_shape = features.shape
|
||||||
|
batch_size = data_shape[0]
|
||||||
|
channels = data_shape[-1]
|
||||||
|
|
||||||
|
# We only apply these augmentations in training
|
||||||
|
if mode == 'train':
|
||||||
|
# add static gain noise
|
||||||
|
if self.transform_args['static_gain_std'] > 0:
|
||||||
|
warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
|
||||||
|
warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
|
||||||
|
|
||||||
|
features = torch.matmul(features, warp_mat)
|
||||||
|
|
||||||
|
# add white noise
|
||||||
|
if self.transform_args['white_noise_std'] > 0:
|
||||||
|
features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
|
||||||
|
|
||||||
|
# add constant offset noise
|
||||||
|
if self.transform_args['constant_offset_std'] > 0:
|
||||||
|
features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
|
||||||
|
|
||||||
|
# add random walk noise
|
||||||
|
if self.transform_args['random_walk_std'] > 0:
|
||||||
|
features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
|
||||||
|
|
||||||
|
# randomly cutoff part of the data timecourse
|
||||||
|
if self.transform_args['random_cut'] > 0:
|
||||||
|
cut = np.random.randint(0, self.transform_args['random_cut'])
|
||||||
|
features = features[:, cut:, :]
|
||||||
|
n_time_steps = n_time_steps - cut
|
||||||
|
|
||||||
|
# Apply Gaussian smoothing to data
|
||||||
|
# This is done in both training and validation
|
||||||
|
if self.transform_args['smooth_data']:
|
||||||
|
features = gauss_smooth(
|
||||||
|
inputs = features,
|
||||||
|
device = self.device,
|
||||||
|
smooth_kernel_std = self.transform_args['smooth_kernel_std'],
|
||||||
|
smooth_kernel_size= self.transform_args['smooth_kernel_size'],
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self, 'model_dtype'):
|
||||||
|
features = features.to(self.model_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
return features, n_time_steps
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
'''
|
||||||
|
Train the model
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Set model to train mode (specificially to make sure dropout layers are engaged)
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
# create vars to track performance
|
||||||
|
train_losses = []
|
||||||
|
val_losses = []
|
||||||
|
val_PERs = []
|
||||||
|
val_results = []
|
||||||
|
|
||||||
|
val_steps_since_improvement = 0
|
||||||
|
|
||||||
|
# training params
|
||||||
|
save_best_checkpoint = self.args.get('save_best_checkpoint', True)
|
||||||
|
early_stopping = self.args.get('early_stopping', True)
|
||||||
|
|
||||||
|
early_stopping_val_steps = self.args['early_stopping_val_steps']
|
||||||
|
|
||||||
|
train_start_time = time.time()
|
||||||
|
|
||||||
|
# train for specified number of batches
|
||||||
|
self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...")
|
||||||
|
for i, batch in enumerate(self.train_loader):
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Train step
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Data is automatically moved to device by Accelerator
|
||||||
|
features = batch['input_features']
|
||||||
|
labels = batch['seq_class_ids']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
phone_seq_lens = batch['phone_seq_lens']
|
||||||
|
day_indicies = batch['day_indicies']
|
||||||
|
|
||||||
|
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
||||||
|
with self.autocast_context():
|
||||||
|
|
||||||
|
# Apply augmentations to the data
|
||||||
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
|
||||||
|
|
||||||
|
# Ensure proper dtype handling for TPU mixed precision
|
||||||
|
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
|
||||||
|
# Get phoneme predictions using inference mode during training
|
||||||
|
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
|
||||||
|
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||||
|
if features.dtype != self.model_dtype:
|
||||||
|
features = features.to(self.model_dtype)
|
||||||
|
|
||||||
|
# Forward pass: enable full adversarial mode if configured and past warmup
|
||||||
|
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
|
||||||
|
if use_full:
|
||||||
|
clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
|
||||||
|
else:
|
||||||
|
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||||
|
|
||||||
|
# Calculate CTC Loss
|
||||||
|
if use_full:
|
||||||
|
# Clean CTC loss
|
||||||
|
clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2)
|
||||||
|
clean_loss = self.ctc_loss(
|
||||||
|
clean_log_probs,
|
||||||
|
labels,
|
||||||
|
adjusted_lens,
|
||||||
|
phone_seq_lens
|
||||||
|
)
|
||||||
|
clean_loss = torch.mean(clean_loss)
|
||||||
|
|
||||||
|
# Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
|
||||||
|
noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2)
|
||||||
|
noisy_loss = self.ctc_loss(
|
||||||
|
noisy_log_probs,
|
||||||
|
labels,
|
||||||
|
adjusted_lens,
|
||||||
|
phone_seq_lens
|
||||||
|
)
|
||||||
|
noisy_loss = torch.mean(noisy_loss)
|
||||||
|
|
||||||
|
# Optional noise energy regularization
|
||||||
|
noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
|
||||||
|
if self.adv_noise_l2_weight > 0.0:
|
||||||
|
noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype)
|
||||||
|
|
||||||
|
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||||
|
else:
|
||||||
|
log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
|
||||||
|
loss = self.ctc_loss(
|
||||||
|
log_probs=log_probs,
|
||||||
|
targets=labels,
|
||||||
|
input_lengths=adjusted_lens,
|
||||||
|
target_lengths=phone_seq_lens
|
||||||
|
)
|
||||||
|
loss = torch.mean(loss) # take mean loss over batches
|
||||||
|
|
||||||
|
# Use Accelerator's backward for distributed training
|
||||||
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
# Clip gradient using Accelerator's clip_grad_norm_
|
||||||
|
if self.args['grad_norm_clip_value'] > 0:
|
||||||
|
grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(),
|
||||||
|
max_norm = self.args['grad_norm_clip_value'])
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.learning_rate_scheduler.step()
|
||||||
|
|
||||||
|
# Save training metrics
|
||||||
|
train_step_duration = time.time() - start_time
|
||||||
|
train_losses.append(loss.detach().item())
|
||||||
|
|
||||||
|
# Incrementally log training progress
|
||||||
|
if i % self.args['batches_per_train_log'] == 0:
|
||||||
|
self.logger.info(f'Train batch {i}: ' +
|
||||||
|
f'loss: {(loss.detach().item()):.2f} ' +
|
||||||
|
f'grad norm: {grad_norm:.2f} '
|
||||||
|
f'time: {train_step_duration:.3f}')
|
||||||
|
|
||||||
|
# Incrementally run a test step
|
||||||
|
if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
|
||||||
|
self.logger.info(f"Running test after training batch: {i}")
|
||||||
|
|
||||||
|
# Calculate metrics on val data
|
||||||
|
start_time = time.time()
|
||||||
|
val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
|
||||||
|
val_step_duration = time.time() - start_time
|
||||||
|
|
||||||
|
|
||||||
|
# Log info
|
||||||
|
self.logger.info(f'Val batch {i}: ' +
|
||||||
|
f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
|
||||||
|
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
|
||||||
|
f'time: {val_step_duration:.3f}')
|
||||||
|
|
||||||
|
if self.args['log_individual_day_val_PER']:
|
||||||
|
for day in val_metrics['day_PERs'].keys():
|
||||||
|
self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
|
||||||
|
|
||||||
|
# Save metrics
|
||||||
|
val_PERs.append(val_metrics['avg_PER'])
|
||||||
|
val_losses.append(val_metrics['avg_loss'])
|
||||||
|
val_results.append(val_metrics)
|
||||||
|
|
||||||
|
# Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
|
||||||
|
new_best = False
|
||||||
|
if val_metrics['avg_PER'] < self.best_val_PER:
|
||||||
|
self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
|
||||||
|
self.best_val_PER = val_metrics['avg_PER']
|
||||||
|
self.best_val_loss = val_metrics['avg_loss']
|
||||||
|
new_best = True
|
||||||
|
elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss):
|
||||||
|
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
|
||||||
|
self.best_val_loss = val_metrics['avg_loss']
|
||||||
|
new_best = True
|
||||||
|
|
||||||
|
if new_best:
|
||||||
|
|
||||||
|
# Checkpoint if metrics have improved
|
||||||
|
if save_best_checkpoint:
|
||||||
|
self.logger.info(f"Checkpointing model")
|
||||||
|
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
|
||||||
|
|
||||||
|
# save validation metrics to pickle file
|
||||||
|
if self.args['save_val_metrics']:
|
||||||
|
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
||||||
|
pickle.dump(val_metrics, f)
|
||||||
|
|
||||||
|
val_steps_since_improvement = 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
val_steps_since_improvement +=1
|
||||||
|
|
||||||
|
# Optionally save this validation checkpoint, regardless of performance
|
||||||
|
if self.args['save_all_val_steps']:
|
||||||
|
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
|
||||||
|
self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
|
||||||
|
break
|
||||||
|
|
||||||
|
# Log final training steps
|
||||||
|
training_duration = time.time() - train_start_time
|
||||||
|
|
||||||
|
|
||||||
|
self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
|
||||||
|
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
if self.args['save_final_model']:
|
||||||
|
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
|
||||||
|
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
|
||||||
|
|
||||||
|
train_stats = {}
|
||||||
|
train_stats['train_losses'] = train_losses
|
||||||
|
train_stats['val_losses'] = val_losses
|
||||||
|
train_stats['val_PERs'] = val_PERs
|
||||||
|
train_stats['val_metrics'] = val_results
|
||||||
|
|
||||||
|
return train_stats
|
||||||
|
|
||||||
|
def validation(self, loader, return_logits = False, return_data = False):
|
||||||
|
'''
|
||||||
|
Calculate metrics on the validation dataset
|
||||||
|
'''
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
# Record metrics
|
||||||
|
if return_logits:
|
||||||
|
metrics['logits'] = []
|
||||||
|
metrics['n_time_steps'] = []
|
||||||
|
|
||||||
|
if return_data:
|
||||||
|
metrics['input_features'] = []
|
||||||
|
|
||||||
|
metrics['decoded_seqs'] = []
|
||||||
|
metrics['true_seq'] = []
|
||||||
|
metrics['phone_seq_lens'] = []
|
||||||
|
metrics['transcription'] = []
|
||||||
|
metrics['losses'] = []
|
||||||
|
metrics['block_nums'] = []
|
||||||
|
metrics['trial_nums'] = []
|
||||||
|
metrics['day_indicies'] = []
|
||||||
|
|
||||||
|
total_edit_distance = 0
|
||||||
|
total_seq_length = 0
|
||||||
|
|
||||||
|
# Calculate PER for each specific day
|
||||||
|
day_per = {}
|
||||||
|
for d in range(len(self.args['dataset']['sessions'])):
|
||||||
|
if self.args['dataset']['dataset_probability_val'][d] == 1:
|
||||||
|
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
|
||||||
|
|
||||||
|
for i, batch in enumerate(loader):
|
||||||
|
|
||||||
|
# Data is automatically moved to device by Accelerator
|
||||||
|
features = batch['input_features']
|
||||||
|
labels = batch['seq_class_ids']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
phone_seq_lens = batch['phone_seq_lens']
|
||||||
|
day_indicies = batch['day_indicies']
|
||||||
|
|
||||||
|
# Determine if we should perform validation on this batch
|
||||||
|
day = day_indicies[0].item()
|
||||||
|
if self.args['dataset']['dataset_probability_val'][day] == 0:
|
||||||
|
if self.args['log_val_skip_logs']:
|
||||||
|
self.logger.info(f"Skipping validation on day {day}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
with self.autocast_context():
|
||||||
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
|
# Ensure proper dtype handling for TPU mixed precision
|
||||||
|
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
|
||||||
|
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||||
|
model_param = next(self.model.parameters()) if self.model is not None else None
|
||||||
|
if model_param is not None and features.dtype != model_param.dtype:
|
||||||
|
features = features.to(model_param.dtype)
|
||||||
|
|
||||||
|
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||||
|
|
||||||
|
val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
|
||||||
|
loss = self.ctc_loss(
|
||||||
|
val_log_probs,
|
||||||
|
labels,
|
||||||
|
adjusted_lens,
|
||||||
|
phone_seq_lens,
|
||||||
|
)
|
||||||
|
loss = torch.mean(loss)
|
||||||
|
|
||||||
|
metrics['losses'].append(loss.cpu().detach().numpy())
|
||||||
|
|
||||||
|
# Calculate PER per day and also avg over entire validation set
|
||||||
|
batch_edit_distance = 0
|
||||||
|
decoded_seqs = []
|
||||||
|
for iterIdx in range(logits.shape[0]):
|
||||||
|
decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
|
||||||
|
decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
|
||||||
|
decoded_seq = decoded_seq.cpu().detach().numpy()
|
||||||
|
decoded_seq = np.array([i for i in decoded_seq if i != 0])
|
||||||
|
|
||||||
|
trueSeq = np.array(
|
||||||
|
labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
|
||||||
|
|
||||||
|
decoded_seqs.append(decoded_seq)
|
||||||
|
|
||||||
|
day = batch['day_indicies'][0].item()
|
||||||
|
|
||||||
|
day_per[day]['total_edit_distance'] += batch_edit_distance
|
||||||
|
day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
|
||||||
|
|
||||||
|
|
||||||
|
total_edit_distance += batch_edit_distance
|
||||||
|
total_seq_length += torch.sum(phone_seq_lens)
|
||||||
|
|
||||||
|
# Record metrics
|
||||||
|
if return_logits:
|
||||||
|
metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
|
||||||
|
metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
|
||||||
|
|
||||||
|
if return_data:
|
||||||
|
metrics['input_features'].append(batch['input_features'].cpu().numpy())
|
||||||
|
|
||||||
|
metrics['decoded_seqs'].append(decoded_seqs)
|
||||||
|
metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
|
||||||
|
metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
|
||||||
|
metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
|
||||||
|
metrics['losses'].append(loss.detach().item())
|
||||||
|
metrics['block_nums'].append(batch['block_nums'].numpy())
|
||||||
|
metrics['trial_nums'].append(batch['trial_nums'].numpy())
|
||||||
|
metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
|
||||||
|
|
||||||
|
if isinstance(total_seq_length, torch.Tensor):
|
||||||
|
total_length_value = float(total_seq_length.item())
|
||||||
|
else:
|
||||||
|
total_length_value = float(total_seq_length)
|
||||||
|
|
||||||
|
avg_PER = total_edit_distance / max(total_length_value, 1e-6)
|
||||||
|
|
||||||
|
metrics['day_PERs'] = day_per
|
||||||
|
metrics['avg_PER'] = avg_PER
|
||||||
|
metrics['avg_loss'] = float(np.mean(metrics['losses']))
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def inference(self, features, day_indicies, n_time_steps, mode='inference'):
|
||||||
|
'''
|
||||||
|
TPU-compatible inference method for generating phoneme logits
|
||||||
|
'''
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.autocast_context():
|
||||||
|
# Apply data transformations (no augmentation for inference)
|
||||||
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
|
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||||
|
if features.dtype != self.model_dtype:
|
||||||
|
features = features.to(self.model_dtype)
|
||||||
|
|
||||||
|
# Get phoneme predictions
|
||||||
|
logits = self.model(features, day_indicies, None, False, mode)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def inference_batch(self, batch, mode='inference'):
|
||||||
|
'''
|
||||||
|
Inference method for processing a full batch
|
||||||
|
'''
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Data is automatically moved to device by Accelerator
|
||||||
|
features = batch['input_features']
|
||||||
|
day_indicies = batch['day_indicies']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.autocast_context():
|
||||||
|
# Apply data transformations (no augmentation for inference)
|
||||||
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
|
# Calculate adjusted sequence lengths for CTC with proper dtype handling
|
||||||
|
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
|
||||||
|
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||||
|
if features.dtype != self.model_dtype:
|
||||||
|
features = features.to(self.model_dtype)
|
||||||
|
|
||||||
|
# Get phoneme predictions
|
||||||
|
logits = self.model(features, day_indicies, None, False, mode)
|
||||||
|
|
||||||
|
return logits, adjusted_lens
|
27
model_training_nnn_tpu/start_tpu_training.sh
Normal file
27
model_training_nnn_tpu/start_tpu_training.sh
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# TPU XLA Multi-threading Environment Setup
|
||||||
|
# Set these BEFORE starting Python to ensure they take effect
|
||||||
|
|
||||||
|
echo "Setting up XLA multi-threading environment..."
|
||||||
|
|
||||||
|
# Get CPU core count
|
||||||
|
CPU_CORES=$(nproc)
|
||||||
|
echo "Detected $CPU_CORES CPU cores"
|
||||||
|
|
||||||
|
# Set XLA compilation flags
|
||||||
|
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES"
|
||||||
|
export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES
|
||||||
|
|
||||||
|
# Additional XLA optimizations
|
||||||
|
export XLA_USE_BF16=1
|
||||||
|
export TPU_CORES=8
|
||||||
|
|
||||||
|
# Print current settings
|
||||||
|
echo "XLA_FLAGS: $XLA_FLAGS"
|
||||||
|
echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS"
|
||||||
|
echo "XLA_USE_BF16: $XLA_USE_BF16"
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
echo "Starting TPU training..."
|
||||||
|
python train_model.py --config_path rnn_args.yaml
|
162
model_training_nnn_tpu/test_simple_model.py
Normal file
162
model_training_nnn_tpu/test_simple_model.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
简化模型测试脚本 - 验证XLA编译是否正常工作
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# 设置XLA环境变量(必须在导入torch_xla之前)
|
||||||
|
os.environ['XLA_FLAGS'] = (
|
||||||
|
'--xla_cpu_multi_thread_eigen=true '
|
||||||
|
'--xla_cpu_enable_fast_math=true '
|
||||||
|
f'--xla_force_host_platform_device_count={os.cpu_count()}'
|
||||||
|
)
|
||||||
|
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
|
||||||
|
os.environ['XLA_USE_BF16'] = '1'
|
||||||
|
|
||||||
|
print(f"🔧 XLA环境变量设置:")
|
||||||
|
print(f" CPU核心数: {os.cpu_count()}")
|
||||||
|
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||||||
|
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||||
|
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
class SimpleModel(nn.Module):
|
||||||
|
"""简化的测试模型"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(512, 256)
|
||||||
|
self.gru = nn.GRU(256, 128, batch_first=True)
|
||||||
|
self.linear2 = nn.Linear(128, 41) # 41个音素类别
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.relu(self.linear1(x))
|
||||||
|
x, _ = self.gru(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_xla_compilation():
|
||||||
|
"""测试XLA编译速度"""
|
||||||
|
print("\n🚀 开始简化模型XLA编译测试...")
|
||||||
|
|
||||||
|
# 检查TPU设备
|
||||||
|
device = xm.xla_device()
|
||||||
|
print(f"📱 TPU设备: {device}")
|
||||||
|
print(f"🌍 TPU World Size: {xm.xrt_world_size()}")
|
||||||
|
|
||||||
|
# 创建简化模型
|
||||||
|
model = SimpleModel().to(device)
|
||||||
|
print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
batch_size = 8 # 小批次
|
||||||
|
seq_len = 100 # 短序列
|
||||||
|
x = torch.randn(batch_size, seq_len, 512, device=device)
|
||||||
|
|
||||||
|
print(f"📥 输入形状: {x.shape}")
|
||||||
|
|
||||||
|
# 首次前向传播 - 触发XLA编译
|
||||||
|
print(f"🔄 开始首次前向传播 (XLA编译)...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(x)
|
||||||
|
|
||||||
|
compile_time = time.time() - start_time
|
||||||
|
print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}秒")
|
||||||
|
print(f"📤 输出形状: {output.shape}")
|
||||||
|
|
||||||
|
# 再次前向传播 - 使用编译后的图
|
||||||
|
print(f"🔄 第二次前向传播 (使用编译后的图)...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output2 = model(x)
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
print(f"⚡ 执行完成! 耗时: {execution_time:.4f}秒")
|
||||||
|
|
||||||
|
# 性能对比
|
||||||
|
speedup = compile_time / execution_time if execution_time > 0 else float('inf')
|
||||||
|
print(f"\n📈 性能分析:")
|
||||||
|
print(f" 编译时间: {compile_time:.2f}秒")
|
||||||
|
print(f" 执行时间: {execution_time:.4f}秒")
|
||||||
|
print(f" 加速比: {speedup:.1f}x")
|
||||||
|
|
||||||
|
if compile_time < 60: # 1分钟内编译完成
|
||||||
|
print("✅ XLA编译正常!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ XLA编译过慢,可能有问题")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_training_step():
|
||||||
|
"""测试训练步骤"""
|
||||||
|
print("\n🎯 测试简化训练步骤...")
|
||||||
|
|
||||||
|
device = xm.xla_device()
|
||||||
|
model = SimpleModel().to(device)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
# 创建训练数据
|
||||||
|
x = torch.randn(4, 50, 512, device=device)
|
||||||
|
labels = torch.randint(0, 41, (4, 50), device=device)
|
||||||
|
|
||||||
|
print(f"🔄 开始训练步骤 (包含反向传播)...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
outputs = model(x)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
loss = criterion(outputs.view(-1, 41), labels.view(-1))
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}")
|
||||||
|
|
||||||
|
return step_time < 120 # 2分钟内完成
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("🧪 XLA编译快速测试")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 测试1: 简单模型编译
|
||||||
|
compilation_ok = test_xla_compilation()
|
||||||
|
|
||||||
|
if compilation_ok:
|
||||||
|
# 测试2: 训练步骤
|
||||||
|
training_ok = test_training_step()
|
||||||
|
|
||||||
|
if training_ok:
|
||||||
|
print("\n✅ 所有测试通过! 可以尝试完整模型训练")
|
||||||
|
print("💡 建议:")
|
||||||
|
print(" 1. 确保有足够内存 (32GB+)")
|
||||||
|
print(" 2. 减小batch_size (比如从32改为16)")
|
||||||
|
print(" 3. 使用gradient_accumulation_steps补偿")
|
||||||
|
else:
|
||||||
|
print("\n⚠️ 训练步骤较慢,建议优化")
|
||||||
|
else:
|
||||||
|
print("\n❌ XLA编译有问题,需要检查环境")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n💥 测试失败: {e}")
|
||||||
|
print("💡 可能的问题:")
|
||||||
|
print(" - TPU资源不可用")
|
||||||
|
print(" - PyTorch XLA安装问题")
|
||||||
|
print(" - 内存不足")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
25
model_training_nnn_tpu/train_model.py
Normal file
25
model_training_nnn_tpu/train_model.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import argparse
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from rnn_trainer import BrainToTextDecoder_Trainer
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
|
||||||
|
parser.add_argument('--config_path', default='rnn_args.yaml',
|
||||||
|
help='Path to configuration file (default: rnn_args.yaml)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
config = OmegaConf.load(args.config_path)
|
||||||
|
|
||||||
|
# Initialize trainer
|
||||||
|
trainer = BrainToTextDecoder_Trainer(config)
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
print("Training completed successfully!")
|
||||||
|
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user