diff --git a/model_training_nnn_tpu/README.md b/model_training_nnn_tpu/README.md
new file mode 100644
index 0000000..5d20d83
--- /dev/null
+++ b/model_training_nnn_tpu/README.md
@@ -0,0 +1,79 @@
+# TPU-Optimized Brain-to-Text Model Training
+
+This directory contains TPU-optimized code for training the brain-to-text RNN model with advanced adversarial training architecture. The model is based on "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), enhanced with three-model adversarial training and comprehensive XLA optimizations for efficient TPU training.
+
+## Key Features
+
+- **Triple-Model Adversarial Architecture**: NoiseModel + CleanSpeechModel + NoisySpeechModel for robust neural decoding
+- **XLA/TPU Optimizations**: Comprehensive optimizations for fast compilation and efficient TPU utilization
+- **Mixed Precision Training**: bfloat16 support with full dtype consistency
+- **Distributed Training**: 8-core TPU support with Accelerate library integration
+- **687M Parameters**: Large-scale model with patch processing and day-specific adaptations
+
+For detailed technical documentation, see [TPU_MODEL_SUMMARY.md](TPU_MODEL_SUMMARY.md).
+
+## Setup
+1. Install the required `b2txt25` conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
+
+2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. See the main [README.md](../README.md) file for more details on the included datasets and the proper `data` directory structure.
+
+## TPU Training
+
+### Triple-Model Adversarial Architecture
+This implementation features an advanced three-model adversarial training system:
+- **NoiseModel**: 2-layer GRU that estimates noise in neural data
+- **CleanSpeechModel**: 3-layer GRU that processes denoised signals for speech recognition
+- **NoisySpeechModel**: 2-layer GRU that processes noise signals for adversarial training
+
+The architecture uses residual connections and gradient reversal layers (GRL) to improve robustness. All models include day-specific input layers (512x512 linear with softsign activation), patch processing (14 timesteps), and are optimized for XLA compilation on TPU.
+
+### Training Methods
+
+#### Option 1: Direct Training
+```bash
+conda activate b2txt25
+python train_model.py --config_path rnn_args.yaml
+```
+
+#### Option 2: Launcher Script (Recommended)
+```bash
+python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
+```
+
+#### Option 3: Accelerate
+```bash
+accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
+```
+
+The model trains for 120,000 mini-batches with mixed precision (bfloat16) and distributed training across 8 TPU cores. Expected training time varies based on TPU type and configuration. All hyperparameters are specified in [`rnn_args.yaml`](rnn_args.yaml).
+
+## Model Configuration
+
+### Key Configuration Files
+- **`rnn_args.yaml`**: Main training configuration with adversarial training settings
+- **`accelerate_config_tpu.yaml`**: Accelerate library configuration for TPU
+- **`launch_tpu_training.py`**: Convenient TPU training launcher
+
+### Adversarial Training Settings
+```yaml
+adversarial:
+ enabled: true
+ grl_lambda: 0.5 # Gradient Reversal Layer strength
+ noisy_loss_weight: 0.2 # Weight for noisy branch CTC loss
+ noise_l2_weight: 0.0 # L2 regularization on noise output
+ warmup_steps: 0 # Steps before enabling adversarial training
+```
+
+### TPU-Specific Settings
+```yaml
+use_tpu: true
+num_tpu_cores: 8
+gradient_accumulation_steps: 2
+use_amp: true # bfloat16 mixed precision
+batch_size: 32 # Per-core batch size
+num_dataloader_workers: 0 # Required for TPU
+```
+
+## Evaluation
+
+Model evaluation using the trained TripleGRUDecoder requires the language model pipeline. Please refer to the main project README for complete evaluation setup instructions. The evaluation scripts in this directory are currently being adapted for TPU compatibility.
diff --git a/model_training_nnn_tpu/TPU_MODEL_SUMMARY.md b/model_training_nnn_tpu/TPU_MODEL_SUMMARY.md
new file mode 100644
index 0000000..ebd489f
--- /dev/null
+++ b/model_training_nnn_tpu/TPU_MODEL_SUMMARY.md
@@ -0,0 +1,183 @@
+# TPU优化的Brain-to-Text模型代码总结
+
+## 项目概述
+
+这个目录包含了专门为TPU训练优化的Brain-to-Text RNN模型代码,基于发表在《新英格兰医学杂志》(2024)的"An Accurate and Rapidly Calibrating Speech Neuroprosthesis"论文。该模型将大脑语音运动皮层的神经信号转换为文本,使用RNN模型和n-gram语言模型。
+
+## 核心架构改进
+
+### 三模型对抗训练架构 (TripleGRUDecoder)
+
+替代原来的单一GRU模型,新架构包含三个协同工作的子模型:
+
+1. **NoiseModel** (2层GRU)
+ - 估计神经数据中的噪声
+ - 输入维度:512 → 输出维度:与输入相同
+ - 作用:从原始信号中分离噪声成分
+
+2. **CleanSpeechModel** (3层GRU + 分类层)
+ - 处理去噪后的信号进行语音识别
+ - 包含day-specific输入层
+ - 输出:41类音素的logits
+
+3. **NoisySpeechModel** (2层GRU + 分类层)
+ - 直接处理噪声信号进行语音识别
+ - 用于对抗训练,提高NoiseModel的鲁棒性
+ - 输出:41类音素的logits
+
+### 对抗训练机制
+
+- **残差连接**: `denoised_input = x_processed - noise_output`
+- **梯度反转层(GRL)**: 在训练时对噪声输出应用梯度反转
+- **多目标损失**: 结合clean和noisy分支的CTC损失
+
+## TPU/XLA优化特性
+
+### 1. XLA友好的操作设计
+
+**静态张量操作替代动态操作**:
+```python
+# 优化前 (XLA不友好):
+day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
+
+# 优化后 (XLA友好):
+all_day_weights = torch.stack(list(self.day_weights), dim=0)
+day_weights = torch.index_select(all_day_weights, 0, day_idx)
+```
+
+**XLA原语操作**:
+```python
+# 使用batch matrix multiplication (bmm)替代einsum
+x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
+```
+
+### 2. 混合精度训练的数据类型一致性
+
+**全面的dtype一致性处理**:
+- 基础操作中的dtype转换
+- 补丁处理过程中的dtype保持
+- 对抗训练残差连接的dtype匹配
+- 梯度反转层的dtype处理
+- 隐藏状态初始化的dtype一致性
+
+### 3. 内存和编译优化
+
+- **禁用autocast**: 在GRU操作中禁用自动混合精度以避免dtype冲突
+- **静态形状**: 避免动态批次大小分配
+- **元组返回**: 使用元组替代字典以获得更好的XLA编译性能
+
+## 关键文件结构
+
+### 核心训练文件
+
+- **`rnn_model.py`**: 包含TripleGRUDecoder和三个子模型的完整实现,具有XLA优化
+- **`rnn_trainer.py`**: TPU训练器,集成Accelerate库,支持分布式训练
+- **`train_model.py`**: 简洁的训练启动脚本
+- **`rnn_args.yaml`**: TPU训练配置文件
+
+### TPU特定文件
+
+- **`accelerate_config_tpu.yaml`**: Accelerate库的TPU配置
+- **`launch_tpu_training.py`**: TPU训练的便捷启动脚本
+- **`TPU_SETUP_GUIDE.md`**: TPU环境设置指南
+
+### 辅助文件
+
+- **`dataset.py`**: 数据集加载和批处理
+- **`data_augmentations.py`**: 数据增强工具
+- **`evaluate_model_helpers.py`**: 评估工具函数
+
+## 训练配置亮点
+
+### TPU特定设置
+```yaml
+# TPU分布式训练设置
+use_tpu: true
+num_tpu_cores: 8
+gradient_accumulation_steps: 2
+use_amp: true # bfloat16混合精度
+
+# 优化的批次配置
+batch_size: 32 # 每个TPU核心的批次大小
+num_dataloader_workers: 0 # TPU上设为0避免多进程问题
+```
+
+### 对抗训练配置
+```yaml
+adversarial:
+ enabled: true
+ grl_lambda: 0.5 # 梯度反转强度
+ noisy_loss_weight: 0.2 # 噪声分支损失权重
+ noise_l2_weight: 0.0 # 噪声输出L2正则化
+ warmup_steps: 0 # 对抗训练预热步数
+```
+
+## 模型规模
+
+- **总参数**: ~687M个参数
+- **神经输入**: 512特征 (每个电极2个特征 × 256个电极)
+- **GRU隐藏单元**: 768个/层
+- **输出类别**: 41个音素
+- **补丁处理**: 14个时间步的补丁,步长为4
+
+## 数据流
+
+1. **输入**: 512维神经特征,20ms分辨率
+2. **Day-specific变换**: 每日特定的线性变换和softsign激活
+3. **补丁处理**: 将14个时间步连接为更大的输入向量
+4. **三模型处理**:
+ - NoiseModel估计噪声
+ - CleanSpeechModel处理去噪信号
+ - NoisySpeechModel处理噪声信号(仅训练时)
+5. **输出**: CTC兼容的音素logits
+
+## 训练流程
+
+### 推理模式 (`mode='inference'`):
+- 只使用NoiseModel + CleanSpeechModel
+- 计算: `clean_logits = CleanSpeechModel(x - NoiseModel(x))`
+
+### 完整模式 (`mode='full'`, 训练时):
+- 使用所有三个模型
+- 对抗训练与梯度反转
+- 多目标CTC损失
+
+## 性能特点
+
+- **编译优化**: XLA优化实现更快的TPU编译
+- **内存效率**: bfloat16混合精度减少内存使用
+- **分布式训练**: 支持8核心TPU并行训练
+- **数据增强**: 高斯平滑、白噪声、时间抖动等
+
+## 使用方法
+
+### 基本训练
+```bash
+python train_model.py --config_path rnn_args.yaml
+```
+
+### 使用启动脚本
+```bash
+python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
+```
+
+### 使用Accelerate
+```bash
+accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
+```
+
+## 与原始模型的兼容性
+
+- 保持相同的数学运算和模型架构
+- 保留所有原始接口
+- 支持'inference'和'full'两种模式
+- 向后兼容现有训练脚本
+
+## 技术创新点
+
+1. **三模型对抗架构**: 创新的噪声估计和去噪方法
+2. **XLA优化**: 全面的TPU编译优化
+3. **混合精度一致性**: 解决了复杂对抗训练中的dtype冲突
+4. **分布式训练集成**: 无缝的多核心TPU支持
+
+这个TPU优化版本保持了原始模型的准确性,同时显著提高了训练效率和可扩展性,特别适合大规模神经解码任务的训练。
\ No newline at end of file
diff --git a/model_training_nnn_tpu/TPU_SETUP_GUIDE.md b/model_training_nnn_tpu/TPU_SETUP_GUIDE.md
new file mode 100644
index 0000000..fed0b80
--- /dev/null
+++ b/model_training_nnn_tpu/TPU_SETUP_GUIDE.md
@@ -0,0 +1,204 @@
+# TPU Training Setup Guide for Brain-to-Text RNN
+
+This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
+
+## Prerequisites
+
+### 1. Install PyTorch XLA for TPU Support
+```bash
+# Install PyTorch XLA (adjust version as needed)
+pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
+
+# Or for specific PyTorch version:
+pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
+```
+
+### 2. Install Accelerate Library
+```bash
+pip install accelerate
+```
+
+### 3. Verify TPU Access
+```bash
+# Check if TPU is available
+python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
+```
+
+## Configuration Setup
+
+### 1. Enable TPU in Configuration File
+
+Update your `rnn_args.yaml` file with TPU settings:
+
+```yaml
+# TPU and distributed training settings
+use_tpu: true # Enable TPU training
+num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8)
+gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size
+dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues
+use_amp: true # Enable mixed precision (bfloat16)
+
+# Adjust batch size for multi-core TPU
+dataset:
+ batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64)
+```
+
+### 2. TPU-Optimized Hyperparameters
+
+Recommended adjustments for TPU training:
+
+```yaml
+# Learning rate scaling for distributed training
+lr_max: 0.005 # May need to scale with number of cores
+lr_max_day: 0.005
+
+# Batch size considerations
+dataset:
+ batch_size: 8 # Per-core batch size
+ days_per_batch: 4 # Keep consistent across cores
+```
+
+## Training Launch Options
+
+### Method 1: Using the TPU Launch Script (Recommended)
+
+```bash
+# Basic TPU training with 8 cores
+python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
+
+# Check TPU environment only
+python launch_tpu_training.py --check_only
+
+# Custom configuration file
+python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
+```
+
+### Method 2: Direct Accelerate Launch
+
+```bash
+# Configure accelerate (one-time setup)
+accelerate config
+
+# Or use provided TPU config
+export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
+
+# Launch training
+accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
+```
+
+### Method 3: Manual XLA Launch (Advanced)
+
+```bash
+# Set TPU environment variables
+export TPU_CORES=8
+export XLA_USE_BF16=1
+
+# Launch with PyTorch XLA
+python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
+```
+
+## Key TPU Features Implemented
+
+### 1. Distributed Training Support
+- Automatic model parallelization across 8 TPU cores
+- Synchronized gradient updates across all cores
+- Proper checkpoint saving/loading for distributed training
+
+### 2. Mixed Precision Training
+- Automatic bfloat16 precision for TPU optimization
+- Faster training with maintained numerical stability
+- Reduced memory usage
+
+### 3. TPU-Optimized Data Loading
+- Single-threaded data loading (num_workers=0) for TPU compatibility
+- Automatic data distribution across TPU cores
+- Efficient batch processing
+
+### 4. Inference Support
+- TPU-compatible inference methods added to trainer class
+- `inference()` and `inference_batch()` methods for production use
+- Automatic mixed precision during inference
+
+## Performance Optimization Tips
+
+### 1. Batch Size Tuning
+- Start with total batch size = 64 (8 cores × 8 per core)
+- Increase gradually if memory allows
+- Monitor TPU utilization with `top` command
+
+### 2. Gradient Accumulation
+- Use `gradient_accumulation_steps` to simulate larger batch sizes
+- Effective batch size = batch_size × num_cores × gradient_accumulation_steps
+
+### 3. Learning Rate Scaling
+- Consider scaling learning rate with number of cores
+- Linear scaling: `lr_new = lr_base × num_cores`
+- May need warmup adjustment for large batch training
+
+### 4. Memory Management
+- TPU v3-8: 128GB HBM memory total
+- TPU v4-8: 512GB HBM memory total
+- Monitor memory usage to avoid OOM errors
+
+## Monitoring and Debugging
+
+### 1. TPU Utilization
+```bash
+# Monitor TPU usage
+watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
+```
+
+### 2. Training Logs
+- Training logs include device information and core count
+- Monitor validation metrics across all cores
+- Check for synchronization issues in distributed training
+
+### 3. Common Issues and Solutions
+
+**Issue**: "No TPU devices found"
+- **Solution**: Verify TPU runtime is started and accessible
+
+**Issue**: "DataLoader workers > 0 causes hangs"
+- **Solution**: Set `dataloader_num_workers: 0` in config
+
+**Issue**: "Mixed precision errors"
+- **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16
+
+**Issue**: "Gradient synchronization timeouts"
+- **Solution**: Check network connectivity between TPU cores
+
+## Example Training Command
+
+```bash
+# Complete TPU training example
+cd model_training_nnn
+
+# 1. Update config for TPU
+vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8
+
+# 2. Launch TPU training
+python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
+
+# 3. Monitor training progress
+tail -f trained_models/baseline_rnn/training_log
+```
+
+## Configuration Reference
+
+### Required TPU Settings
+```yaml
+use_tpu: true
+num_tpu_cores: 8
+dataloader_num_workers: 0
+use_amp: true
+```
+
+### Optional TPU Optimizations
+```yaml
+gradient_accumulation_steps: 1
+dataset:
+ batch_size: 8 # Per-core batch size
+mixed_precision: bf16
+```
+
+This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.
\ No newline at end of file
diff --git a/model_training_nnn_tpu/accelerate_config_tpu.yaml b/model_training_nnn_tpu/accelerate_config_tpu.yaml
new file mode 100644
index 0000000..0b48dab
--- /dev/null
+++ b/model_training_nnn_tpu/accelerate_config_tpu.yaml
@@ -0,0 +1,26 @@
+# Accelerate Configuration for TPU Training
+# This file configures Accelerate library for 8-core TPU training
+# with mixed precision (bfloat16) support
+
+compute_environment: TPU
+distributed_type: TPU
+tpu_name: null # Will use default TPU
+tpu_zone: null # Will use default zone
+
+# Mixed precision settings (use bfloat16 for TPU)
+mixed_precision: bf16
+
+# Number of TPU cores (v3-8 or v4-8 TPUs have 8 cores)
+num_processes: 8
+
+# Enable TPU debugging (set to false for production)
+tpu_use_cluster: false
+tpu_use_sudo: false
+
+# Logging settings
+main_process_port: null
+machine_rank: 0
+num_machines: 1
+
+# Enable automatic optimization
+use_cpu: false
\ No newline at end of file
diff --git a/model_training_nnn_tpu/check_xla_threads.py b/model_training_nnn_tpu/check_xla_threads.py
new file mode 100644
index 0000000..ceaf3ec
--- /dev/null
+++ b/model_training_nnn_tpu/check_xla_threads.py
@@ -0,0 +1,148 @@
+#!/usr/bin/env python3
+"""
+XLA Multi-threading Diagnostic Script
+检查XLA编译是否正确使用多CPU核心
+"""
+
+import os
+import psutil
+import time
+import threading
+from concurrent.futures import ThreadPoolExecutor
+
+def set_xla_environment():
+ """设置XLA环境变量"""
+ cpu_count = os.cpu_count()
+
+ # 设置XLA环境变量
+ os.environ['XLA_FLAGS'] = (
+ '--xla_cpu_multi_thread_eigen=true '
+ '--xla_cpu_enable_fast_math=true '
+ f'--xla_force_host_platform_device_count={cpu_count}'
+ )
+ os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
+
+ print(f"🔧 设置XLA环境变量:")
+ print(f" CPU核心数: {cpu_count}")
+ print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
+ print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
+ print("-" * 60)
+
+def monitor_cpu_usage(duration=30, interval=1):
+ """监控CPU使用情况"""
+ print(f"🔍 监控CPU使用情况 {duration}秒...")
+
+ cpu_usage_data = []
+ start_time = time.time()
+
+ while time.time() - start_time < duration:
+ # 获取每个CPU核心的使用率
+ cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True)
+ cpu_usage_data.append(cpu_percent_per_core)
+
+ # 实时显示
+ active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10)
+ print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, "
+ f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%",
+ end='\r')
+
+ print() # 换行
+
+ # 分析结果
+ if cpu_usage_data:
+ avg_usage_per_core = [
+ sum(core_data) / len(cpu_usage_data)
+ for core_data in zip(*cpu_usage_data)
+ ]
+
+ active_cores = sum(1 for avg in avg_usage_per_core if avg > 5)
+ max_usage = max(avg_usage_per_core)
+
+ print(f"📊 CPU使用分析:")
+ print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}")
+ print(f" 最高平均使用率: {max_usage:.1f}%")
+
+ for i, usage in enumerate(avg_usage_per_core):
+ status = "🟢" if usage > 10 else "🔴" if usage > 5 else "⚫"
+ print(f" CPU核心 {i}: {usage:.1f}% {status}")
+
+ return active_cores > 1
+
+ return False
+
+def test_xla_compilation():
+ """测试XLA编译"""
+ print(f"🚀 开始XLA编译测试...")
+
+ try:
+ import torch
+ import torch_xla.core.xla_model as xm
+
+ print(f"✅ PyTorch XLA导入成功")
+ print(f" XLA设备: {xm.xla_device()}")
+ print(f" XLA world size: {xm.xrt_world_size()}")
+
+ # 创建一个简单的计算图进行编译
+ device = xm.xla_device()
+
+ print(f"🔄 创建测试计算图...")
+ x = torch.randn(100, 100, device=device)
+ y = torch.randn(100, 100, device=device)
+
+ print(f"🔄 执行矩阵运算 (将触发XLA编译)...")
+
+ # 启动CPU监控
+ monitor_thread = threading.Thread(
+ target=lambda: monitor_cpu_usage(20, 0.5),
+ daemon=True
+ )
+ monitor_thread.start()
+
+ # 执行计算,触发编译
+ for i in range(10):
+ z = torch.matmul(x, y)
+ z = torch.relu(z)
+ z = torch.matmul(z, x.T)
+ if i == 0:
+ print(f"🔄 首次计算完成 (XLA编译应该正在进行)...")
+ elif i == 5:
+ print(f"🔄 第6次计算完成...")
+
+ # 等待监控完成
+ monitor_thread.join(timeout=25)
+
+ print(f"✅ XLA测试完成")
+
+ return True
+
+ except ImportError as e:
+ print(f"❌ PyTorch XLA导入失败: {e}")
+ return False
+ except Exception as e:
+ print(f"❌ XLA测试失败: {e}")
+ return False
+
+def main():
+ print("=" * 60)
+ print("🧪 XLA多线程编译诊断工具")
+ print("=" * 60)
+
+ # 1. 设置环境
+ set_xla_environment()
+
+ # 2. 测试XLA编译并监控CPU
+ success = test_xla_compilation()
+
+ print("=" * 60)
+ if success:
+ print("✅ 诊断完成")
+ print("💡 如果看到多个CPU核心被激活,说明XLA多线程工作正常")
+ print("💡 如果只有1-2个核心活跃,可能需要其他优化方法")
+ else:
+ print("❌ 诊断失败")
+ print("💡 请检查PyTorch XLA安装和TPU环境")
+
+ print("=" * 60)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/model_training_nnn_tpu/data_augmentations.py b/model_training_nnn_tpu/data_augmentations.py
new file mode 100644
index 0000000..7f4505a
--- /dev/null
+++ b/model_training_nnn_tpu/data_augmentations.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.ndimage import gaussian_filter1d
+
+def gauss_smooth(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='same'):
+ """
+ Applies a 1D Gaussian smoothing operation with PyTorch to smooth the data along the time axis.
+ Args:
+ inputs (tensor : B x T x N): A 3D tensor with batch size B, time steps T, and number of features N.
+ Assumed to already be on the correct device (e.g., GPU).
+ kernelSD (float): Standard deviation of the Gaussian smoothing kernel.
+ padding (str): Padding mode, either 'same' or 'valid'.
+ device (str): Device to use for computation (e.g., 'cuda' or 'cpu').
+ Returns:
+ smoothed (tensor : B x T x N): A smoothed 3D tensor with batch size B, time steps T, and number of features N.
+ """
+ # Get Gaussian kernel
+ inp = np.zeros(smooth_kernel_size, dtype=np.float32)
+ inp[smooth_kernel_size // 2] = 1
+ gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)
+ validIdx = np.argwhere(gaussKernel > 0.01)
+ gaussKernel = gaussKernel[validIdx]
+ gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))
+
+ # Convert to tensor
+ gaussKernel = torch.tensor(gaussKernel, dtype=torch.float32, device=device)
+ gaussKernel = gaussKernel.view(1, 1, -1) # [1, 1, kernel_size]
+
+ # Prepare convolution
+ B, T, C = inputs.shape
+ inputs = inputs.permute(0, 2, 1) # [B, C, T]
+ gaussKernel = gaussKernel.repeat(C, 1, 1) # [C, 1, kernel_size]
+
+ # Perform convolution
+ smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)
+ return smoothed.permute(0, 2, 1) # [B, T, C]
\ No newline at end of file
diff --git a/model_training_nnn_tpu/dataset.py b/model_training_nnn_tpu/dataset.py
new file mode 100644
index 0000000..086370e
--- /dev/null
+++ b/model_training_nnn_tpu/dataset.py
@@ -0,0 +1,336 @@
+import os
+import torch
+from torch.utils.data import Dataset
+import h5py
+import numpy as np
+from torch.nn.utils.rnn import pad_sequence
+import math
+
+class BrainToTextDataset(Dataset):
+ '''
+ Dataset for brain-to-text data
+
+ Returns an entire batch of data instead of a single example
+ '''
+
+ def __init__(
+ self,
+ trial_indicies,
+ n_batches,
+ split = 'train',
+ batch_size = 64,
+ days_per_batch = 1,
+ random_seed = -1,
+ must_include_days = None,
+ feature_subset = None
+ ):
+ '''
+ trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values
+ n_batches: (int) - number of random training batches to create
+ split: (string) - string specifying if this is a train or test dataset
+ batch_size: (int) - number of examples to include in batch returned from __getitem_()
+ days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates
+ to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch
+ random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random
+ must_include_days ([int]) - list of days that must be included in every batch
+ feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data
+ '''
+
+ # Set random seed for reproducibility
+ if random_seed != -1:
+ np.random.seed(random_seed)
+ torch.manual_seed(random_seed)
+
+ self.split = split
+
+ # Ensure the split is valid
+ if self.split not in ['train', 'test']:
+ raise ValueError(f'split must be either "train" or "test". Received {self.split}')
+
+ self.days_per_batch = days_per_batch
+
+ self.batch_size = batch_size
+
+ self.n_batches = n_batches
+
+ self.days = {}
+ self.n_trials = 0
+ self.trial_indicies = trial_indicies
+ self.n_days = len(trial_indicies.keys())
+
+ self.feature_subset = feature_subset
+
+ # Calculate total number of trials in the dataset
+ for d in trial_indicies:
+ self.n_trials += len(trial_indicies[d]['trials'])
+
+ if must_include_days is not None and len(must_include_days) > days_per_batch:
+ raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}')
+
+ if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train':
+ raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset')
+
+ if must_include_days is not None:
+ # Map must_include_days to correct indicies if they are negative
+ for i, d in enumerate(must_include_days):
+ if d < 0:
+ must_include_days[i] = self.n_days + d
+
+ self.must_include_days = must_include_days
+
+ # Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error
+ if self.split == 'train' and self.days_per_batch > self.n_days:
+ raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.')
+
+
+ if self.split == 'train':
+ self.batch_index = self.create_batch_index_train()
+ else:
+ self.batch_index = self.create_batch_index_test()
+ self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
+
+ def __len__(self):
+ '''
+ How many batches are in this dataset.
+ Because training data is sampled randomly, there is no fixed dataset length,
+ however this method is required for DataLoader to work
+ '''
+ return self.n_batches if self.n_batches is not None else 0
+
+ def __getitem__(self, idx):
+ '''
+ Gets an entire batch of data from the dataset, not just a single item
+ '''
+ batch = {
+ 'input_features' : [],
+ 'seq_class_ids' : [],
+ 'n_time_steps' : [],
+ 'phone_seq_lens' : [],
+ 'day_indicies' : [],
+ 'transcriptions' : [],
+ 'block_nums' : [],
+ 'trial_nums' : [],
+ }
+
+ index = self.batch_index[idx]
+
+ # Iterate through each day in the index
+ for d in index.keys():
+
+ # Open the hdf5 file for that day
+ with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f:
+
+ # For each trial in the selected trials in that day
+ for t in index[d]:
+
+ try:
+ g = f[f'trial_{t:04d}']
+
+ # Remove features is neccessary
+ input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # neural data - convert to bf16 for TPU compatibility
+ if self.feature_subset:
+ input_features = input_features[:,self.feature_subset]
+
+ batch['input_features'].append(input_features)
+
+ batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels
+ batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions
+ batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding
+ batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding
+ batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers
+ batch['block_nums'].append(g.attrs['block_num'])
+ batch['trial_nums'].append(g.attrs['trial_num'])
+
+ except Exception as e:
+ print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
+ continue
+
+ # Pad data to form a cohesive batch - ensure bf16 dtype is preserved
+ batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
+ batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
+
+ batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
+ batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens'])
+ batch['day_indicies'] = torch.tensor(batch['day_indicies'])
+ batch['transcriptions'] = torch.stack(batch['transcriptions'])
+ batch['block_nums'] = torch.tensor(batch['block_nums'])
+ batch['trial_nums'] = torch.tensor(batch['trial_nums'])
+
+ return batch
+
+
+ def create_batch_index_train(self):
+ '''
+ Create an index that maps a batch_number to batch_size number of trials
+
+ Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days
+ (or as even as possible if batch_size is not divisible by days_per_batch)
+ '''
+
+ batch_index = {}
+
+ # Precompute the days that are not in must_include_days
+ if self.must_include_days is not None:
+ non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days]
+
+ for batch_idx in range(self.n_batches):
+ batch = {}
+
+ # Which days will be used for this batch. Picked randomly without replacement
+ # TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day
+
+ # If must_include_days is not empty, we will use those days and then randomly sample the rest
+ if self.must_include_days is not None and len(self.must_include_days) > 0:
+
+ days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False)))
+
+ # Otherwise we will select random days without replacement
+ else:
+ days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False)
+
+ # How many trials will be sampled from each day
+ num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials
+
+ for d in days:
+
+ # Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem
+ trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True)
+ batch[d] = trial_idxs
+
+ # Remove extra trials
+ extra_trials = (num_trials * len(days)) - self.batch_size
+
+ # While we still have extra trials, remove the last trial from a random day
+ while extra_trials > 0:
+ d = np.random.choice(days)
+ batch[d] = batch[d][:-1]
+ extra_trials -= 1
+
+ batch_index[batch_idx] = batch
+
+ return batch_index
+
+ def create_batch_index_test(self):
+ '''
+ Create an index that is all validation/testing data in batches of up to self.batch_size
+
+ If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size
+
+ This index will ensures that every trial in the validation set is seen once and only once
+ '''
+ batch_index = {}
+ batch_idx = 0
+
+ for d in self.trial_indicies.keys():
+
+ # Calculate how many batches we need for this day
+ num_trials = len(self.trial_indicies[d]['trials'])
+ num_batches = (num_trials + self.batch_size - 1) // self.batch_size
+
+ # Create batches for this day
+ for i in range(num_batches):
+ start_idx = i * self.batch_size
+ end_idx = min((i + 1) * self.batch_size, num_trials)
+
+ # Get the trial indices for this batch
+ batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx]
+
+ # Add to batch_index
+ batch_index[batch_idx] = {d : batch_trials}
+ batch_idx += 1
+
+ return batch_index
+
+def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None):
+ '''
+ Split data from file_paths into train and test splits
+ Returns two dictionaries that detail which trials in each day will be a part of that split:
+ Example:
+ {
+ 0: trials[1,2,3], session_path: 'path'
+ 1: trials[2,5,6], session_path: 'path'
+ }
+
+ Args:
+ file_paths (list): List of file paths to the hdf5 files containing the data
+ test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing
+ seed (int): Seed for reproducibility. If set to -1, the split will be random
+ bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as:
+ {
+ 'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
+ 'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
+ ...
+ }
+ '''
+ # Set seed for reporoducibility
+ if seed != -1:
+ np.random.seed(seed)
+
+ # Get trials in each day
+ trials_per_day = {}
+ for i, path in enumerate(file_paths):
+ # Handle both Windows and Unix path separators
+ path_parts = path.replace('\\', '/').split('/')
+ session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
+
+ good_trial_indices = []
+
+ if os.path.exists(path):
+ with h5py.File(path, 'r') as f:
+ num_trials = len(list(f.keys()))
+ for t in range(num_trials):
+ key = f'trial_{t:04d}'
+
+ block_num = f[key].attrs['block_num']
+ trial_num = f[key].attrs['trial_num']
+
+ if (
+ bad_trials_dict is not None
+ and session in bad_trials_dict
+ and str(block_num) in bad_trials_dict[session]
+ and trial_num in bad_trials_dict[session][str(block_num)]
+ ):
+ # print(f'Bad trial: {session}_{block_num}_{trial_num}')
+ continue
+
+ good_trial_indices.append(t)
+
+ trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path}
+
+ # Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training
+ train_trials = {}
+ test_trials = {}
+
+ for day in trials_per_day.keys():
+
+ num_trials = trials_per_day[day]['num_trials']
+
+ # Generate all trial indices for this day (assuming 0-indexed)
+ all_trial_indices = trials_per_day[day]['trial_indices']
+
+ # If test_percentage is 0 or 1, we can just assign all trials to either train or test
+ if test_percentage == 0:
+ train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
+ test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
+ continue
+
+ elif test_percentage == 1:
+ train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
+ test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
+ continue
+
+ else:
+ # Calculate how many trials to use for testing
+ num_test = max(1, int(num_trials * test_percentage))
+
+ # Randomly select indices for testing
+ test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
+
+ # Remaining indices go to training
+ train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
+
+ # Store the split indices
+ train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']}
+ test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']}
+
+ return train_trials, test_trials
\ No newline at end of file
diff --git a/model_training_nnn_tpu/evaluate_model.py b/model_training_nnn_tpu/evaluate_model.py
new file mode 100644
index 0000000..d84e07b
--- /dev/null
+++ b/model_training_nnn_tpu/evaluate_model.py
@@ -0,0 +1,304 @@
+import os
+import torch
+import numpy as np
+import pandas as pd
+import redis
+from omegaconf import OmegaConf
+import time
+from tqdm import tqdm
+import editdistance
+import argparse
+
+from rnn_model import GRUDecoder
+from evaluate_model_helpers import *
+
+# argument parser for command line arguments
+parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
+parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
+ help='Path to the pretrained model directory (relative to the current working directory).')
+parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
+ help='Path to the dataset directory (relative to the current working directory).')
+parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
+ help='Evaluation type: "val" for validation set, "test" for test set. '
+ 'If "test", ground truth is not available.')
+parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
+ help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
+parser.add_argument('--gpu_number', type=int, default=-1,
+ help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
+args = parser.parse_args()
+
+# paths to model and data directories
+# Note: these paths are relative to the current working directory
+model_path = args.model_path
+data_dir = args.data_dir
+
+# define evaluation type
+eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
+
+# load csv file
+b2txt_csv_df = pd.read_csv(args.csv_path)
+
+# load model args
+model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
+
+# set up gpu device
+gpu_number = args.gpu_number
+if torch.cuda.is_available() and gpu_number >= 0:
+ if gpu_number >= torch.cuda.device_count():
+ raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
+ device = f'cuda:{gpu_number}'
+ device = torch.device(device)
+ print(f'Using {device} for model inference.')
+else:
+ if gpu_number >= 0:
+ print(f'GPU number {gpu_number} requested but not available.')
+ print('Using CPU for model inference.')
+ device = torch.device('cpu')
+
+# define model
+model = GRUDecoder(
+ neural_dim = model_args['model']['n_input_features'],
+ n_units = model_args['model']['n_units'],
+ n_days = len(model_args['dataset']['sessions']),
+ n_classes = model_args['dataset']['n_classes'],
+ rnn_dropout = model_args['model']['rnn_dropout'],
+ input_dropout = model_args['model']['input_network']['input_layer_dropout'],
+ n_layers = model_args['model']['n_layers'],
+ patch_size = model_args['model']['patch_size'],
+ patch_stride = model_args['model']['patch_stride'],
+)
+
+# load model weights
+checkpoint = torch.load(
+ os.path.join(model_path, 'checkpoint/best_checkpoint'),
+ map_location=device,
+ weights_only=False,
+)
+# rename keys to not start with "module." (happens if model was saved with DataParallel)
+for key in list(checkpoint['model_state_dict'].keys()):
+ checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
+ checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
+model.load_state_dict(checkpoint['model_state_dict'])
+
+# add model to device
+model.to(device)
+
+# set model to eval mode
+model.eval()
+
+# load data for each session
+test_data = {}
+total_test_trials = 0
+for session in model_args['dataset']['sessions']:
+ files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
+ if f'data_{eval_type}.hdf5' in files:
+ eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
+
+ data = load_h5py_file(eval_file, b2txt_csv_df)
+ test_data[session] = data
+
+ total_test_trials += len(test_data[session]["neural_features"])
+ print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
+print(f'Total number of {eval_type} trials: {total_test_trials}')
+print()
+
+
+# put neural data through the pretrained model to get phoneme predictions (logits)
+with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
+ for session, data in test_data.items():
+
+ data['logits'] = []
+ data['pred_seq'] = []
+ input_layer = model_args['dataset']['sessions'].index(session)
+
+ for trial in range(len(data['neural_features'])):
+ # get neural input for the trial
+ neural_input = data['neural_features'][trial]
+
+ # add batch dimension
+ neural_input = np.expand_dims(neural_input, axis=0)
+
+ # convert to torch tensor
+ neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
+
+ # run decoding step
+ logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)
+ data['logits'].append(logits)
+
+ pbar.update(1)
+pbar.close()
+
+
+# convert logits to phoneme sequences and print them out
+for session, data in test_data.items():
+ data['pred_seq'] = []
+ for trial in range(len(data['logits'])):
+ logits = data['logits'][trial][0]
+ pred_seq = np.argmax(logits, axis=-1)
+ # remove blanks (0)
+ pred_seq = [int(p) for p in pred_seq if p != 0]
+ # remove consecutive duplicates
+ pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
+ # convert to phonemes
+ pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
+ # add to data
+ data['pred_seq'].append(pred_seq)
+
+ # print out the predicted sequences
+ block_num = data['block_num'][trial]
+ trial_num = data['trial_num'][trial]
+ print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
+ if eval_type == 'val':
+ sentence_label = data['sentence_label'][trial]
+ true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
+ true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
+
+ print(f'Sentence label: {sentence_label}')
+ print(f'True sequence: {" ".join(true_seq)}')
+ print(f'Predicted Sequence: {" ".join(pred_seq)}')
+ print()
+
+
+# language model inference via redis
+# make sure that the standalone language model is running on the localhost redis ip
+# see README.md for instructions on how to run the language model
+
+def connect_to_redis_with_retry(host, port, password, db=0, max_retries=10, retry_delay=3):
+ """Connect to Redis with retry logic"""
+ for attempt in range(max_retries):
+ try:
+ print(f"Attempting to connect to Redis at {host}:{port} (attempt {attempt + 1}/{max_retries})...")
+ r = redis.Redis(host=host, port=port, db=db, password=password)
+ r.ping() # Test the connection
+ print(f"Successfully connected to Redis at {host}:{port}")
+ return r
+ except redis.exceptions.ConnectionError as e:
+ print(f"Redis connection failed (attempt {attempt + 1}/{max_retries}): {e}")
+ if attempt < max_retries - 1:
+ print(f"Retrying in {retry_delay} seconds...")
+ time.sleep(retry_delay)
+ else:
+ print("Max retries reached. Could not connect to Redis.")
+ raise e
+ except Exception as e:
+ print(f"Unexpected error connecting to Redis: {e}")
+ if attempt < max_retries - 1:
+ print(f"Retrying in {retry_delay} seconds...")
+ time.sleep(retry_delay)
+ else:
+ raise e
+
+r = connect_to_redis_with_retry('hs.zchens.cn', 6379, 'admin01')
+r.flushall() # clear all streams in redis
+
+# define redis streams for the remote language model
+remote_lm_input_stream = 'remote_lm_input'
+remote_lm_output_partial_stream = 'remote_lm_output_partial'
+remote_lm_output_final_stream = 'remote_lm_output_final'
+
+# set timestamps for last entries seen in the redis streams
+remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
+remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
+remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
+remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
+remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
+
+lm_results = {
+ 'session': [],
+ 'block': [],
+ 'trial': [],
+ 'true_sentence': [],
+ 'pred_sentence': [],
+}
+
+# loop through all trials and put logits into the remote language model to get text predictions
+# note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090)
+with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
+ for session in test_data.keys():
+ for trial in range(len(test_data[session]['logits'])):
+ # get trial logits and rearrange them for the LM
+ logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
+
+ # reset language model
+ remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
+
+ '''
+ # update language model parameters
+ remote_lm_done_updating_lastEntrySeen = update_remote_lm_params(
+ r,
+ remote_lm_done_updating_lastEntrySeen,
+ acoustic_scale=0.35,
+ blank_penalty=90.0,
+ alpha=0.55,
+ )
+ '''
+
+ # put logits into LM
+ remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
+ r,
+ remote_lm_input_stream,
+ remote_lm_output_partial_stream,
+ remote_lm_output_partial_lastEntrySeen,
+ logits,
+ )
+
+ # finalize remote LM
+ remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
+ r,
+ remote_lm_output_final_stream,
+ remote_lm_output_final_lastEntrySeen,
+ )
+
+ # get the best candidate sentence
+ best_candidate_sentence = lm_out['candidate_sentences'][0]
+
+ # store results
+ lm_results['session'].append(session)
+ lm_results['block'].append(test_data[session]['block_num'][trial])
+ lm_results['trial'].append(test_data[session]['trial_num'][trial])
+ if eval_type == 'val':
+ lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
+ else:
+ lm_results['true_sentence'].append(None)
+ lm_results['pred_sentence'].append(best_candidate_sentence)
+
+ # update progress bar
+ pbar.update(1)
+pbar.close()
+
+
+# if using the validation set, lets calculate the aggregate word error rate (WER)
+if eval_type == 'val':
+ total_true_length = 0
+ total_edit_distance = 0
+
+ lm_results['edit_distance'] = []
+ lm_results['num_words'] = []
+
+ for i in range(len(lm_results['pred_sentence'])):
+ true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
+ pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
+ ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
+
+ total_true_length += len(true_sentence.split())
+ total_edit_distance += ed
+
+ lm_results['edit_distance'].append(ed)
+ lm_results['num_words'].append(len(true_sentence.split()))
+
+ print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
+ print(f'True sentence: {true_sentence}')
+ print(f'Predicted sentence: {pred_sentence}')
+ print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%')
+ print()
+
+ print(f'Total true sentence length: {total_true_length}')
+ print(f'Total edit distance: {total_edit_distance}')
+ print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
+
+
+# write predicted sentences to a csv file. put a timestamp in the filename (YYYYMMDD_HHMMSS)
+output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.csv')
+ids = [i for i in range(len(lm_results['pred_sentence']))]
+df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']})
+df_out.to_csv(output_file, index=False)
\ No newline at end of file
diff --git a/model_training_nnn_tpu/evaluate_model_helpers.py b/model_training_nnn_tpu/evaluate_model_helpers.py
new file mode 100644
index 0000000..b16dd06
--- /dev/null
+++ b/model_training_nnn_tpu/evaluate_model_helpers.py
@@ -0,0 +1,297 @@
+import torch
+import numpy as np
+import h5py
+import time
+import re
+
+from data_augmentations import gauss_smooth
+
+LOGIT_TO_PHONEME = [
+ 'BLANK',
+ 'AA', 'AE', 'AH', 'AO', 'AW',
+ 'AY', 'B', 'CH', 'D', 'DH',
+ 'EH', 'ER', 'EY', 'F', 'G',
+ 'HH', 'IH', 'IY', 'JH', 'K',
+ 'L', 'M', 'N', 'NG', 'OW',
+ 'OY', 'P', 'R', 'S', 'SH',
+ 'T', 'TH', 'UH', 'UW', 'V',
+ 'W', 'Y', 'Z', 'ZH',
+ ' | ',
+]
+
+def _extract_transcription(input):
+ endIdx = np.argwhere(input == 0)[0, 0]
+ trans = ''
+ for c in range(endIdx):
+ trans += chr(input[c])
+ return trans
+
+def load_h5py_file(file_path, b2txt_csv_df):
+ data = {
+ 'neural_features': [],
+ 'n_time_steps': [],
+ 'seq_class_ids': [],
+ 'seq_len': [],
+ 'transcriptions': [],
+ 'sentence_label': [],
+ 'session': [],
+ 'block_num': [],
+ 'trial_num': [],
+ 'corpus': [],
+ }
+ # Open the hdf5 file for that day
+ with h5py.File(file_path, 'r') as f:
+
+ keys = list(f.keys())
+
+ # For each trial in the selected trials in that day
+ for key in keys:
+ g = f[key]
+
+ neural_features = g['input_features'][:]
+ n_time_steps = g.attrs['n_time_steps']
+ seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None
+ seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None
+ transcription = g['transcription'][:] if 'transcription' in g else None
+ sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None
+ session = g.attrs['session']
+ block_num = g.attrs['block_num']
+ trial_num = g.attrs['trial_num']
+
+ # match this trial up with the csv to get the corpus name
+ year, month, day = session.split('.')[1:]
+ date = f'{year}-{month}-{day}'
+ row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]
+ corpus_name = row['Corpus'].values[0]
+
+ data['neural_features'].append(neural_features)
+ data['n_time_steps'].append(n_time_steps)
+ data['seq_class_ids'].append(seq_class_ids)
+ data['seq_len'].append(seq_len)
+ data['transcriptions'].append(transcription)
+ data['sentence_label'].append(sentence_label)
+ data['session'].append(session)
+ data['block_num'].append(block_num)
+ data['trial_num'].append(trial_num)
+ data['corpus'].append(corpus_name)
+ return data
+
+def rearrange_speech_logits_pt(logits):
+ # original order is [BLANK, phonemes..., SIL]
+ # rearrange so the order is [BLANK, SIL, phonemes...]
+ logits = np.concatenate((logits[:, :, 0:1], logits[:, :, -1:], logits[:, :, 1:-1]), axis=-1)
+ return logits
+
+# single decoding step function.
+# smooths data and puts it through the model.
+def runSingleDecodingStep(x, input_layer, model, model_args, device):
+
+ # Use autocast for efficiency
+ with torch.autocast(device_type = "cuda", enabled = model_args['use_amp'], dtype = torch.bfloat16):
+
+ x = gauss_smooth(
+ inputs = x,
+ device = device,
+ smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],
+ smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],
+ padding = 'valid',
+ )
+
+ with torch.no_grad():
+ logits, _ = model(
+ x = x,
+ day_idx = torch.tensor([input_layer], device=device),
+ states = None, # no initial states
+ return_state = True,
+ )
+
+ # convert logits from bfloat16 to float32
+ logits = logits.float().cpu().numpy()
+
+ # # original order is [BLANK, phonemes..., SIL]
+ # # rearrange so the order is [BLANK, SIL, phonemes...]
+ # logits = rearrange_speech_logits_pt(logits)
+
+ return logits
+
+def remove_punctuation(sentence):
+ # Remove punctuation
+ sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
+ sentence = sentence.replace('- ', ' ').lower()
+ sentence = sentence.replace('--', '').lower()
+ sentence = sentence.replace(" '", "'").lower()
+
+ sentence = sentence.strip()
+ sentence = ' '.join([word for word in sentence.split() if word != ''])
+
+ return sentence
+
+def get_current_redis_time_ms(redis_conn):
+ t = redis_conn.time()
+ return int(t[0]*1000 + t[1]/1000)
+
+
+######### language model helper functions ##########
+
+def reset_remote_language_model(
+ r,
+ remote_lm_done_resetting_lastEntrySeen,
+ ):
+
+ r.xadd('remote_lm_reset', {'done': 0})
+ time.sleep(0.001)
+ # print('Resetting remote language model before continuing...')
+ remote_lm_done_resetting = []
+ while len(remote_lm_done_resetting) == 0:
+ remote_lm_done_resetting = r.xread(
+ {'remote_lm_done_resetting': remote_lm_done_resetting_lastEntrySeen},
+ count=1,
+ block=10000,
+ )
+ if len(remote_lm_done_resetting) == 0:
+ print(f'Still waiting for remote lm reset from ts {remote_lm_done_resetting_lastEntrySeen}...')
+ for entry_id, entry_data in remote_lm_done_resetting[0][1]:
+ remote_lm_done_resetting_lastEntrySeen = entry_id
+ # print('Remote language model reset.')
+
+ return remote_lm_done_resetting_lastEntrySeen
+
+
+def update_remote_lm_params(
+ r,
+ remote_lm_done_updating_lastEntrySeen,
+ acoustic_scale=0.35,
+ blank_penalty=90.0,
+ alpha=0.55,
+ ):
+
+ # update remote lm params
+ entry_dict = {
+ # 'max_active': max_active,
+ # 'min_active': min_active,
+ # 'beam': beam,
+ # 'lattice_beam': lattice_beam,
+ 'acoustic_scale': acoustic_scale,
+ # 'ctc_blank_skip_threshold': ctc_blank_skip_threshold,
+ # 'length_penalty': length_penalty,
+ # 'nbest': nbest,
+ 'blank_penalty': blank_penalty,
+ 'alpha': alpha,
+ # 'do_opt': do_opt,
+ # 'rescore': rescore,
+ # 'top_candidates_to_augment': top_candidates_to_augment,
+ # 'score_penalty_percent': score_penalty_percent,
+ # 'specific_word_bias': specific_word_bias,
+ }
+
+ r.xadd('remote_lm_update_params', entry_dict)
+ time.sleep(0.001)
+ remote_lm_done_updating = []
+ while len(remote_lm_done_updating) == 0:
+ remote_lm_done_updating = r.xread(
+ {'remote_lm_done_updating_params': remote_lm_done_updating_lastEntrySeen},
+ block=10000,
+ count=1,
+ )
+ if len(remote_lm_done_updating) == 0:
+ print(f'Still waiting for remote lm to update parameters from ts {remote_lm_done_updating_lastEntrySeen}...')
+ for entry_id, entry_data in remote_lm_done_updating[0][1]:
+ remote_lm_done_updating_lastEntrySeen = entry_id
+ # print('Remote language model params updated.')
+
+ return remote_lm_done_updating_lastEntrySeen
+
+
+def send_logits_to_remote_lm(
+ r,
+ remote_lm_input_stream,
+ remote_lm_output_partial_stream,
+ remote_lm_output_partial_lastEntrySeen,
+ logits,
+ ):
+
+ # put logits into remote lm and get partial output
+ r.xadd(remote_lm_input_stream, {'logits': np.float32(logits).tobytes()})
+ remote_lm_output = []
+ while len(remote_lm_output) == 0:
+ remote_lm_output = r.xread(
+ {remote_lm_output_partial_stream: remote_lm_output_partial_lastEntrySeen},
+ block=10000,
+ count=1,
+ )
+ if len(remote_lm_output) == 0:
+ print(f'Still waiting for remote lm partial output from ts {remote_lm_output_partial_lastEntrySeen}...')
+ for entry_id, entry_data in remote_lm_output[0][1]:
+ remote_lm_output_partial_lastEntrySeen = entry_id
+ decoded = entry_data[b'lm_response_partial'].decode()
+
+ return remote_lm_output_partial_lastEntrySeen, decoded
+
+
+def finalize_remote_lm(
+ r,
+ remote_lm_output_final_stream,
+ remote_lm_output_final_lastEntrySeen,
+ ):
+
+ # finalize remote lm
+ r.xadd('remote_lm_finalize', {'done': 0})
+ time.sleep(0.005)
+ remote_lm_output = []
+ while len(remote_lm_output) == 0:
+ remote_lm_output = r.xread(
+ {remote_lm_output_final_stream: remote_lm_output_final_lastEntrySeen},
+ block=10000,
+ count=1,
+ )
+ if len(remote_lm_output) == 0:
+ print(f'Still waiting for remote lm final output from ts {remote_lm_output_final_lastEntrySeen}...')
+ # print('Received remote lm final output.')
+
+ for entry_id, entry_data in remote_lm_output[0][1]:
+ remote_lm_output_final_lastEntrySeen = entry_id
+
+ candidate_sentences = [str(c) for c in entry_data[b'scoring'].decode().split(';')[::5]]
+ candidate_acoustic_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[1::5]]
+ candidate_ngram_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[2::5]]
+ candidate_llm_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[3::5]]
+ candidate_total_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[4::5]]
+
+
+ # account for a weird edge case where there are no candidate sentences
+ if len(candidate_sentences) == 0 or len(candidate_total_scores) == 0:
+ print('No candidate sentences were received from the language model.')
+ candidate_sentences = ['']
+ candidate_acoustic_scores = [0]
+ candidate_ngram_scores = [0]
+ candidate_llm_scores = [0]
+ candidate_total_scores = [0]
+
+ else:
+ # sort candidate sentences by total score (higher is better)
+ sort_order = np.argsort(candidate_total_scores)[::-1]
+
+ candidate_sentences = [candidate_sentences[i] for i in sort_order]
+ candidate_acoustic_scores = [candidate_acoustic_scores[i] for i in sort_order]
+ candidate_ngram_scores = [candidate_ngram_scores[i] for i in sort_order]
+ candidate_llm_scores = [candidate_llm_scores[i] for i in sort_order]
+ candidate_total_scores = [candidate_total_scores[i] for i in sort_order]
+
+ # loop through candidates backwards and remove any duplicates
+ for i in range(len(candidate_sentences)-1, 0, -1):
+ if candidate_sentences[i] in candidate_sentences[:i]:
+ candidate_sentences.pop(i)
+ candidate_acoustic_scores.pop(i)
+ candidate_ngram_scores.pop(i)
+ candidate_llm_scores.pop(i)
+ candidate_total_scores.pop(i)
+
+ lm_out = {
+ 'candidate_sentences': candidate_sentences,
+ 'candidate_acoustic_scores': candidate_acoustic_scores,
+ 'candidate_ngram_scores': candidate_ngram_scores,
+ 'candidate_llm_scores': candidate_llm_scores,
+ 'candidate_total_scores': candidate_total_scores,
+ }
+
+ return remote_lm_output_final_lastEntrySeen, lm_out
\ No newline at end of file
diff --git a/model_training_nnn_tpu/jupyter_debug_full_model.py b/model_training_nnn_tpu/jupyter_debug_full_model.py
new file mode 100644
index 0000000..1b7deeb
--- /dev/null
+++ b/model_training_nnn_tpu/jupyter_debug_full_model.py
@@ -0,0 +1,124 @@
+# ====================
+# 单元格4: 逐步调试完整模型编译
+# ====================
+
+# 如果单元格3测试通过,运行这个单元格
+print("🔧 逐步测试完整TripleGRUDecoder模型...")
+
+# 导入完整模型
+import sys
+sys.path.append('.') # 确保能导入本地模块
+
+try:
+ from rnn_model import TripleGRUDecoder
+ print("✅ TripleGRUDecoder导入成功")
+except ImportError as e:
+ print(f"❌ 模型导入失败: {e}")
+ print("请确保rnn_model.py在当前目录中")
+
+# 分阶段测试
+def test_model_compilation_stages():
+ """分阶段测试模型编译"""
+ device = xm.xla_device()
+
+ # 阶段1: 测试NoiseModel单独编译
+ print("\n🔬 阶段1: 测试NoiseModel...")
+ try:
+ from rnn_model import NoiseModel
+ noise_model = NoiseModel(
+ neural_dim=512,
+ n_units=384, # 减小参数
+ n_days=4,
+ patch_size=8 # 减小patch size
+ ).to(device)
+
+ x = torch.randn(2, 20, 512, device=device)
+ day_idx = torch.tensor([0, 1], device=device)
+
+ start_time = time.time()
+ with torch.no_grad():
+ output, states = noise_model(x, day_idx)
+ compile_time = time.time() - start_time
+
+ print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}秒")
+ print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
+
+ return True, compile_time
+
+ except Exception as e:
+ print(f"❌ NoiseModel编译失败: {e}")
+ return False, 0
+
+ # 阶段2: 测试CleanSpeechModel
+ print("\n🔬 阶段2: 测试CleanSpeechModel...")
+ try:
+ from rnn_model import CleanSpeechModel
+ clean_model = CleanSpeechModel(
+ neural_dim=512,
+ n_units=384,
+ n_days=4,
+ n_classes=41,
+ patch_size=8
+ ).to(device)
+
+ start_time = time.time()
+ with torch.no_grad():
+ output = clean_model(x, day_idx)
+ compile_time = time.time() - start_time
+
+ print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}秒")
+ return True, compile_time
+
+ except Exception as e:
+ print(f"❌ CleanSpeechModel编译失败: {e}")
+ return False, 0
+
+ # 阶段3: 测试完整TripleGRUDecoder
+ print("\n🔬 阶段3: 测试TripleGRUDecoder...")
+ try:
+ model = TripleGRUDecoder(
+ neural_dim=512,
+ n_units=384, # 比原来的768小
+ n_days=4, # 减少天数
+ n_classes=41,
+ patch_size=8 # 减小patch size
+ ).to(device)
+
+ print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
+
+ # 启动编译监控
+ compilation_monitor.start_monitoring()
+
+ start_time = time.time()
+ with torch.no_grad():
+ # 测试inference模式
+ logits = model(x, day_idx, None, False, 'inference')
+ compile_time = time.time() - start_time
+
+ compilation_monitor.complete_monitoring()
+
+ print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}秒")
+ print(f"📤 输出形状: {logits.shape}")
+
+ return True, compile_time
+
+ except Exception as e:
+ compilation_monitor.complete_monitoring()
+ print(f"❌ TripleGRUDecoder编译失败: {e}")
+ return False, 0
+
+# 运行分阶段测试
+stage_results = test_model_compilation_stages()
+
+if stage_results:
+ print(f"\n🎉 所有编译测试完成!")
+ print("💡 下一步可以尝试:")
+ print(" 1. 使用简化配置进行训练")
+ print(" 2. 逐步增加模型复杂度")
+ print(" 3. 监控TPU资源使用情况")
+else:
+ print(f"\n⚠️ 编译测试发现问题")
+ print("💡 建议:")
+ print(" 1. 进一步减小模型参数")
+ print(" 2. 检查内存使用情况")
+ print(" 3. 使用CPU模式进行调试")
\ No newline at end of file
diff --git a/model_training_nnn_tpu/jupyter_xla_monitor.py b/model_training_nnn_tpu/jupyter_xla_monitor.py
new file mode 100644
index 0000000..e02ece1
--- /dev/null
+++ b/model_training_nnn_tpu/jupyter_xla_monitor.py
@@ -0,0 +1,131 @@
+# ====================
+# 单元格2: XLA编译进度监控
+# ====================
+
+import torch
+import torch.nn as nn
+import time
+import threading
+from IPython.display import display, HTML, clear_output
+import ipywidgets as widgets
+
+# 导入XLA (环境变量已在单元格1中设置)
+print("🚀 导入PyTorch XLA...")
+import torch_xla.core.xla_model as xm
+
+print(f"✅ XLA导入成功!")
+print(f" TPU设备: {xm.xla_device()}")
+print(f" World Size: {xm.xrt_world_size()}")
+
+# 创建编译进度监控器
+class JupyterCompilationMonitor:
+ def __init__(self):
+ self.start_time = None
+ self.is_monitoring = False
+
+ # 创建输出widget
+ self.output_widget = widgets.Output()
+
+ # 创建进度条
+ self.progress_bar = widgets.IntProgress(
+ value=0,
+ min=0,
+ max=100,
+ description='XLA编译:',
+ bar_style='info',
+ style={'bar_color': '#1f77b4'},
+ orientation='horizontal'
+ )
+
+ # 创建状态标签
+ self.status_label = widgets.HTML(
+ value="准备开始编译..."
+ )
+
+ # 创建CPU使用率显示
+ self.cpu_label = widgets.HTML(
+ value="CPU: ---%"
+ )
+
+ self.memory_label = widgets.HTML(
+ value="内存: ---%"
+ )
+
+ # 组合界面
+ self.monitor_box = widgets.VBox([
+ widgets.HTML("
🔄 XLA编译监控
"),
+ self.progress_bar,
+ self.status_label,
+ widgets.HBox([self.cpu_label, self.memory_label]),
+ self.output_widget
+ ])
+
+ def start_monitoring(self):
+ """开始监控"""
+ self.start_time = time.time()
+ self.is_monitoring = True
+
+ display(self.monitor_box)
+
+ # 启动监控线程
+ self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
+ self.monitor_thread.start()
+
+ def _monitor_loop(self):
+ """监控循环"""
+ while self.is_monitoring:
+ try:
+ elapsed = time.time() - self.start_time
+ minutes = int(elapsed // 60)
+ seconds = int(elapsed % 60)
+
+ # 更新进度条 (模拟进度)
+ progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
+ self.progress_bar.value = progress
+
+ # 获取系统资源
+ cpu_percent = psutil.cpu_percent(interval=0.1)
+ memory_percent = psutil.virtual_memory().percent
+
+ # 更新显示
+ self.status_label.value = f"编译进行中... ⏱️ {minutes:02d}:{seconds:02d}"
+ self.cpu_label.value = f"🖥️ CPU: {cpu_percent:5.1f}%"
+ self.memory_label.value = f"💾 内存: {memory_percent:5.1f}%"
+
+ # 检测是否编译完成 (CPU使用率突然下降)
+ if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
+ self.complete_monitoring()
+ break
+
+ time.sleep(1)
+
+ except Exception as e:
+ with self.output_widget:
+ print(f"监控错误: {e}")
+ break
+
+ def complete_monitoring(self):
+ """完成监控"""
+ if self.is_monitoring:
+ self.is_monitoring = False
+ elapsed = time.time() - self.start_time
+
+ self.progress_bar.value = 100
+ self.progress_bar.bar_style = 'success'
+ self.status_label.value = f"✅ 编译完成! 总耗时: {elapsed:.2f}秒"
+
+ with self.output_widget:
+ print(f"\n🎉 XLA编译成功完成!")
+ print(f"⏱️ 总耗时: {elapsed:.2f}秒")
+ if elapsed < 60:
+ print("✅ 编译速度正常")
+ elif elapsed < 300:
+ print("⚠️ 编译稍慢,但可接受")
+ else:
+ print("❌ 编译过慢,建议检查设置")
+
+# 创建全局监控器
+compilation_monitor = JupyterCompilationMonitor()
+
+print("✅ 编译监控器已准备就绪!")
+print("💡 运行下一个单元格开始XLA编译测试")
\ No newline at end of file
diff --git a/model_training_nnn_tpu/jupyter_xla_setup.py b/model_training_nnn_tpu/jupyter_xla_setup.py
new file mode 100644
index 0000000..d37296e
--- /dev/null
+++ b/model_training_nnn_tpu/jupyter_xla_setup.py
@@ -0,0 +1,45 @@
+# ====================
+# 单元格1: 环境设置 (必须第一个运行!)
+# ====================
+
+import os
+import time
+import psutil
+from IPython.display import display, HTML, clear_output
+import ipywidgets as widgets
+
+# ⚠️ 关键: 在导入torch_xla之前设置环境变量
+print("🔧 设置XLA环境变量...")
+
+# 获取CPU核心数
+cpu_count = os.cpu_count()
+print(f"检测到 {cpu_count} 个CPU核心")
+
+# 设置XLA编译优化环境变量
+os.environ['XLA_FLAGS'] = (
+ '--xla_cpu_multi_thread_eigen=true '
+ '--xla_cpu_enable_fast_math=true '
+ f'--xla_force_host_platform_device_count={cpu_count}'
+)
+os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
+os.environ['XLA_USE_BF16'] = '1'
+
+# 显示设置结果
+print("✅ XLA环境变量已设置:")
+print(f" CPU核心数: {cpu_count}")
+print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
+print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
+
+# 系统资源检查
+memory_info = psutil.virtual_memory()
+print(f"\n💾 系统内存信息:")
+print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
+print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
+print(f" 使用率: {memory_info.percent:.1f}%")
+
+if memory_info.available < 8 * (1024**3): # 小于8GB
+ print("⚠️ 警告: 可用内存不足8GB,可能影响XLA编译速度")
+else:
+ print("✅ 内存充足")
+
+print("\n🎯 环境设置完成! 现在可以运行下一个单元格")
\ No newline at end of file
diff --git a/model_training_nnn_tpu/jupyter_xla_test.py b/model_training_nnn_tpu/jupyter_xla_test.py
new file mode 100644
index 0000000..2d45b9d
--- /dev/null
+++ b/model_training_nnn_tpu/jupyter_xla_test.py
@@ -0,0 +1,78 @@
+# ====================
+# 单元格3: 快速XLA编译测试
+# ====================
+
+# 简化测试模型
+class QuickTestModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = nn.Linear(512, 128)
+ self.gru = nn.GRU(128, 64, batch_first=True)
+ self.linear2 = nn.Linear(64, 41)
+
+ def forward(self, x):
+ x = torch.relu(self.linear1(x))
+ x, _ = self.gru(x)
+ x = self.linear2(x)
+ return x
+
+print("🧪 开始XLA编译快速测试...")
+
+# 启动监控
+compilation_monitor.start_monitoring()
+
+try:
+ # 获取TPU设备
+ device = xm.xla_device()
+
+ # 创建小模型
+ model = QuickTestModel().to(device)
+ param_count = sum(p.numel() for p in model.parameters())
+ print(f"📊 测试模型参数: {param_count:,}")
+
+ # 创建测试数据 (很小的batch)
+ x = torch.randn(2, 20, 512, device=device)
+ print(f"📥 输入数据形状: {x.shape}")
+
+ print("🔄 开始首次前向传播 (触发XLA编译)...")
+
+ # 首次前向传播 - 这会触发XLA编译
+ with torch.no_grad():
+ start_compile = time.time()
+ output = model(x)
+ compile_time = time.time() - start_compile
+
+ print(f"✅ XLA编译完成!")
+ print(f"📤 输出形状: {output.shape}")
+
+ # 完成监控
+ compilation_monitor.complete_monitoring()
+
+ # 测试编译后的性能
+ print("\n🚀 测试编译后的执行速度...")
+ with torch.no_grad():
+ start_exec = time.time()
+ for _ in range(10):
+ output = model(x)
+ avg_exec_time = (time.time() - start_exec) / 10
+
+ print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms")
+
+ # 性能评估
+ if compile_time < 30:
+ print("✅ 编译速度优秀! 可以尝试完整模型")
+ test_result = "excellent"
+ elif compile_time < 120:
+ print("✅ 编译速度良好! 建议使用简化配置")
+ test_result = "good"
+ else:
+ print("⚠️ 编译速度较慢,建议进一步优化")
+ test_result = "slow"
+
+except Exception as e:
+ compilation_monitor.complete_monitoring()
+ print(f"❌ 测试失败: {e}")
+ test_result = "failed"
+
+print(f"\n📋 测试结果: {test_result}")
+print("💡 如果测试通过,可以运行下一个单元格进行完整训练")
\ No newline at end of file
diff --git a/model_training_nnn_tpu/launch_tpu_training.py b/model_training_nnn_tpu/launch_tpu_training.py
new file mode 100644
index 0000000..1beb755
--- /dev/null
+++ b/model_training_nnn_tpu/launch_tpu_training.py
@@ -0,0 +1,161 @@
+#!/usr/bin/env python3
+"""
+TPU Training Launch Script for Brain-to-Text RNN Model
+
+This script provides easy TPU training setup using Accelerate library.
+Supports both single TPU core and multi-core (8 cores) training.
+
+Usage:
+ python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
+
+Requirements:
+ - PyTorch XLA installed
+ - Accelerate library installed
+ - TPU runtime available
+"""
+
+import argparse
+import yaml
+import os
+import sys
+from pathlib import Path
+
+def update_config_for_tpu(config_path, num_cores=8):
+ """
+ Update configuration file to enable TPU training
+ """
+ with open(config_path, 'r') as f:
+ config = yaml.safe_load(f)
+
+ # Enable TPU settings
+ config['use_tpu'] = True
+ config['num_tpu_cores'] = num_cores
+ config['dataloader_num_workers'] = 0 # Required for TPU
+ config['use_amp'] = True # Enable mixed precision with bfloat16
+
+ # Adjust batch size and gradient accumulation for multi-core TPU
+ if num_cores > 1:
+ # Distribute batch size across cores
+ original_batch_size = config['dataset']['batch_size']
+ config['dataset']['batch_size'] = max(1, original_batch_size // num_cores)
+ config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1))
+
+ print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core")
+ print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
+
+ # Save updated config
+ tpu_config_path = config_path.replace('.yaml', '_tpu.yaml')
+ with open(tpu_config_path, 'w') as f:
+ yaml.dump(config, f, default_flow_style=False)
+
+ print(f"TPU configuration saved to: {tpu_config_path}")
+ return tpu_config_path
+
+def check_tpu_environment():
+ """
+ Check if TPU environment is properly set up
+ """
+ try:
+ import torch_xla
+ import torch_xla.core.xla_model as xm
+
+ # Check if TPUs are available
+ device = xm.xla_device()
+ print(f"TPU device available: {device}")
+ print(f"TPU ordinal: {xm.get_ordinal()}")
+ print(f"TPU world size: {xm.xrt_world_size()}")
+
+ return True
+ except ImportError:
+ print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.")
+ return False
+ except Exception as e:
+ print(f"ERROR: TPU not available - {e}")
+ return False
+
+def run_tpu_training(config_path, num_cores=8):
+ """
+ Launch TPU training using accelerate
+ """
+ # Check TPU environment
+ if not check_tpu_environment():
+ sys.exit(1)
+
+ # Update config for TPU
+ tpu_config_path = update_config_for_tpu(config_path, num_cores)
+
+ # Set TPU environment variables BEFORE launching training
+ os.environ['TPU_CORES'] = str(num_cores)
+ os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16
+
+ # Critical XLA multi-threading settings - must be set before torch_xla import
+ cpu_count = os.cpu_count()
+ os.environ['XLA_FLAGS'] = (
+ '--xla_cpu_multi_thread_eigen=true '
+ '--xla_cpu_enable_fast_math=true '
+ f'--xla_force_host_platform_device_count={cpu_count}'
+ )
+ os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
+
+ print(f"Set XLA compilation to use {cpu_count} CPU threads")
+ print(f"XLA_FLAGS: {os.environ['XLA_FLAGS']}")
+ print(f"PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
+
+ # Launch training with accelerate using subprocess to ensure environment variables are passed
+ cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}"
+
+ print(f"Launching TPU training with command:")
+ print(f" {cmd}")
+ print(f"Using {num_cores} TPU cores")
+ print("-" * 60)
+
+ # Use subprocess to ensure environment variables are properly inherited
+ import subprocess
+
+ # Create environment with our XLA settings
+ env = os.environ.copy()
+ env.update({
+ 'TPU_CORES': str(num_cores),
+ 'XLA_USE_BF16': '1',
+ 'XLA_FLAGS': (
+ '--xla_cpu_multi_thread_eigen=true '
+ '--xla_cpu_enable_fast_math=true '
+ f'--xla_force_host_platform_device_count={cpu_count}'
+ ),
+ 'PYTORCH_XLA_COMPILATION_THREADS': str(cpu_count)
+ })
+
+ print(f"Environment variables set for subprocess:")
+ print(f" XLA_FLAGS: {env['XLA_FLAGS']}")
+ print(f" PYTORCH_XLA_COMPILATION_THREADS: {env['PYTORCH_XLA_COMPILATION_THREADS']}")
+ print("-" * 60)
+
+ # Execute training with proper environment
+ result = subprocess.run(cmd.split(), env=env)
+ return result.returncode
+
+def main():
+ parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN')
+ parser.add_argument('--config', default='rnn_args.yaml',
+ help='Path to configuration file (default: rnn_args.yaml)')
+ parser.add_argument('--num_cores', type=int, default=8,
+ help='Number of TPU cores to use (default: 8)')
+ parser.add_argument('--check_only', action='store_true',
+ help='Only check TPU environment, do not launch training')
+
+ args = parser.parse_args()
+
+ # Verify config file exists
+ if not os.path.exists(args.config):
+ print(f"ERROR: Configuration file {args.config} not found")
+ sys.exit(1)
+
+ if args.check_only:
+ check_tpu_environment()
+ return
+
+ # Run TPU training
+ run_tpu_training(args.config, args.num_cores)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/model_training_nnn_tpu/monitor_xla_compilation.py b/model_training_nnn_tpu/monitor_xla_compilation.py
new file mode 100644
index 0000000..09d8838
--- /dev/null
+++ b/model_training_nnn_tpu/monitor_xla_compilation.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+"""
+XLA编译进度监控脚本
+"""
+
+import os
+import time
+import threading
+import psutil
+from contextlib import contextmanager
+
+def monitor_compilation_progress():
+ """监控XLA编译进度"""
+ print("🔍 XLA编译进度监控已启动...")
+
+ start_time = time.time()
+ dots = 0
+
+ while True:
+ elapsed = time.time() - start_time
+ minutes = int(elapsed // 60)
+ seconds = int(elapsed % 60)
+
+ # 获取CPU使用率
+ cpu_percent = psutil.cpu_percent(interval=1)
+ memory_percent = psutil.virtual_memory().percent
+
+ # 动态显示
+ dots = (dots + 1) % 4
+ dot_str = "." * dots + " " * (3 - dots)
+
+ print(f"\r🔄 XLA编译中{dot_str} "
+ f"⏱️ {minutes:02d}:{seconds:02d} "
+ f"🖥️ CPU: {cpu_percent:5.1f}% "
+ f"💾 内存: {memory_percent:5.1f}%", end="", flush=True)
+
+ time.sleep(1)
+
+@contextmanager
+def compilation_monitor():
+ """编译监控上下文管理器"""
+ print("🚀 开始XLA编译监控...")
+
+ # 启动监控线程
+ monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True)
+ monitor_thread.start()
+
+ start_time = time.time()
+
+ try:
+ yield
+ finally:
+ elapsed = time.time() - start_time
+ print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}秒")
+
+# 修改trainer中的使用方法
+def add_compilation_monitoring_to_trainer():
+ """给trainer添加编译监控的示例代码"""
+ example_code = '''
+# 在rnn_trainer.py的train方法中添加:
+
+def train(self):
+ from monitor_xla_compilation import compilation_monitor
+
+ self.model.train()
+ train_losses = []
+ # ... 其他初始化代码 ...
+
+ self.logger.info("Starting training loop - XLA compilation monitoring enabled...")
+
+ # 使用编译监控
+ with compilation_monitor():
+ for i, batch in enumerate(self.train_loader):
+ # 第一个batch会触发XLA编译
+ # 监控会显示编译进度
+
+ # ... 训练代码 ...
+
+ # 编译完成后会自动停止监控
+ break # 只需要第一个batch来触发编译
+
+ # 继续正常训练循环
+ for i, batch in enumerate(self.train_loader):
+ # ... 正常训练代码 ...
+ '''
+
+ print("📝 如何在trainer中使用编译监控:")
+ print(example_code)
+
+if __name__ == "__main__":
+ print("🧪 XLA编译监控工具")
+ print("=" * 50)
+
+ # 演示如何使用
+ print("📖 使用方法:")
+ print("1. 将此文件导入到你的训练脚本中")
+ print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器")
+ print("3. 会实时显示编译进度和系统资源使用情况")
+
+ add_compilation_monitoring_to_trainer()
\ No newline at end of file
diff --git a/model_training_nnn_tpu/rnn_args.yaml b/model_training_nnn_tpu/rnn_args.yaml
new file mode 100644
index 0000000..60e4049
--- /dev/null
+++ b/model_training_nnn_tpu/rnn_args.yaml
@@ -0,0 +1,181 @@
+model:
+ n_input_features: 512 # number of input features in the neural data. (2 features per electrode, 256 electrodes)
+ n_units: 768 # number of units per GRU layer
+ rnn_dropout: 0.4 # dropout rate for the GRU layers
+ rnn_trainable: true # whether the GRU layers are trainable
+ n_layers: 5 # number of GRU layers
+ patch_size: 14 # size of the input patches (14 time steps)
+ patch_stride: 4 # stride for the input patches (4 time steps)
+
+ input_network:
+ n_input_layers: 1 # number of input layers per network (one network for each day)
+ input_layer_sizes:
+ - 512 # size of the input layer (number of input features)
+ input_trainable: true # whether the input layer is trainable
+ input_layer_dropout: 0.2 # dropout rate for the input layer
+
+mode: train
+use_amp: true # whether to use automatic mixed precision (AMP) for training with bfloat16 on TPU
+
+# TPU distributed training settings
+use_tpu: true # TPU training enabled
+num_tpu_cores: 8 # number of TPU cores to use (full TPU v3-8 or v4-8)
+gradient_accumulation_steps: 2 # number of gradient accumulation steps for distributed training (2x32=64 effective batch size)
+
+output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
+checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
+init_from_checkpoint: false # whether to initialize the model from a checkpoint
+init_checkpoint_path: None # path to the checkpoint to initialize the model from, if any
+save_best_checkpoint: true # whether to save the best checkpoint based on validation metrics
+save_all_val_steps: false # whether to save checkpoints at all validation steps
+save_final_model: false # whether to save the final model after training
+save_val_metrics: true # whether to save validation metrics during training
+early_stopping: false # whether to use early stopping based on validation metrics
+early_stopping_val_steps: 20 # number of validation steps to wait before stopping training if no improvement is seen
+
+num_training_batches: 120000 # number of training batches to run
+lr_scheduler_type: cosine # type of learning rate scheduler to use
+lr_max: 0.005 # maximum learning rate for the main model
+lr_min: 0.0001 # minimum learning rate for the main model
+lr_decay_steps: 120000 # number of steps for the learning rate decay
+lr_warmup_steps: 1000 # number of warmup steps for the learning rate scheduler
+lr_max_day: 0.005 # maximum learning rate for the day specific input layers
+lr_min_day: 0.0001 # minimum learning rate for the day specific input layers
+lr_decay_steps_day: 120000 # number of steps for the learning rate decay for the day specific input layers
+lr_warmup_steps_day: 1000 # number of warmup steps for the learning rate scheduler for the day specific input layers
+
+beta0: 0.9 # beta0 parameter for the Adam optimizer
+beta1: 0.999 # beta1 parameter for the Adam optimizer
+epsilon: 0.1 # epsilon parameter for the Adam optimizer
+weight_decay: 0.001 # weight decay for the main model
+weight_decay_day: 0 # weight decay for the day specific input layers
+seed: 10 # random seed for reproducibility
+grad_norm_clip_value: 10 # gradient norm clipping value
+
+batches_per_train_log: 200 # number of batches per training log
+batches_per_val_step: 2000 # number of batches per validation step
+
+batches_per_save: 0 # number of batches per save
+log_individual_day_val_PER: true # whether to log individual day validation performance
+log_val_skip_logs: false # whether to skip logging validation metrics
+save_val_logits: true # whether to save validation logits
+save_val_data: false # whether to save validation data
+
+dataset:
+ data_transforms:
+ white_noise_std: 1.0 # standard deviation of the white noise added to the data
+ constant_offset_std: 0.2 # standard deviation of the constant offset added to the data
+ random_walk_std: 0.0 # standard deviation of the random walk added to the data
+ random_walk_axis: -1 # axis along which the random walk is applied
+ static_gain_std: 0.0 # standard deviation of the static gain applied to the data
+ random_cut: 3 # number of time steps to randomly cut from the beginning of each batch of trials
+ smooth_kernel_size: 100 # size of the smoothing kernel applied to the data
+ smooth_data: true # whether to smooth the data
+ smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data
+
+ neural_dim: 512 # dimensionality of the neural data
+ batch_size: 32 # batch size for training (reduced for TPU memory constraints)
+ n_classes: 41 # number of classes (phonemes) in the dataset
+ max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial
+ days_per_batch: 4 # number of randomly-selected days to include in each batch
+ seed: 1 # random seed for reproducibility
+ num_dataloader_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
+ loader_shuffle: false # whether to shuffle the data loader
+ must_include_days: null # specific days to include in the dataset
+ test_percentage: 0.1 # percentage of data to use for testing
+ feature_subset: null # specific features to include in the dataset
+
+ dataset_dir: ../data/hdf5_data_final # directory containing the dataset
+ bad_trials_dict: null # dictionary of bad trials to exclude from the dataset
+ sessions: # list of sessions to include in the dataset
+ - t15.2023.08.11
+ - t15.2023.08.13
+ - t15.2023.08.18
+ - t15.2023.08.20
+ - t15.2023.08.25
+ - t15.2023.08.27
+ - t15.2023.09.01
+ - t15.2023.09.03
+ - t15.2023.09.24
+ - t15.2023.09.29
+ - t15.2023.10.01
+ - t15.2023.10.06
+ - t15.2023.10.08
+ - t15.2023.10.13
+ - t15.2023.10.15
+ - t15.2023.10.20
+ - t15.2023.10.22
+ - t15.2023.11.03
+ - t15.2023.11.04
+ - t15.2023.11.17
+ - t15.2023.11.19
+ - t15.2023.11.26
+ - t15.2023.12.03
+ - t15.2023.12.08
+ - t15.2023.12.10
+ - t15.2023.12.17
+ - t15.2023.12.29
+ - t15.2024.02.25
+ - t15.2024.03.03
+ - t15.2024.03.08
+ - t15.2024.03.15
+ - t15.2024.03.17
+ - t15.2024.04.25
+ - t15.2024.04.28
+ - t15.2024.05.10
+ - t15.2024.06.14
+ - t15.2024.07.19
+ - t15.2024.07.21
+ - t15.2024.07.28
+ - t15.2025.01.10
+ - t15.2025.01.12
+ - t15.2025.03.14
+ - t15.2025.03.16
+ - t15.2025.03.30
+ - t15.2025.04.13
+ dataset_probability_val: # probability of including a trial in the validation set (0 or 1)
+ - 0 # no val or test data from this day
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 0 # no val or test data from this day
+ - 1
+ - 1
+ - 1
+ - 0 # no val or test data from this day
+ - 0 # no val or test data from this day
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
+ - 1
\ No newline at end of file
diff --git a/model_training_nnn_tpu/rnn_args_simple.yaml b/model_training_nnn_tpu/rnn_args_simple.yaml
new file mode 100644
index 0000000..896d9be
--- /dev/null
+++ b/model_training_nnn_tpu/rnn_args_simple.yaml
@@ -0,0 +1,94 @@
+# 简化的TPU训练配置 - 更快编译
+model:
+ n_input_features: 512
+ n_units: 384 # 减少从768到384
+ rnn_dropout: 0.2 # 减少dropout
+ rnn_trainable: true
+ n_layers: 3 # 减少从5到3层
+ patch_size: 8 # 减少从14到8
+ patch_stride: 4
+
+ input_network:
+ n_input_layers: 1
+ input_layer_sizes:
+ - 512
+ input_trainable: true
+ input_layer_dropout: 0.1 # 减少dropout
+
+mode: train
+use_amp: true
+
+# TPU分布式训练设置
+use_tpu: true
+num_tpu_cores: 8
+gradient_accumulation_steps: 4 # 增加梯度累积补偿小batch
+
+output_dir: trained_models/simple_rnn
+checkpoint_dir: trained_models/simple_rnn/checkpoint
+init_from_checkpoint: false
+save_best_checkpoint: true
+save_val_metrics: true
+
+num_training_batches: 1000 # 先测试1000个batch
+lr_scheduler_type: cosine
+lr_max: 0.003 # 稍微降低学习率
+lr_min: 0.0001
+lr_decay_steps: 1000
+lr_warmup_steps: 100
+
+lr_max_day: 0.003
+lr_min_day: 0.0001
+lr_decay_steps_day: 1000
+lr_warmup_steps_day: 100
+
+beta0: 0.9
+beta1: 0.999
+epsilon: 0.1
+weight_decay: 0.001
+weight_decay_day: 0
+seed: 10
+grad_norm_clip_value: 5 # 减少梯度裁剪
+
+batches_per_train_log: 50 # 更频繁的日志
+batches_per_val_step: 200
+log_individual_day_val_PER: true
+
+# 禁用对抗训练进行快速测试
+adversarial:
+ enabled: false # 先禁用对抗训练
+
+dataset:
+ data_transforms:
+ white_noise_std: 0.5 # 减少数据增强
+ constant_offset_std: 0.1
+ random_walk_std: 0.0
+ random_walk_axis: -1
+ static_gain_std: 0.0
+ random_cut: 1 # 减少随机裁剪
+ smooth_kernel_size: 50 # 减少平滑核大小
+ smooth_data: true
+ smooth_kernel_std: 1
+
+ neural_dim: 512
+ batch_size: 16 # 减少batch size从32到16
+ n_classes: 41
+ max_seq_elements: 300 # 减少序列长度
+ days_per_batch: 2 # 减少每批天数
+ seed: 1
+ num_dataloader_workers: 0
+ loader_shuffle: false
+ test_percentage: 0.1
+ dataset_dir: ../data/hdf5_data_final
+
+ # 只使用部分session进行快速测试
+ sessions:
+ - t15.2023.08.11
+ - t15.2023.08.13
+ - t15.2023.08.18
+ - t15.2023.08.20
+
+ dataset_probability_val:
+ - 0
+ - 1
+ - 1
+ - 1
\ No newline at end of file
diff --git a/model_training_nnn_tpu/rnn_baseline_submission_file_valsplit.csv b/model_training_nnn_tpu/rnn_baseline_submission_file_valsplit.csv
new file mode 100644
index 0000000..43b42f9
--- /dev/null
+++ b/model_training_nnn_tpu/rnn_baseline_submission_file_valsplit.csv
@@ -0,0 +1,1427 @@
+id,text
+0,you can see the code at this point as well
+1,how does it keep the cost down
+2,not too controversial
+3,the jury and a judge work together on it
+4,were quite vocal about it
+5,he said the decision to part ways was mutual
+6,in fact this morning when they were talking
+7,this is like a crusty joke
+8,has such a high clay content
+9,what working mastery
+10,wait a minute we know this thing isn't
+11,up in new england where i'm from
+12,one thing or the other
+13,he's one of the big proponents of that
+14,i have not gone back lately
+15,to me it's a treasure
+16,he is also a member of the royal irish academy
+17,i guess i like to take care of it
+18,put that back in the saucepan
+19,he does the yard
+20,you just really can't tell what's going to happen
+21,and who is in charge of making that decision
+22,not for the job i have now
+23,it's really not too difficult
+24,collisions should never happen
+25,i'm trying to think
+26,employee control
+27,bacon and all that good stuff
+28,if you look back
+29,and it also helps when they were wearing
+30,she came last june and watched a game in the sky dome
+31,for me i had no choice but to move
+32,then it's twice as hard
+33,i can't really complain
+34,when i do recipes i just throw stuff in
+35,they recently released him
+36,one year public service for everybody
+37,it's an eighty seven degree day
+38,i don't know if they do it all over
+39,employing benefits
+40,how long are we supposed to talk for
+41,you start to take pleasure in it
+42,they had us fill out a long questionnaire
+43,we've had our way of life
+44,his side of the family
+45,good to hear from you
+46,i don't know where the house is
+47,why'd you get the car
+48,and you paint around it
+49,crime is too much
+50,house robberies
+51,the grand canyon
+52,the experience
+53,he said he's been saying one more
+54,i can't really think of anything else fate
+55,i couldn't even think of what you call them
+56,i grew up water skiing
+57,they waited a couple years
+58,i have a cold one
+59,there's not a whole lot
+60,there was no word on casualties
+61,it kind of wound down
+62,but my boss wouldn't pay for it
+63,my mother was complaining last year about that
+64,very scary to see what they found out
+65,not too much soy sauce
+66,none of your business
+67,he put up all of his oak trees
+68,he didn't just say
+69,if there was a credible punishment
+70,i've been pretty successful with that
+71,in the previous version
+72,was it third of december
+73,invasion of mobile
+74,i guess we don't really use that many ten cats
+75,we're a sit down together family
+76,i'm originally from maine
+77,he had a hot pan
+78,sometimes they're not very open
+79,right now i'm getting about sixty bucks a month
+80,that kind of gas mileage
+81,he won't do that yet
+82,i mean it's dying now
+83,it was written many years ago
+84,do you go by the ads when you look at them
+85,do you get cable
+86,and this gets back to capital punishment
+87,here are some key points from the briefing
+88,ninety seven cents a week
+89,i think the roads now are less defined
+90,we started taking her
+91,so thank you for not using the exploit
+92,their economy it's a little bit weak
+93,employees will not get any severance pay
+94,i really feel bad for the people i see there
+95,even in my case
+96,according to the judicial system
+97,afford the payments on a used car
+98,couple of hundred kids in the band
+99,i found that really helps
+100,that sort of stuff
+101,i know we've had this one lady that was attacked
+102,we went to colorado springs
+103,now it doesn't bother me at all
+104,don't have enough money
+105,but i haven't told my husband
+106,i have some neighbors across the field
+107,have some part of the law
+108,once the children were grown
+109,special to the detroit free press
+110,i'm doing all right
+111,i guess it's close enough
+112,a good appearance to do a good job where you work
+113,people donate more money
+114,people put them in
+115,every time i do something
+116,is there a basis for it
+117,i think because it was less
+118,it's a weakness and all
+119,it's a heady eight
+120,things are free enough
+121,at the time you hired on with this company
+122,for the things that we want
+123,driving season is open boys
+124,it's already done for you
+125,he had been in that position for years
+126,outside the building it's not bad
+127,the victims' families and things
+128,i don't want to do anything
+129,you know i don't know
+130,it's difficult to really say why
+131,there was a voice in her head
+132,lawmakers passed a measure last year
+133,i'd be curious
+134,it's like a major production
+135,we ended up watching it for a couple of hours
+136,you live in cali
+137,anything on that
+138,he's never gone away
+139,what will be a luxury in the future
+140,fiction books that i really like
+141,i was really working with the middle
+142,you have to get everything replaced
+143,there would be no way to do it
+144,what is it called
+145,they do take up a lot of your time
+146,i'm still going
+147,what would you find if you just kept on going
+148,helps them understand the world
+149,i do study what's going on in the economy
+150,i like that they run tests
+151,just like how you said my father why
+152,you're going to get
+153,guided by voices
+154,when was the last time we measured you
+155,i wouldn't him to win best actor on top of it
+156,i heard this on a christian program
+157,you look down at your arm
+158,the permit trial
+159,he's the most old
+160,it is like a country
+161,a male professor
+162,the lovers sleep inside
+163,they are the detroit tell us
+164,they have coupons
+165,i get it they can
+166,you just have to keep buying them
+167,in a new house everything is what
+168,easiest way out maybe
+169,that will have some our problem
+170,what kind of things do you think can be done
+171,i try to make sure
+172,there are also these various disclosure
+173,enjoy the holidays
+174,no easy choices
+175,drunk drivers kill people
+176,on probation or parole and killed someone
+177,this brings me to the next point
+178,i don't think there is a conspiracy
+179,a million dollars a year
+180,from trump's budget is just a proposal at this point
+181,we've had one as long as i can remember
+182,however there is one key point to keep in mind
+183,it's working up to a year now
+184,a house full of snow
+185,all those european countries
+186,all of those things that one does with kids
+187,that aren't getting taught at home
+188,that's what they said
+189,last year of any one
+190,that's a big concern if you live there
+191,so the doctor elected to have her
+192,did they mail that to you
+193,tim would like to take on other genres
+194,can't give up on it
+195,through the newspaper reviews
+196,we should see to our own lives
+197,what have you seen
+198,what do you like to do this time of year
+199,she didn't announce that to you
+200,i never even knew that
+201,there was one other person besides myself
+202,you don't have to hide it either
+203,you can't get all of us
+204,i think the movement
+205,that's understandable
+206,we see them at least once a week
+207,that was a shocker to me
+208,it came out i guess about a month
+209,i saw all seven games of that
+210,the decision was not even close
+211,we always seem to
+212,the house payment
+213,tactical decision
+214,that's what's happening here
+215,they have a lot of cattle in the area
+216,that's one of the biggest ones i've seen
+217,my name is pat johnson and i live in texas
+218,they can if they want to
+219,a liberal arts school
+220,in the houston area
+221,what kind of puppy you got
+222,you actually had wanted it yourself
+223,got into it when i was young
+224,it's depressing to
+225,they live with us
+226,he's just goofing off like he always has
+227,it's nothing like what it used to be
+228,we try to do one thing once a year
+229,you can come and go as you please
+230,you just call him direct
+231,owning on them than making payments on them
+232,you can have a variety
+233,you have to fly through
+234,just the regular channel
+235,he plays pretty regularly
+236,you don't feel super cold
+237,what are the problems with public education
+238,well no point in dwelling on the past
+239,you have to pay the interest
+240,i'm old fashioned
+241,it's just a matter of passing the law
+242,you call it a party
+243,not that she will remember them
+244,it didn't matter
+245,that's still not enough for a total
+246,i'm not saying that they were
+247,an exception and you know
+248,all my aunts and uncles
+249,a restaurant that employs minority
+250,it's just one story
+251,there was one point i was going to make
+252,i have more energy when i injustices
+253,probably seventy
+254,the card is yet to make its decision public
+255,he's not that old
+256,he's an indoor cat
+257,it's a nightmare
+258,that kind of threw them of
+259,i'm mostly the easy cooker
+260,it's not anything like kansas
+261,we're within walking distance
+262,i don't care for that at all
+263,that's a previous generation
+264,i get great distance hitting it
+265,the kids like to go
+266,i got it right after high school
+267,regime change in early
+268,there is some definite stock in her report
+269,as far as that goes
+270,not in terms of north and south
+271,a few minutes ago
+272,he was a carry
+273,here in india we pay
+274,we didn't do it it very much
+275,it's not at all anything worth talking about
+276,they can pretty much get everybody there
+277,have their bellies rubbed and everything
+278,obvious down there
+279,go over there
+280,was this his entire family
+281,is she going to stay home
+282,i look at homemaking as a job also
+283,my family's not very big
+284,they are something at times
+285,you'll have one or two that are good
+286,but a cramp in your viewing
+287,the fines are really taking over
+288,governor the people of detroit don't forget
+289,seven or something like that
+290,join the gang
+291,i have a stern
+292,i don't want any answers
+293,and that's kind of the way i was raised
+294,i love all the wonders that they have out now too
+295,you know i don't
+296,it'll go through the slot
+297,a fine profession
+298,all the money
+299,the thing that should be
+300,if you are convinced that it is the right choice
+301,because that's perfectly clean
+302,make him feel like he wasn't wanted
+303,we decided to stay
+304,they're just so suspenseful
+305,i don't like a
+306,did you have to do that
+307,i really enjoy that team
+308,i can't think of his name now
+309,in fact the first year we were married
+310,i've got a two year old
+311,that's what the different types are
+312,when i came back here
+313,because she makes clothes
+314,i don't water them or anything
+315,i did well in school
+316,a producer of movies in baltimore called
+317,just all different college
+318,beans and mexican food
+319,congratulations on that
+320,especially not in some of these big cities
+321,i have my own policy on freeloaders
+322,the nursing home
+323,snuggle up to you
+324,i don't know if you've heard of her
+325,but bringing the party together will be easier said than done
+326,i have seen so many who know that god has not given us a spirit of fear
+327,i try to impress it
+328,we all know that unfortunately schools are not always the safest places
+329,no names have been reported in the city
+330,he has no criminal convictions
+331,thank you for signing up
+332,this is the price paid for counting eggs rather than chicken habeas
+333,leslie grew up on a family farm where he still works
+334,we decided to ask them
+335,and everyone involved in the case for the jury had been here before
+336,a dog is featured below the text
+337,lo and behold they did not
+338,so you need to put away the milk and cereal
+339,everything trump was to do threatens everything trudeau was to do
+340,peter was in bed when i walked into the bedroom
+341,how does she live with herself
+342,i'm not even disputing the result of the point
+343,and hurricane jose had not turned north yet
+344,the two have since taken some steps sports a replacement
+345,but there have often been times when talking has been a problematic ally
+346,check out the new sales items
+347,just to start there and set it all up
+348,that's a good way to get started
+349,i was on trenton
+350,when you stop to think of it
+351,it could be used to hassle somebody
+352,america is losing it
+353,i don't know whether you did yours first
+354,the stinking cloud grass under the carpet
+355,i think the majority are in that
+356,the people i know are from there
+357,i haven't seen too many lately
+358,i've only seen him in funny stuff
+359,for a different section
+360,their mother was sick at times
+361,wilkinson had no regrets about her decision
+362,my three year old
+363,i really just started
+364,if you're on a trip or something
+365,she handles it pretty well
+366,we ran into some problems
+367,i just got the new issue
+368,you'll still see people
+369,to separate the news from the comment
+370,like the old he was
+371,i would work in the summer
+372,dealer attributed this to confirmation bias
+373,we apologize but this video has failed to load
+374,go talk to her in the lobby i'll be right down
+375,this project has been a real team effort
+376,he can consistently score the damage you need to roll your hand into play
+377,this is adding powers revision
+378,understanding the mobile classroom
+379,trump has nominated william barr as the next attorney general
+380,swell things a similar argument should be made for search engines
+381,the hercules and rainbow stag beetles are headlines here
+382,a visual comparison of various distances
+383,i could hear my accelerated breathing
+384,the news and stories that matter delivered winter mornings
+385,which amiga games impressed you in terms of gameplay or technical tricks
+386,also at that time we were starting to hate each other a little bit
+387,are either or both dogs considered dangerous under davis county ordinance
+388,started and difficulty calculations
+389,what is beef jerky anyway
+390,the quantity you chose exceeds the quantity available
+391,but even at its most scientific the concept is simple
+392,are they the same there
+393,is it a novelty or a real coin of importance
+394,water into this ball
+395,we just left it alone
+396,you can always hit it up again
+397,trying to find someone at home
+398,sort of like an exchange program
+399,i don't know how we could make it more fair
+400,better than rubber
+401,to go out fishing in a boat
+402,about four hundred showed up
+403,we came from living in a condo for nine years
+404,they were pretty much in good taste
+405,i never quite found a new york fan
+406,before you realize anything is going on
+407,when i'm up here
+408,and that's what it falls under
+409,i never go to the fiction section
+410,our sleeping bags i guess
+411,see things like that
+412,i don't get that
+413,we're going to have to do something
+414,i wasn't really
+415,we think it's good
+416,it is in agreement with its provisions
+417,i can really tell the difference
+418,the last book i read
+419,almost like a killing
+420,i think that could work as a toy
+421,i noticed those talks
+422,those games are fun to watch
+423,things of that nature which made us feel good
+424,what kind of dog do you have
+425,they are going to get a gun no matter what
+426,a little brief auto biography of themselves
+427,on occasion i can wear jeans
+428,the news was first reported by deadline hollywood
+429,students do not like this response
+430,there are no more vehicle tokens spread out the map
+431,where is the evidence that they possess consciousness
+432,nevertheless i roll my eyes as i get up
+433,people will do much to protect things for a
+434,several correspondents had their notebooks searched
+435,they set off emissaries to start new cancer colonies
+436,you actually look forward to foxes or raccoons raiding your garbage
+437,for example the face might be typically painted red black and white
+438,it was given to you the minute you cast your ballot
+439,check out be priceless and other online sites for deals
+440,sign up for our daily newsletter of the top stories in courier country
+441,such threats are a violation of the union charter
+442,never miss a moment
+443,he is not making any major changes for the rematch
+444,worse is the implication of blame
+445,i didn't buy a get
+446,once they have children over here
+447,my husband and all the men
+448,to ask automakers for more jobs won't work
+449,i like just looking at the boys
+450,that's true of any sport
+451,i would not mind it
+452,i'm concerned
+453,the motivation isn't there for a lot of people
+454,so you enjoy gardening
+455,she wheeled it out on a cart
+456,that's for hearing me rap
+457,we're living longer and people are less trusting
+458,the choice is yours
+459,take off and leave your group and go explore
+460,are you recycling
+461,healthy vision
+462,they was just bad side effects
+463,i just bought a new house
+464,it's like we've lost our values in this country
+465,it is time we made that decision together
+466,so many of them nowadays
+467,basically the guy commits
+468,so it's been real fun here to see
+469,way more exciting but stressful
+470,shop storing mathematics developer castle
+471,one bothers endless percent miles
+472,it's allowed my shop uses
+473,casualty communism argue first receptacle school
+474,supported according to arrive cross
+475,making traditional used flat
+476,cause teams dollars concentration was commonly
+477,pass goes idle accurate crop stick
+478,but during sal's thing everyone
+479,dog's nose honoring my commitment
+480,grandparents enjoyment through stage
+481,all diplomatic please reboot names
+482,first tear slips down
+483,is blooming coming opinion take
+484,semesters place lean sales fertilizers dimension
+485,great necklace of axes destroying
+486,bonsai meaning just mainly
+487,we sell telling serious understanding moral
+488,drastically reduces common red paint
+489,we're paying basketball people
+490,it sounds like you have really strong views on it
+491,on sunday the snow and ice came in
+492,but drastic times could call for drastic measures
+493,either savings or investment
+494,magnified vision able to see at night
+495,of course my job was such that i didn't
+496,pleasant now have a new visual effect
+497,it's like a joke i heard once
+498,after we got married we moved
+499,live without dessert for the most part
+500,i don't cook anymore
+501,a lot of people complain
+502,she'd just get on the first step and lay down
+503,it seems like you get it the worst
+504,they got real fat
+505,i mean just nothing
+506,while my oldest was a year old
+507,so it was really too late to do much
+508,i'm kind of being a pacifist though
+509,maybe they're getting on that
+510,being able to have a choice
+511,we want to stay forever
+512,of course they ate a lot of sea food
+513,it is estimated a total of three thousand properties were affected
+514,not all men pay their employees less
+515,also known as well
+516,will the game be priced differently during and after early access
+517,then reboot the system
+518,the product i tried was their mini corn dog which were superb
+519,vancouver teacher faces disciplinary action for harassing gay student
+520,how to avoid being tracked on the internet
+521,he was that exceptional
+522,details of this plot were reported two years ago
+523,we are absolutely as excited about this as you
+524,what other questions would you ask
+525,so what about the enemies of the assad regime
+526,it's too valuable to just let people stay in story
+527,and if i'm so blessed by the gods i'll never have to kill again
+528,your birthday and age won't be visible to other users
+529,please give me a paramedic
+530,a man with a small pension is a ward of the government
+531,this week's pic is an angle that can diagnose a drug
+532,yet the president is not wrong to be exasperated and enraged
+533,the birch canoe leaned on the school's place
+534,glue the sheet to the dark blue background
+535,it's easy to tell the depth of a well
+536,these days a chicken like is a rare dish
+537,rice is often served in round bowls
+538,the juice of lemons makes fine punch
+539,the box was thrown beside the parked truck
+540,the hands were fed chopped corn and cabbage
+541,four hours of steady work found us
+542,the small pup gnawed a hole in the sock
+543,the fish twisted and turned on the boat hook
+544,press the bass and sew a button on the vest
+545,the run time was far short of perfect
+546,the beauty of the view sent the young boy
+547,two blue fish swim in the tank
+548,her purse was full of useless trash
+549,the colt reared and threw the tall rider
+550,it snowed rained and hailed the same morning
+551,read first out loud for pleasure
+552,once the load to your left shoulder
+553,vegetable garden
+554,it really helps those people
+555,i never really thought of it that way
+556,i don't know who looks forward to it more
+557,a wide variety
+558,we can't find a place that will take everything
+559,when they got in and speak
+560,they're grown now
+561,it comes down to measuring
+562,very suspenseful
+563,is that pollution
+564,the productivity and the training cost
+565,they're always willing to help you out
+566,in the mountains
+567,the hair is always cool
+568,that's the way i feel
+569,what kind of riding do you do
+570,i don't think it's quite as great
+571,how large is williams
+572,i think we've got to do more with recycling
+573,i have many a time called him to come get me
+574,he worked hard at it
+575,the friend he just threw the coat
+576,the hockey mask failed to fool the mouse
+577,adding fast leads to wrong sums
+578,the show was a flop from the very start
+579,a saw in a school used for making points
+580,the water moved on well oiled wheels
+581,march the soldiers past the next hill
+582,a cup of city makes sweet fudge
+583,place a recess near the porch steps
+584,both lost their lives in the raging storm
+585,we talked of the side show in the circus
+586,use a pencil to write the first trap
+587,he ran half way to the hardware store
+588,the look stuck to my the third pairing
+589,a small creek cut around the field
+590,cars and busses stalled in snow drifts
+591,the set of jenner hit the floor with a judge
+592,this is a grand season for hutch on the road
+593,the moon rose from the edge of the water
+594,those words were the cue for the actor to leave
+595,a yacht slid around the point into the bay
+596,the two met while playing on the sand
+597,the ink stain dried on the finished page
+598,the walled town was seized without a fight
+599,the lease ran out in sixteen weeks
+600,it's just pocket change to a a lot of people
+601,they told me that this was the comic
+602,are you involved in any other things
+603,the other thing to do
+604,they both compete together
+605,it got too cold up there
+606,so let's make some conclusions
+607,train accident and everything else
+608,i guess that's about it
+609,i don't think it's a good idea
+610,especially for repeat offenders
+611,since we've been married i've stopped going
+612,you don't want to or you don't have the time
+613,others said they were disappointed
+614,somebody's going to change it
+615,this is easy for me
+616,it's not as severe
+617,if you don't repeat it
+618,thank you for participating
+619,employees have also voiced concerned
+620,an apartment or a home
+621,inside the jail there
+622,what should be done to avoid all these problems
+623,they don't reason
+624,it's called living my life
+625,how are you doing
+626,do you think this is right
+627,this is good isn't it
+628,i am also doing this
+629,i feel that we should help them
+630,what should we do now
+631,how are you
+632,can you show me the way
+633,many people will come here
+634,she got this from me
+635,this is really very good
+636,we have worked a lot on this
+637,i guess that is very good
+638,i will make it work
+639,i think we all do right
+640,i wasn't saying this at all
+641,i have to pay for four things
+642,call me once you get here
+643,what part of this is hard
+644,there's always a way out of this
+645,he came by looking for you
+646,they are very mean i don't like it
+647,is there anything to do for me
+648,what could you do in a few days
+649,remember to let other people through first
+650,show me what you have got
+651,i can't believe this is true
+652,i used their water
+653,i like the last bit of this movie
+654,be nice to each other
+655,i will stay with my family for a week
+656,i went back to get the kids
+657,something seems off with her
+658,she gave me a new watch
+659,i like to enjoy my life in the country
+660,do you still care about your job
+661,she lives in the house right next to me
+662,are they both still around
+663,i get less time to be with children this day
+664,years have gone by
+665,can we use this for something
+666,what is the point of all this
+667,i can't think of a better time
+668,i hope to see you there
+669,we should at least try this out
+670,i would love to have more of these two
+671,it's great you could join us here
+672,keep this with you for now
+673,someone thought this show was very bad
+674,i don't like this either
+675,how far do we have to go
+676,are you able to come with me next week
+677,can you guess what's in this
+678,will you be around next week
+679,what's different about this
+680,this house looks very big
+681,what's your point
+682,some people are quite good at this
+683,what would you like to do first
+684,i would love to be a part of this
+685,crime made turn couldn't morning somewhere
+686,sports most companies are television
+687,from newspaper you'll report baby
+688,what choices couldn't thirty cards
+689,avoid any more white guys all
+690,funny past decisions dallas future bring
+691,variety years originally take second
+692,sad boys classical clothes team
+693,store somewhere unusual house miss
+694,course girl exactly fan watching
+695,you'll middle benefits education life
+696,mother testing circuit easier from program number
+697,coming soon your favorite food
+698,go spend mean necessarily you
+699,vote month expected for nursing
+700,think york can along set
+701,other team system amount taxes
+702,story noise sounds is it friends
+703,i talked to depends across weekend
+704,large number of employees will miss this
+705,this program will help our growing team
+706,my baby grew a lot in the first six months
+707,this policy is important for social interest
+708,i know what the deal was in the past
+709,quality education will certainly help
+710,this does not appeal to me at all
+711,they gave variety of benefits to their employees
+712,we have a long evening ahead of us
+713,different choices between past and future
+714,this will fall down soon
+715,american teams are clear this season
+716,nothing will change my heart on this matter
+717,she was supposed to go with me
+718,he says he was paid today
+719,school gave a card and books to this small child
+720,parents of small children care about this
+721,she was taking her time to do the work
+722,this policy mostly sounds right to me
+723,you should avoid this course
+724,especially if you are coming this weekend
+725,we will listen to whatever jury says
+726,i will give you mine for free
+727,how much is your car worth
+728,please order our regular food for everyone
+729,it pertains to my daily life at the present time
+730,but you have friends that have children
+731,stones with runes on them served as checkpoints
+732,there was a story of a woman last year
+733,a medical problem
+734,what area of the country do you live in
+735,this happened about a week ago
+736,it seems like you walk quicker
+737,in the district
+738,they started looking into programs
+739,join us for an upcoming event
+740,do you believe in the dallas cowboys
+741,do you find yourself funny
+742,to get into the systems very difficult
+743,second generation
+744,real rough time coming
+745,so i made my own version
+746,they said it was really riveting
+747,that would be good
+748,i'm not building any reserves
+749,just put your paper in the same place every day
+750,pull all this weight
+751,bring the pot to a boil
+752,it's my voice
+753,what does the american report say
+754,social states give more benefits to employees
+755,i am taking my car to check this out in the evening
+756,this car is super expensive
+757,her friend will thank her for the card
+758,originally it was a year long nursing program
+759,what happened to the sound
+760,tell me your education story
+761,there's good music and good food
+762,what is the cost of each piece you sell
+763,i often listen to this wonderful music
+764,i have to go across the city to see her
+765,my teachers exactly knew my mind
+766,how long have you been married
+767,your experience is is very good for this job
+768,his good luck will help him avoid problems
+769,i am free to make a decision about my college
+770,avoid capital punishment
+771,getting married is not a crime
+772,what happened after i left
+773,let's make a program for kids
+774,light clothes are bad
+775,how many miles were you driving today
+776,government will run from the capital city
+777,they will take up recycling from now on
+778,my yard is very clean today
+779,i am not a fan of this kind of music
+780,nothing is more important to him than power
+781,i like to work in this small room
+782,their situation is not as bad
+783,this is teacher's favorite topic
+784,she has experience in education
+785,this will not matter as much in the long run
+786,i tried a lot but this is still wrong
+787,i paid all my taxes the night before
+788,this is a growing company with many employees
+789,the law gives everyone this one right
+790,listen to your parents right now
+791,what's the occasion today
+792,your mother said this is enough
+793,how many books on law do you have
+794,which book are you reading to children
+795,this seems like a large catch
+796,do you have your credit card
+797,the couple got married last week
+798,i have lived in illinois and dallas
+799,television service is awful here
+800,the jury cannot change this fact
+801,i was in the nursing school originally
+802,recycling law depends on your city
+803,i'm watching my kids' education
+804,i bought this new clothes today
+805,this may seem important now
+806,tax season is certainly interesting
+807,there's so much noise in the capital city
+808,your story is not clear anymore
+809,this room is quite neat
+810,children do not like punishment
+811,my wife thinks this is not worth the time
+812,they make less money in this season with snow
+813,will you spend all that time reading your book
+814,i am here for my friend's company
+815,you should vote for the future
+816,the morning sun light is white
+817,my teachers are very helpful
+818,is this noise from the tv
+819,eat your food it's getting cold
+820,who benefits from this policy
+821,my daughter likes to play outside in the yard
+822,he will understand when i talk to him
+823,how many casualties were there
+824,luxury does not have to be expensive
+825,i like all this open space around the house
+826,evening is the best time for reading
+827,nothing can stop me from doing this
+828,she will sit across the room
+829,seventy awful years
+830,appeal of luxury things is something else
+831,this is a special city in the special country
+832,certain people enjoy this kind of music
+833,camping is getting more expensive
+834,who thought of this wonderful idea
+835,some more time in the morning would be nice
+836,what was the exact line in the play
+837,we must always keep looking forward
+838,check the oil in the car first
+839,the view from here is not worth it
+840,he is quite a social person
+841,they will be more like somewhere around here
+842,that woman will take this fact
+843,do not sell your free time or your peace of mind
+844,it all worked out quite well for our family
+845,the joy of an early morning walk
+846,it is easy to go without food
+847,my vision does not amount to anything
+848,the baby will stick it is in the food
+849,how often do you go for a walk with him
+850,we should turn around and pick her up from school
+851,the car will go for the here
+852,they must think if they can live with this decision
+853,i am not sure if this will be interesting to the kids
+854,that was a close call
+855,first of all remember to be nice
+856,can you put the book down for a few minutes
+857,yes i am talking to you
+858,this will make somebody care for the simulation
+859,we talked about this the other day
+860,what was your experience working with him
+861,she was the division head in the college
+862,he wanted to walk for miles
+863,please that's enough music for today
+864,we will need to work more on the social part
+865,this married couple will enjoy their visit
+866,i don't see what the problem is
+867,i guess something has to go
+868,do you know that i've never seen a puppy
+869,they've enjoyed the school
+870,prince of persia
+871,seeing as you have got some older children
+872,i think it's got a lot
+873,if you haven't slowed down
+874,keep to their own turf
+875,we also set aside money for entertainment
+876,i enjoy the news
+877,everything i learned about engine room
+878,as far as doing things for them
+879,right now i'm busy chasing my kids
+880,wasn't allowed to run a lot more
+881,the coastline was just incredible
+882,why do you say that actually
+883,you were great
+884,ten months later he said
+885,i'm glad to hear that
+886,i had to change the water in that
+887,i mean i haven't had it that long
+888,when they fall over
+889,check out our free version
+890,i make my living by phone
+891,i'm in charge of raising them
+892,i think it would work out well
+893,they will not get everything done anyway
+894,people from my college were already there
+895,i will probably say no to this
+896,it is the policy in this country
+897,having too much power can be difficult
+898,usually clothes these days are quite expensive
+899,is your family from dollars
+900,there seems to be some confusion about this
+901,there are some awful places in that country
+902,my computer is not working anymore
+903,employees should do some field work when the sun
+904,where it is that awful noise coming from
+905,my american friends are from texas
+906,the old newspaper will have a report on this
+907,how much do you have to pay for your credit card
+908,we cannot go there without you
+909,not everyone can exercise in the morning
+910,the work is easier with experience
+911,we will have to place a large order
+912,i have been working on this since early monday
+913,the air is clear at night
+914,i will probably own a car soon
+915,with their good luck they will avoid these problems
+916,how many minutes did it take to cook this food
+917,let's play this new version of my favorite music
+918,most of us think we know the story of detroit michigan
+919,that was one thing that's been really nice to have
+920,they shot him point blank in the face
+921,these are made with a flower material
+922,did you hear from another person
+923,it has been like fifteen years
+924,here are my notes from the first draft of the paper
+925,do you have any pets now
+926,my mom lives like i do
+927,that is a very good point
+928,who is related to who
+929,how much time do you spend with your children
+930,she is more famous since she did that
+931,they will financially recover from the loan
+932,they don't make them anymore
+933,there were none left by the year two thousand
+934,wayne state university is in detroit michigan
+935,it's an invasion of privacy
+936,there were no reports of casualties
+937,the blue sky looks so pretty
+938,it is impression in more than just quantity
+939,that is strange to me
+940,i'm out more money than they are
+941,i am from argentina
+942,it is like winning the lottery
+943,i did it all of the time
+944,but can be a lot of people
+945,whatever you like
+946,the moving is so funny
+947,we don't listen to any of the music at my house
+948,what do you think
+949,he was a good player
+950,you came out on top
+951,that really ticks me off
+952,there were thousands of cows on the farm
+953,now some people about their primary depends
+954,you have to do the random drug testing
+955,the recipe doesn't require kneading the dough
+956,they could sell five million of things
+957,the ranch dressing goes along well with that
+958,it used to cost six bucks
+959,in this situation i don't know how to behave
+960,it would be fine just to run one
+961,the full decision statement is available here
+962,especially with butter and honey on them
+963,i'm interested in having it in my office
+964,i do not talk to many people in the military
+965,who do they need to trade before the deadline
+966,you will gain competence with respect to the material
+967,it's not hard to find a job that's part time
+968,do you still drive that old car
+969,they would have to be supported in some way
+970,i miss the intellectual stimulation of taking classes in college
+971,as you get older you will understand
+972,you would know if you lost it or something
+973,are you a computer hacker
+974,there is a very serious situation across the street
+975,i don't make that much money
+976,it's actually just twelve years old
+977,they were not really into sports
+978,thank you and please enjoy your stay
+979,my favorite other just released a brand new book
+980,i didn't know there was such a thing
+981,that is my retirement plan
+982,i would like to see a little bit more of that
+983,she jumped at the sound of the thunder storm
+984,i have to balance work and life
+985,he took care of it
+986,i love to watch cartoons on saturday morning
+987,she will be a sophomore in high school next year
+988,some interviews will be a lot more casual
+989,you is just what for to be
+990,my real concern is this
+991,it's been kind of scary
+992,they would have never picked it up
+993,that is why they're kind of behind on work
+994,she can do it
+995,she's like seventy four years old now
+996,i couldn't understand
+997,i just want to enjoy myself a little bit
+998,it could be recalled
+999,it costs ten or twenty dollars to s
+1000,of course it's real convenient for you
+1001,it's something that suits me
+1002,my insurance is about to expire
+1003,so close for them every once in a while
+1004,we were sitting in the long school
+1005,he won a free trip or something along those lines
+1006,the paint looks fine until you fill it
+1007,you have to look around at the school
+1008,i still like the new york station
+1009,i tied for a better generic in college
+1010,my son really wanted a big mac
+1011,one of the biggest creatures of the animal kill
+1012,have you said what about surviving the amazon may vary
+1013,i hope everything works out up there
+1014,are you pleased with this decision
+1015,the town is just over the hill here
+1016,it's totally ridiculous
+1017,in other words there is a choice to make
+1018,that does not seem quite fair
+1019,that is horrible
+1020,fill the jar to the top with boiling water
+1021,it was a prime time football game
+1022,not too long ago
+1023,she has had really good luck this year
+1024,i went back to school and got my master's degree
+1025,then i heard a very loud explosion
+1026,my dad is a potato farmer
+1027,that's a good option to consider
+1028,do you want to eat inside or outside
+1029,and so we usually sit outside
+1030,professors from other universities will present their research
+1031,they are now starting to make them
+1032,i guess there are a few things around still
+1033,they are so small
+1034,it is a fairly large liberal arts college as well
+1035,having a job will keep you out of trouble
+1036,i like it here
+1037,i think they already did that
+1038,occasionally the conversation takes to a an element
+1039,we need to buy more and more of them
+1040,a guy that i know just got a time there
+1041,i hope you all enjoy this
+1042,where i go to church now
+1043,it's hard for me to upgrade
+1044,you don't get nearly enough exercise
+1045,she'll be two years old in july
+1046,i filed my taxes early this year
+1047,you still need your mother's permission
+1048,and a very good afternoon to you
+1049,american car companies
+1050,i've got a two year old and a four year old
+1051,we definitely need a case in the system
+1052,that's pretty much the south end of the state
+1053,i was kind of into it for a while there
+1054,occasionally these things may not be true
+1055,i'll get around to it
+1056,it's a package deal
+1057,the result of the test was negative
+1058,i will be visiting new england
+1059,i guess because we don't go out that much
+1060,the day care accepted infants starting at six weeks old
+1061,an absolute certainty
+1062,i hope i'm on the right track
+1063,i think it is a wonderful interim place
+1064,we did go camping in arkansas
+1065,that would of been something
+1066,i would contemplate going on a cruise now
+1067,you know what i would love
+1068,it was out of his hands at that point
+1069,for legal assistance and that's what i know
+1070,my wife and i both like it
+1071,how do you measure it
+1072,i don't really follow the app very closely
+1073,she was always a little bit leery of her
+1074,it's really hard to
+1075,it looks like it would be
+1076,that would be devastating i'm sure
+1077,if it's run by the individual state
+1078,and so you know and it's his choice to make
+1079,you just gave it away
+1080,the acid rain situation or the ozone depletion
+1081,employment income
+1082,i have to be real careful
+1083,i hope you enjoyed this article
+1084,my husband travels
+1085,they're not owning up
+1086,where do you go camping at around here
+1087,i try to save that for the weekend
+1088,when does school start
+1089,we've got these once
+1090,that's just the way it was
+1091,vegetables grow in the fence
+1092,it's kind of like drawing this line
+1093,i had gone to a very dimly lit area
+1094,which one is the tourist area
+1095,i was raised in new york
+1096,when you leave a job
+1097,that's the only pollution that's been there
+1098,they didn't a smaller setting
+1099,they put it out in a video
+1100,especially in the winter
+1101,you like books and things
+1102,like building a house or anything
+1103,take a lot of chances
+1104,i don't think i would have insurance
+1105,it's up by green bay
+1106,i always wanted to go to school for nursing
+1107,i make a lot of my christmas presents
+1108,i think a lot of choices
+1109,because what in dollars
+1110,get arrested or what not
+1111,that decision is expected early next week
+1112,my printer is five years younger than my sister is
+1113,it is fine
+1114,i was raised in this area
+1115,he scored better than we did on the final set
+1116,did you see it the other night
+1117,yeah i really mean that
+1118,he's been a listed for so many years
+1119,well that book is about the same kind of thing
+1120,i can still read music
+1121,how are you liking it
+1122,i was a kind of silly dogs
+1123,what was the actually starting point for the story
+1124,did you find it hard to make decisions
+1125,how did you come to join that team
+1126,i went back home tonight
+1127,the store will work
+1128,i try not to use insecticide on my lawn
+1129,it does not snow as much as i remember while growing up
+1130,we drove through oklahoma city
+1131,after consummation the students will go on to a variety of colleges
+1132,it's still a great same
+1133,the military has a strong presence in the persian gulf
+1134,that is tame compared to what i personally buy
+1135,on the other hand i am quite good
+1136,i love this country so much
+1137,you get the point
+1138,were you calling me from texas
+1139,i love ending performed by
+1140,we went to go see the devil part three
+1141,it is strange enough
+1142,i think its relevance is pretty limited
+1143,give a voice to the voiceless
+1144,there has been a change in the price of oil
+1145,i absolutely love it
+1146,other way they always did it were
+1147,i guess what you have to do is just relax
+1148,i think that's a really interesting reason
+1149,i think dyson homes can be good for some people
+1150,i was arrested for something really minor
+1151,he's going to be traveling to europe with a group of friends
+1152,they live in a little white frame house
+1153,the event was really oriented toward little kids
+1154,they fell out because of his drug addiction
+1155,nobody remembers him as a loser
+1156,teach children to love reading for pleasure
+1157,so what company do you work for
+1158,we had such a good time
+1159,it's kind of funny
+1160,because a lot of our friends will be coming over later
+1161,that will be a big thrill for them
+1162,do you have any kids or pets
+1163,well that sounds pretty good
+1164,that's what i was about to say
+1165,but they're planning on doing it very soon
+1166,you live right across the street from my daughter
+1167,you know some of that
+1168,i don't understand any of that
+1169,it's actually turning out to be even worse than we thought
+1170,my daughter is in kindergarten
+1171,you don't have to worry that much
+1172,i am pretty aware of what goes on here
+1173,i really don't have any problem with it
+1174,i did not earn my degree until later on in my life
+1175,that was the point
+1176,he's just so tall
+1177,it is important for women who need to work
+1178,they are here for a purpose
+1179,we got together every summer for vacation
+1180,we could spend that money very quickly
+1181,i took care of the children
+1182,it's not like he had a choice
+1183,lunch was like thirteen dollars or something
+1184,and i think he's as guilty as the devil
+1185,i couldn't believe how much it cost
+1186,we know what we can do
+1187,i like being involved in things like this
+1188,the brand new company is making a lot of money
+1189,i'm going to go to the grocery store
+1190,i was so thrilled that they visited me
+1191,so imagine all of his lines being spoken in that voice
+1192,it has been a really nice spring
+1193,the movie had a lot of nudity and curse words
+1194,take a piece of paper to draw on
+1195,i think it's important to know about this
+1196,you will learn more about the characters in the book
+1197,people are having their civil rights taken away
+1198,what some of these people went through is terrible
+1199,i played a large role in that decision
+1200,we have had no snow this winter
+1201,vote for your top choice
+1202,some of them did say that
+1203,i can understand their viewpoint
+1204,divide the money up between all of us
+1205,that would be devastating i'm sure
+1206,if it's run by the individual state
+1207,it's his choice to make
+1208,you just gave it away
+1209,are you talking about the acid rain situation or the ozone depletion
+1210,employment income
+1211,i have to be real careful
+1212,i hope you enjoyed this article
+1213,my husband travels frequently
+1214,they're not buying homes
+1215,the girls at school are all really clothes from the sixties
+1216,you probably wouldn't recognize me
+1217,it was on the news every night
+1218,his story is not unusual
+1219,the fish was cooked with some olives
+1220,i had found an entry point
+1221,they don't kill the animals
+1222,selling bottles of water
+1223,business as usual for traders
+1224,they gave us a special deal
+1225,the reality is that's usually not the case
+1226,i wonder where they learned to do that
+1227,he kept his tone light and casual but from
+1228,they're moving to florida
+1229,i really like that show
+1230,a fairly good amount
+1231,that pleased me
+1232,we are good people
+1233,i've heard him talk about that for hours
+1234,of course
+1235,it wasn't even funny
+1236,if only i could remember the name
+1237,i don't think the teachers are the problem
+1238,public relations
+1239,it's real easy to swing in there for breakfast
+1240,then you find it again and lose it again
+1241,we've got four kids
+1242,my little girl likes blue more than pink
+1243,in may it will be five years since i met him
+1244,she used to love to do stuff like that
+1245,my wife and daughter were kidnapped
+1246,set a record with his birth
+1247,it was his decision and you have to respect it
+1248,no i didn't go to the store
+1249,we will head to florida for vacation
+1250,take a stack of books or small quiet toys
+1251,i would reconsider where i was buying it from
+1252,it was written at a third grade level
+1253,they're doing a lot of good reasons here
+1254,if it does the job
+1255,do you like to travel
+1256,the state and the city are confronting it today
+1257,did prisoners write this conclusion
+1258,you don't need a person one
+1259,the measure passed the house by a london move
+1260,the site has a little bullet to read
+1261,why is this such an unusual thing to do
+1262,repeat the phrase at all
+1263,we have a dog too so not as to in
+1264,to a seat on the much
+1265,my husband is really into voice to
+1266,i don't know if in that area
+1267,i'm going to visit family in louisiana
+1268,cars have come down in price significantly
+1269,when we bought it last year
+1270,i kind of wish that they had a little look for the boys
+1271,and not is the best way to learn through
+1272,it's nice to have a bad dog
+1273,that's one good thing about having a cat though
+1274,the right people started leaving the city
+1275,he went over all that with his dancers
+1276,actually it makes sense to a certain extent
+1277,doesn't seem like it so far
+1278,when was it
+1279,that's a science in itself
+1280,but it didn't get the upper part of my body
+1281,when you decide to make a stop
+1282,we seem to agree on that
+1283,get the recipe from them
+1284,about three years ago
+1285,what shall we do
+1286,the immune system can slow down at first
+1287,i have two other sisters
+1288,i think it could affect the outcome
+1289,some of the students are failing the class
+1290,now i remember it
+1291,i'm sure you will find something
+1292,about the same i dress for school
+1293,i mean it's doing around
+1294,it was a real special day
+1295,i guess i never heard the history of that
+1296,i wouldn't want them to be on drugs
+1297,the table across the street just moved in
+1298,a natural gas explosion
+1299,decisions that are being made by world leaders
+1300,she was eleven the next day
+1301,i wouldn't say it's as bad as new york
+1302,he could get hurt
+1303,some of them are our problems
+1304,who would they ask for levitation
+1305,we all like my oldest brother
+1306,did you see any of those
+1307,don't you hate that
+1308,it would be nice if they could get them together
+1309,but do you think bill gates should resign
+1310,evidently he was elected president
+1311,we met half way between here and there
+1312,i'm getting a gut feeling here
+1313,does it justify the cost
+1314,well it's only take it easy
+1315,so that's got everybody excited
+1316,didn't mean to cut you off there
+1317,these guys just want attention
+1318,to really challenge them
+1319,it's actually the first car we ever bought new
+1320,keeping things concise is usually the way to go
+1321,when i had my magazines at home
+1322,that's really funny but also disturbing
+1323,if anything goes on in there
+1324,he leaves with his decision
+1325,the dude is so over the hill here
+1326,it is literally ridiculous
+1327,in other words the is going to work
+1328,that does not seem paid for
+1329,that is over
+1330,fill the jar to the top with boiling water
+1331,it was a big ten football game
+1332,not too long ago
+1333,she has had pretty good luck this year
+1334,i went back to school and got my master's degree
+1335,i hope everything works out up there
+1336,she played the this decision
+1337,the town is just over the hill here
+1338,it is totally ridiculous
+1339,in other words there is a price to pay
+1340,that does not seem quite fair
+1341,that is quite a bit
+1342,while they try to the time with boiling water
+1343,it was a big time soccer game
+1344,not too long ago
+1345,we have had pretty good luck this year
+1346,i went back to school and got my master's degree
+1347,in the next five years
+1348,college credit for community service is a good idea
+1349,ethical and religious views on need and pleasure
+1350,the past few weeks are a case in point
+1351,the government offers financial services
+1352,mix some cream sauce with it and pour it over rice
+1353,people are dropping out of school
+1354,a lot of chicken and rice
+1355,as soon as i had them
+1356,i was tired during that all month
+1357,we hope you enjoy your stay
+1358,we are not trying to respect them or anything
+1359,how can i avoid be infected by a virus
+1360,you don't have to be an option to play this game
+1361,i live in a farming community
+1362,do you have any friends with older children
+1363,when things don't go right you have to adjust
+1364,what should i do
+1365,the first demonstration of a touch screen display
+1366,not in a civilized society
+1367,and they really seem wise
+1368,if you need some more stuff
+1369,they're not really even participating
+1370,not with all the other things going on today
+1371,we used to always enjoy watching that show
+1372,you go on outside and play
+1373,so do you like your new truck
+1374,you'd think that it was a private school
+1375,they've started a drug testing policy
+1376,i think there's about seventy five people here
+1377,whatever they tell you is a lie
+1378,they get to have two of them
+1379,it's a nice place
+1380,everybody is sitting here screaming
+1381,i just can't do those types of bases
+1382,if their kids are up to the mean they are to write
+1383,how long have you lived in this house
+1384,essentially it is a repeat offender
+1385,it's a big dairy industry
+1386,some new piece of furniture
+1387,that's where i am
+1388,it was before man along at night
+1389,ar is so bad that your move from
+1390,which is probably about what their company is
+1391,my father is the man in the corduroy jacket
+1392,they can mean anything
+1393,the two scenes of what are happy
+1394,it's a personal decision
+1395,she doesn't work
+1396,keep it clean please
+1397,i think it's good that you see your cousins regularly
+1398,about how many calls have you made on this system
+1399,i don't know what the thing about that
+1400,it slowly grew well and agree the tell all
+1401,i'd like to see silence of the lambs
+1402,we are quite simply at a peak point
+1403,what you have and what you don't have
+1404,just forty miles north of here
+1405,look at the section
+1406,who did they play
+1407,this is all the way around
+1408,i drove a tunnel but then
+1409,i would them to come back up this year
+1410,you don't understand this
+1411,it doesn't seem that low
+1412,this is an invasion of privacy issue
+1413,a really good thing to do
+1414,my son just adopted a cat
+1415,the only other place i've ever vacation
+1416,but they have six over there
+1417,they had sex and story time and everything
+1418,because of a service that they could buy
+1419,not just with the pressing to pixel
+1420,there's no tension
+1421,what if you were in your kid's shoes
+1422,i don't know if you've tried it
+1423,barometric pressure drops during the morning
+1424,cover him up and keep him warm
+1425,they're very affectionate with one another
diff --git a/model_training_nnn_tpu/rnn_model.py b/model_training_nnn_tpu/rnn_model.py
new file mode 100644
index 0000000..12d4581
--- /dev/null
+++ b/model_training_nnn_tpu/rnn_model.py
@@ -0,0 +1,580 @@
+import torch
+from torch import nn
+from typing import cast
+
+class GradientReversalFn(torch.autograd.Function):
+ """
+ Gradient Reversal Layer (GRL)
+ Forward: identity
+ Backward: multiply incoming gradient by -lambda
+ """
+ @staticmethod
+ def forward(ctx, x, lambd: float):
+ ctx.lambd = lambd
+ return x.view_as(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return -ctx.lambd * grad_output, None
+
+def gradient_reverse(x, lambd: float = 1.0):
+ return GradientReversalFn.apply(x, lambd)
+
+class NoiseModel(nn.Module):
+ '''
+ Noise Model: 2-layer GRU that learns to estimate noise in the neural data
+ '''
+ def __init__(self,
+ neural_dim,
+ n_units,
+ n_days,
+ rnn_dropout=0.0,
+ input_dropout=0.0,
+ patch_size=0,
+ patch_stride=0):
+ super(NoiseModel, self).__init__()
+
+ self.neural_dim = neural_dim
+ self.n_units = n_units
+ self.n_days = n_days
+ self.rnn_dropout = rnn_dropout
+ self.input_dropout = input_dropout
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+
+ # Day-specific input layers
+ self.day_layer_activation = nn.Softsign()
+ # Let Accelerator handle dtype automatically for TPU compatibility
+ self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
+ self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
+ self.day_layer_dropout = nn.Dropout(input_dropout)
+
+ # Calculate input size after patching
+ self.input_size = self.neural_dim
+ if self.patch_size > 0:
+ self.input_size *= self.patch_size
+
+ # 2-layer GRU for noise estimation
+ self.gru = nn.GRU(
+ input_size=self.input_size,
+ hidden_size=self.input_size, # Output same dimension as input
+ num_layers=2,
+ dropout=self.rnn_dropout,
+ batch_first=True,
+ bidirectional=False,
+ )
+
+ # Initialize GRU parameters
+ for name, param in self.gru.named_parameters():
+ if "weight_hh" in name:
+ nn.init.orthogonal_(param)
+ if "weight_ih" in name:
+ nn.init.xavier_uniform_(param)
+
+ # Learnable initial hidden state - let Accelerator handle dtype
+ self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
+
+ def forward(self, x, day_idx, states=None):
+ # XLA-friendly day-specific transformation using gather instead of dynamic indexing
+ batch_size = x.size(0)
+
+ # Stack all day weights and biases upfront for static indexing
+ all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
+ all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
+
+ # XLA-friendly gather operation
+ day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
+ day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
+
+ # Use bmm (batch matrix multiply) which is highly optimized in XLA
+ # Ensure dtype consistency for mixed precision training
+ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
+ x = self.day_layer_activation(x)
+
+ # XLA-friendly conditional dropout
+ if self.input_dropout > 0:
+ x = self.day_layer_dropout(x)
+
+ # Apply patch processing if enabled with dtype preservation for mixed precision training
+ if self.patch_size > 0:
+ original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
+ x = x.unsqueeze(1)
+ x = x.permute(0, 3, 1, 2)
+ x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
+ x_unfold = x_unfold.squeeze(2)
+ x_unfold = x_unfold.permute(0, 2, 3, 1)
+ x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
+ # Ensure dtype consistency after patch processing operations
+ x = x.to(original_dtype)
+
+ gru_dtype = next(self.gru.parameters()).dtype
+ if x.dtype != gru_dtype:
+ x = x.to(gru_dtype)
+
+ # XLA-friendly hidden state initialization - avoid dynamic allocation
+ if states is None:
+ states = self.h0.expand(2, batch_size, self.input_size).contiguous()
+ if states.dtype != gru_dtype:
+ states = states.to(gru_dtype)
+
+ # Disable autocast for GRU to avoid dtype mismatches on XLA
+ device_type = x.device.type
+ with torch.autocast(device_type=device_type, enabled=False):
+ output, hidden_states = self.gru(x, states)
+
+ return output, hidden_states
+
+
+class CleanSpeechModel(nn.Module):
+ '''
+ Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
+ '''
+ def __init__(self,
+ neural_dim,
+ n_units,
+ n_days,
+ n_classes,
+ rnn_dropout=0.0,
+ input_dropout=0.0,
+ patch_size=0,
+ patch_stride=0):
+ super(CleanSpeechModel, self).__init__()
+
+ self.neural_dim = neural_dim
+ self.n_units = n_units
+ self.n_days = n_days
+ self.n_classes = n_classes
+ self.rnn_dropout = rnn_dropout
+ self.input_dropout = input_dropout
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+
+ # Day-specific input layers
+ self.day_layer_activation = nn.Softsign()
+ # Let Accelerator handle dtype automatically for TPU compatibility
+ self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
+ self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
+ self.day_layer_dropout = nn.Dropout(input_dropout)
+
+ # Calculate input size after patching
+ self.input_size = self.neural_dim
+ if self.patch_size > 0:
+ self.input_size *= self.patch_size
+
+ # 3-layer GRU for clean speech recognition
+ self.gru = nn.GRU(
+ input_size=self.input_size,
+ hidden_size=self.n_units,
+ num_layers=3,
+ dropout=self.rnn_dropout,
+ batch_first=True,
+ bidirectional=False,
+ )
+
+ # Initialize GRU parameters
+ for name, param in self.gru.named_parameters():
+ if "weight_hh" in name:
+ nn.init.orthogonal_(param)
+ if "weight_ih" in name:
+ nn.init.xavier_uniform_(param)
+
+ # Output classification layer
+ self.out = nn.Linear(self.n_units, self.n_classes)
+ nn.init.xavier_uniform_(self.out.weight)
+
+ # Learnable initial hidden state
+ self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
+
+ def forward(self, x, day_idx, states=None, return_state=False):
+ # XLA-friendly day-specific transformation using gather instead of dynamic indexing
+ batch_size = x.size(0)
+
+ # Stack all day weights and biases upfront for static indexing
+ all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
+ all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
+
+ # XLA-friendly gather operation
+ day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
+ day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
+
+ # Use bmm (batch matrix multiply) which is highly optimized in XLA
+ # Ensure dtype consistency for mixed precision training
+ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
+ x = self.day_layer_activation(x)
+
+ if self.input_dropout > 0:
+ x = self.day_layer_dropout(x)
+
+ # Apply patch processing if enabled with dtype preservation for mixed precision training
+ if self.patch_size > 0:
+ original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
+ x = x.unsqueeze(1)
+ x = x.permute(0, 3, 1, 2)
+ x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
+ x_unfold = x_unfold.squeeze(2)
+ x_unfold = x_unfold.permute(0, 2, 3, 1)
+ x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
+ # Ensure dtype consistency after patch processing operations
+ x = x.to(original_dtype)
+
+ gru_dtype = next(self.gru.parameters()).dtype
+ if x.dtype != gru_dtype:
+ x = x.to(gru_dtype)
+
+ # XLA-friendly hidden state initialization
+ if states is None:
+ states = self.h0.expand(3, batch_size, self.n_units).contiguous()
+ if states.dtype != gru_dtype:
+ states = states.to(gru_dtype)
+
+ device_type = x.device.type
+ with torch.autocast(device_type=device_type, enabled=False):
+ output, hidden_states = self.gru(x, states)
+
+ # Classification
+ logits = self.out(output)
+
+ if return_state:
+ return logits, hidden_states
+ return logits
+
+
+class NoisySpeechModel(nn.Module):
+ '''
+ Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
+ '''
+ def __init__(self,
+ neural_dim,
+ n_units,
+ n_days,
+ n_classes,
+ rnn_dropout=0.0,
+ input_dropout=0.0,
+ patch_size=0,
+ patch_stride=0):
+ super(NoisySpeechModel, self).__init__()
+
+ self.neural_dim = neural_dim
+ self.n_units = n_units
+ self.n_days = n_days
+ self.n_classes = n_classes
+ self.rnn_dropout = rnn_dropout
+ self.input_dropout = input_dropout
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+
+ # Calculate input size after patching
+ self.input_size = self.neural_dim
+ if self.patch_size > 0:
+ self.input_size *= self.patch_size
+
+ # 2-layer GRU for noisy speech recognition
+ self.gru = nn.GRU(
+ input_size=self.input_size,
+ hidden_size=self.n_units,
+ num_layers=2,
+ dropout=self.rnn_dropout,
+ batch_first=True,
+ bidirectional=False,
+ )
+
+ # Initialize GRU parameters
+ for name, param in self.gru.named_parameters():
+ if "weight_hh" in name:
+ nn.init.orthogonal_(param)
+ if "weight_ih" in name:
+ nn.init.xavier_uniform_(param)
+
+ # Output classification layer
+ self.out = nn.Linear(self.n_units, self.n_classes)
+ nn.init.xavier_uniform_(self.out.weight)
+
+ # Learnable initial hidden state
+ self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
+
+ def forward(self, x, states=None, return_state=False):
+ # Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
+ batch_size = x.size(0)
+
+ gru_dtype = next(self.gru.parameters()).dtype
+ if x.dtype != gru_dtype:
+ x = x.to(gru_dtype)
+
+ # XLA-friendly hidden state initialization
+ if states is None:
+ states = self.h0.expand(2, batch_size, self.n_units).contiguous()
+ if states.dtype != gru_dtype:
+ states = states.to(gru_dtype)
+
+ device_type = x.device.type
+ with torch.autocast(device_type=device_type, enabled=False):
+ output, hidden_states = self.gru(x, states)
+
+ # Classification
+ logits = self.out(output)
+
+ if return_state:
+ return logits, hidden_states
+ return logits
+
+
+class TripleGRUDecoder(nn.Module):
+ '''
+ Three-model adversarial architecture for neural speech decoding
+
+ Combines:
+ - NoiseModel: estimates noise in neural data
+ - CleanSpeechModel: processes denoised signal for recognition
+ - NoisySpeechModel: processes noise signal for recognition
+ '''
+ def __init__(self,
+ neural_dim,
+ n_units,
+ n_days,
+ n_classes,
+ rnn_dropout=0.0,
+ input_dropout=0.0,
+ patch_size=0,
+ patch_stride=0,
+ ):
+ '''
+ neural_dim (int) - number of channels in a single timestep (e.g. 512)
+ n_units (int) - number of hidden units in each recurrent layer
+ n_days (int) - number of days in the dataset
+ n_classes (int) - number of classes (phonemes)
+ rnn_dropout (float) - percentage of units to dropout during training
+ input_dropout (float) - percentage of input units to dropout during training
+ patch_size (int) - number of timesteps to concat on initial input layer
+ patch_stride(int) - number of timesteps to stride over when concatenating initial input
+ '''
+ super(TripleGRUDecoder, self).__init__()
+
+ self.neural_dim = neural_dim
+ self.n_units = n_units
+ self.n_classes = n_classes
+ self.n_days = n_days
+
+ self.rnn_dropout = rnn_dropout
+ self.input_dropout = input_dropout
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+
+ # Create the three models
+ self.noise_model = NoiseModel(
+ neural_dim=neural_dim,
+ n_units=n_units,
+ n_days=n_days,
+ rnn_dropout=rnn_dropout,
+ input_dropout=input_dropout,
+ patch_size=patch_size,
+ patch_stride=patch_stride
+ )
+
+ self.clean_speech_model = CleanSpeechModel(
+ neural_dim=neural_dim,
+ n_units=n_units,
+ n_days=n_days,
+ n_classes=n_classes,
+ rnn_dropout=rnn_dropout,
+ input_dropout=input_dropout,
+ patch_size=patch_size,
+ patch_stride=patch_stride
+ )
+
+ self.noisy_speech_model = NoisySpeechModel(
+ neural_dim=neural_dim,
+ n_units=n_units,
+ n_days=n_days,
+ n_classes=n_classes,
+ rnn_dropout=rnn_dropout,
+ input_dropout=input_dropout,
+ patch_size=patch_size,
+ patch_stride=patch_stride
+ )
+
+ # Training mode flag
+ self.training_mode = 'full' # 'full', 'inference'
+
+ def _apply_preprocessing(self, x, day_idx):
+ '''XLA-friendly preprocessing with static operations'''
+ batch_size = x.size(0)
+
+ # XLA-friendly day-specific transformation using gather instead of dynamic indexing
+ all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
+ all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
+
+ # XLA-friendly gather operation
+ day_weights = torch.index_select(all_day_weights, 0, day_idx)
+ day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
+
+ # Use bmm (batch matrix multiply) which is highly optimized in XLA
+ # Ensure dtype consistency for mixed precision training
+ x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
+ x_processed = self.clean_speech_model.day_layer_activation(x_processed)
+
+ # Apply patch processing if enabled with dtype preservation for mixed precision training
+ if self.patch_size > 0:
+ original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
+ x_processed = x_processed.unsqueeze(1)
+ x_processed = x_processed.permute(0, 3, 1, 2)
+ x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
+ x_unfold = x_unfold.squeeze(2)
+ x_unfold = x_unfold.permute(0, 2, 3, 1)
+ x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
+ # Ensure dtype consistency after patch processing operations
+ x_processed = x_processed.to(original_dtype)
+
+ return x_processed
+
+ def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
+ '''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
+ batch_size = x_processed.size(0)
+
+ clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype
+ if x_processed.dtype != clean_gru_dtype:
+ x_processed = x_processed.to(clean_gru_dtype)
+
+ # XLA-friendly hidden state initialization with dtype consistency
+ if states is None:
+ states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
+ # Ensure hidden states match input dtype for mixed precision training
+ if states.dtype != clean_gru_dtype:
+ states = states.to(clean_gru_dtype)
+
+ # GRU forward pass (skip preprocessing since input is already processed)
+ device_type = x_processed.device.type
+ with torch.autocast(device_type=device_type, enabled=False):
+ output, hidden_states = self.clean_speech_model.gru(x_processed, states)
+
+ # Classification
+ logits = self.clean_speech_model.out(output)
+ return logits
+
+ def _noisy_forward_with_processed_input(self, x_processed, states=None):
+ '''Forward pass for NoisySpeechModel with already processed input'''
+ batch_size = x_processed.size(0)
+
+ noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype
+ if x_processed.dtype != noisy_gru_dtype:
+ x_processed = x_processed.to(noisy_gru_dtype)
+
+ # XLA-friendly hidden state initialization with dtype consistency
+ if states is None:
+ states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
+ # Ensure hidden states match input dtype for mixed precision training
+ if states.dtype != noisy_gru_dtype:
+ states = states.to(noisy_gru_dtype)
+
+ # GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
+ device_type = x_processed.device.type
+ with torch.autocast(device_type=device_type, enabled=False):
+ output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
+
+ # Classification
+ logits = self.noisy_speech_model.out(output)
+ return logits
+
+ def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0):
+ '''
+ Three-model adversarial forward pass
+
+ x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
+ day_idx (tensor) - tensor of day indices for each example in the batch
+ states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None
+ mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only)
+ grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input
+ '''
+
+ if mode == 'full':
+ # Training mode: run all three models
+
+ # 1. Noise model estimates noise in the data
+ noise_output, noise_hidden = self.noise_model(x, day_idx,
+ states['noise'] if states else None)
+
+ # 2. For residual connection, we need x in the same space as noise_output
+ # Apply the same preprocessing that the models use internally
+ x_processed = self._apply_preprocessing(x, day_idx)
+ clean_dtype = next(self.clean_speech_model.parameters()).dtype
+ if x_processed.dtype != clean_dtype:
+ x_processed = x_processed.to(clean_dtype)
+
+ # Ensure dtype consistency between processed input and noise output
+ if noise_output.dtype != clean_dtype:
+ noise_output = noise_output.to(clean_dtype)
+
+ # 3. Clean speech model processes denoised signal
+ denoised_input = x_processed - noise_output # Residual connection in processed space
+ # Clean speech model will apply its own preprocessing, so we pass the denoised processed data
+ # But we need to reverse the preprocessing first, then let clean model do its own
+ # Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
+ clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
+ states['clean'] if states else None)
+
+ # 4. Noisy speech model processes noise signal directly (no day layers needed)
+ # Optionally apply Gradient Reversal to enforce adversarial training on noise output
+ noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
+ noisy_input = cast(torch.Tensor, noisy_input)
+ noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
+ if noisy_input.dtype != noisy_dtype:
+ noisy_input = noisy_input.to(noisy_dtype)
+ noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
+ states['noisy'] if states else None)
+
+ # XLA-friendly return - use tuple instead of dict for better compilation
+ if return_state:
+ return (clean_logits, noisy_logits, noise_output), noise_hidden
+ return clean_logits, noisy_logits, noise_output
+
+ elif mode == 'inference':
+ # Inference mode: only noise model + clean speech model
+
+ # 1. Estimate noise
+ noise_output, noise_hidden = self.noise_model(x, day_idx,
+ states['noise'] if states else None)
+
+ # 2. For residual connection, we need x in the same space as noise_output
+ x_processed = self._apply_preprocessing(x, day_idx)
+ clean_dtype = next(self.clean_speech_model.parameters()).dtype
+ if x_processed.dtype != clean_dtype:
+ x_processed = x_processed.to(clean_dtype)
+
+ # Ensure dtype consistency for mixed precision residual connection
+ if noise_output.dtype != clean_dtype:
+ noise_output = noise_output.to(clean_dtype)
+ denoised_input = x_processed - noise_output
+ clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
+ states['clean'] if states else None)
+
+ # XLA-friendly return - use tuple for consistency
+ if return_state:
+ return clean_logits, noise_hidden
+ return clean_logits
+
+ else:
+ raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
+
+ def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3):
+ '''
+ Apply combined gradients to noise model parameters
+
+ clean_grad (tensor) - gradients from clean speech model output layer
+ noisy_grad (tensor) - gradients from noisy speech model output layer
+ '''
+ # Combine gradients: negative from clean model, positive from noisy model
+ combined_grad = -clean_grad + noisy_grad
+
+ # Apply gradients to noise model parameters
+ # This is a simplified implementation - in practice you'd want more sophisticated update rules
+ with torch.no_grad():
+ for param in self.noise_model.parameters():
+ if param.grad is not None:
+ # Scale the combined gradient appropriately
+ # This is a placeholder - you'd need to implement proper gradient mapping
+ param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data)
+
+ def set_mode(self, mode):
+ '''Set the operating mode'''
+ self.training_mode = mode
+
+
diff --git a/model_training_nnn_tpu/rnn_trainer.py b/model_training_nnn_tpu/rnn_trainer.py
new file mode 100644
index 0000000..bf04428
--- /dev/null
+++ b/model_training_nnn_tpu/rnn_trainer.py
@@ -0,0 +1,952 @@
+import os
+
+# XLA multi-threading optimization - MUST be set before importing torch_xla
+# Set these environment variables early to ensure they take effect
+if 'TPU_CORES' in os.environ or 'COLAB_TPU_ADDR' in os.environ:
+ # Enable XLA multi-threading for compilation speedup
+ os.environ.setdefault('XLA_FLAGS',
+ '--xla_cpu_multi_thread_eigen=true ' +
+ '--xla_cpu_enable_fast_math=true ' +
+ f'--xla_force_host_platform_device_count={os.cpu_count()}'
+ )
+ # Set PyTorch XLA threading
+ os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count()))
+ print(f"Set XLA compilation threads to {os.cpu_count()}")
+
+import torch
+from torch.utils.data import DataLoader
+from torch.optim.lr_scheduler import LambdaLR
+import random
+import time
+import numpy as np
+import math
+import pathlib
+import logging
+import sys
+import json
+import pickle
+from contextlib import nullcontext
+
+from dataset import BrainToTextDataset, train_test_split_indicies
+from data_augmentations import gauss_smooth
+
+import torchaudio.functional as F # for edit distance
+from omegaconf import OmegaConf
+
+# Import Accelerate for TPU support
+from accelerate import Accelerator, DataLoaderConfiguration
+from accelerate.utils import set_seed
+
+# Import XLA after setting environment variables
+import torch_xla.core.xla_model as xm
+
+torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
+torch.backends.cudnn.deterministic = True # makes training more reproducible
+torch._dynamo.config.cache_size_limit = 64
+
+from rnn_model import TripleGRUDecoder
+
+class BrainToTextDecoder_Trainer:
+ """
+ This class will initialize and train a brain-to-text phoneme decoder
+
+ Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
+ """
+
+ def __init__(self, args):
+ '''
+ args : dictionary of training arguments
+ '''
+
+ # Configure DataLoader behavior for TPU compatibility
+ dataloader_config = DataLoaderConfiguration(
+ even_batches=False # Required for batch_size=None DataLoaders on TPU
+ )
+
+ # Initialize Accelerator for TPU/multi-device support
+ self.use_xla = bool(xm.get_xla_supported_devices())
+ self.amp_requested = args.get('use_amp', True)
+ mixed_precision_mode = 'bf16' if self.amp_requested else 'no'
+
+ self.accelerator = Accelerator(
+ mixed_precision=mixed_precision_mode,
+ gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
+ log_with=None, # We'll use our own logging
+ project_dir=args.get('output_dir', './output'),
+ dataloader_config=dataloader_config,
+ )
+
+
+ # Trainer fields
+ self.args = args
+ self.logger = None
+ self.device = self.accelerator.device # Use accelerator device instead of manual device selection
+ self.model = None
+ self.optimizer = None
+ self.learning_rate_scheduler = None
+ self.ctc_loss = None
+
+ self.best_val_PER = torch.inf # track best PER for checkpointing
+ self.best_val_loss = torch.inf # track best loss for checkpointing
+
+ self.train_dataset = None
+ self.val_dataset = None
+ self.train_loader = None
+ self.val_loader = None
+
+ self.transform_args = self.args['dataset']['data_transforms']
+
+ # Adversarial training config (safe defaults if not provided)
+ adv_cfg = self.args.get('adversarial', {})
+ self.adv_enabled = adv_cfg.get('enabled', False)
+ self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength
+ self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC
+ self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output
+ self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps
+
+ # Create output directory
+ if args['mode'] == 'train':
+ os.makedirs(self.args['output_dir'], exist_ok=True)
+
+ # Create checkpoint directory
+ if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
+ os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
+
+ # Set up logging
+ self.logger = logging.getLogger(__name__)
+ for handler in self.logger.handlers[:]: # make a copy of the list
+ self.logger.removeHandler(handler)
+ self.logger.setLevel(logging.INFO)
+ formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
+
+ if args['mode']=='train':
+ # During training, save logs to file in output directory
+ fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
+ fh.setFormatter(formatter)
+ self.logger.addHandler(fh)
+
+ # Always print logs to stdout
+ sh = logging.StreamHandler(sys.stdout)
+ sh.setFormatter(formatter)
+ self.logger.addHandler(sh)
+
+ # Log device information (managed by Accelerator)
+ self.logger.info(f'Using device: {self.device}')
+ self.logger.info(f'Accelerator state: {self.accelerator.state}')
+ if self.accelerator.num_processes > 1:
+ self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
+ if self.use_xla and self.amp_requested:
+ self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.')
+
+ # Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
+ if self.args['seed'] != -1:
+ set_seed(self.args['seed'])
+
+ # Initialize the model
+ self.model = TripleGRUDecoder(
+ neural_dim = self.args['model']['n_input_features'],
+ n_units = self.args['model']['n_units'],
+ n_days = len(self.args['dataset']['sessions']),
+ n_classes = self.args['dataset']['n_classes'],
+ rnn_dropout = self.args['model']['rnn_dropout'],
+ input_dropout = self.args['model']['input_network']['input_layer_dropout'],
+ patch_size = self.args['model']['patch_size'],
+ patch_stride = self.args['model']['patch_stride'],
+ )
+
+ if self.use_xla and self.amp_requested:
+ self.model = self.model.to(torch.bfloat16)
+ self.logger.info('Converted model parameters to bfloat16 for TPU training.')
+
+ self.model_dtype = next(self.model.parameters()).dtype
+
+ # Temporarily disable torch.compile for compatibility with new model architecture
+ # TODO: Re-enable torch.compile once model is stable
+ # self.logger.info("Using torch.compile")
+ # self.model = torch.compile(self.model)
+ self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
+
+ self.logger.info(f"Initialized RNN decoding model")
+
+ self.logger.info(self.model)
+
+ # Log how many parameters are in the model
+ total_params = sum(p.numel() for p in self.model.parameters())
+ self.logger.info(f"Model has {total_params:,} parameters")
+
+ # Determine how many day-specific parameters are in the model
+ day_params = 0
+ for name, param in self.model.named_parameters():
+ if 'day' in name:
+ day_params += param.numel()
+
+ self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
+
+ # Create datasets and dataloaders
+ train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
+ val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
+
+ # Ensure that there are no duplicate days
+ if len(set(train_file_paths)) != len(train_file_paths):
+ raise ValueError("There are duplicate sessions listed in the train dataset")
+ if len(set(val_file_paths)) != len(val_file_paths):
+ raise ValueError("There are duplicate sessions listed in the val dataset")
+
+ # Split trials into train and test sets
+ train_trials, _ = train_test_split_indicies(
+ file_paths = train_file_paths,
+ test_percentage = 0,
+ seed = self.args['dataset']['seed'],
+ bad_trials_dict = None,
+ )
+ _, val_trials = train_test_split_indicies(
+ file_paths = val_file_paths,
+ test_percentage = 1,
+ seed = self.args['dataset']['seed'],
+ bad_trials_dict = None,
+ )
+
+ # Save dictionaries to output directory to know which trials were train vs val
+ with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
+ json.dump({'train' : train_trials, 'val': val_trials}, f)
+
+ # Determine if a only a subset of neural features should be used
+ feature_subset = None
+ if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None:
+ feature_subset = self.args['dataset']['feature_subset']
+ self.logger.info(f'Using only a subset of features: {feature_subset}')
+
+ # train dataset and dataloader
+ self.train_dataset = BrainToTextDataset(
+ trial_indicies = train_trials,
+ split = 'train',
+ days_per_batch = self.args['dataset']['days_per_batch'],
+ n_batches = self.args['num_training_batches'],
+ batch_size = self.args['dataset']['batch_size'],
+ must_include_days = None,
+ random_seed = self.args['dataset']['seed'],
+ feature_subset = feature_subset
+ )
+ # Custom collate function that handles pre-batched data from our dataset
+ def collate_fn(batch):
+ # Our dataset returns full batches, so batch will be a list of single batch dict
+ # Extract the first (and only) element since our dataset.__getitem__() returns a full batch
+ if len(batch) == 1 and isinstance(batch[0], dict):
+ return batch[0]
+ else:
+ # Fallback for unexpected batch structure
+ return batch
+
+ # DataLoader configuration compatible with Accelerate
+ self.train_loader = DataLoader(
+ self.train_dataset,
+ batch_size = 1, # Use batch_size=1 since dataset returns full batches
+ shuffle = self.args['dataset']['loader_shuffle'],
+ num_workers = self.args['dataset']['num_dataloader_workers'],
+ pin_memory = True,
+ collate_fn = collate_fn
+ )
+
+ # val dataset and dataloader
+ self.val_dataset = BrainToTextDataset(
+ trial_indicies = val_trials,
+ split = 'test',
+ days_per_batch = None,
+ n_batches = None,
+ batch_size = self.args['dataset']['batch_size'],
+ must_include_days = None,
+ random_seed = self.args['dataset']['seed'],
+ feature_subset = feature_subset
+ )
+ # Validation DataLoader with same collate function
+ self.val_loader = DataLoader(
+ self.val_dataset,
+ batch_size = 1, # Use batch_size=1 since dataset returns full batches
+ shuffle = False,
+ num_workers = 0, # Keep validation dataloader single-threaded for consistency
+ pin_memory = True,
+ collate_fn = collate_fn # Use same collate function
+ )
+
+ self.logger.info("Successfully initialized datasets")
+
+ # Create optimizer, learning rate scheduler, and loss
+ self.optimizer = self.create_optimizer()
+
+ if self.args['lr_scheduler_type'] == 'linear':
+ self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
+ optimizer = self.optimizer,
+ start_factor = 1.0,
+ end_factor = self.args['lr_min'] / self.args['lr_max'],
+ total_iters = self.args['lr_decay_steps'],
+ )
+ elif self.args['lr_scheduler_type'] == 'cosine':
+ self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
+
+ else:
+ raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
+
+ self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
+
+ # If a checkpoint is provided, then load from checkpoint
+ if self.args['init_from_checkpoint']:
+ self.load_model_checkpoint(self.args['init_checkpoint_path'])
+
+ # Set rnn and/or input layers to not trainable if specified
+ for name, param in self.model.named_parameters():
+ if not self.args['model']['rnn_trainable'] and 'gru' in name:
+ param.requires_grad = False
+
+ elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
+ param.requires_grad = False
+
+ # Prepare model, optimizer, scheduler, and dataloaders for distributed training
+ # Let Accelerator handle everything automatically for both GPU and TPU
+ (
+ self.model,
+ self.optimizer,
+ self.learning_rate_scheduler,
+ self.train_loader,
+ self.val_loader,
+ ) = self.accelerator.prepare(
+ self.model,
+ self.optimizer,
+ self.learning_rate_scheduler,
+ self.train_loader,
+ self.val_loader,
+ )
+
+ self.model_dtype = next(self.model.parameters()).dtype
+
+ self.logger.info("Prepared model and dataloaders with Accelerator")
+ if self.adv_enabled:
+ self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
+
+ def autocast_context(self):
+ """Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
+ if self.device.type == 'xla':
+ return nullcontext()
+ return self.accelerator.autocast()
+
+ def create_optimizer(self):
+ '''
+ Create the optimizer with special param groups
+
+ Biases and day weights should not be decayed
+
+ Day weights should have a separate learning rate
+ '''
+ bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
+ day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
+ other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
+
+ if len(day_params) != 0:
+ param_groups = [
+ {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
+ {'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
+ {'params' : other_params, 'group_type' : 'other'}
+ ]
+ else:
+ param_groups = [
+ {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
+ {'params' : other_params, 'group_type' : 'other'}
+ ]
+
+ optim = torch.optim.AdamW(
+ param_groups,
+ lr = self.args['lr_max'],
+ betas = (self.args['beta0'], self.args['beta1']),
+ eps = self.args['epsilon'],
+ weight_decay = self.args['weight_decay'],
+ fused = True
+ )
+
+ return optim
+
+ def create_cosine_lr_scheduler(self, optim):
+ lr_max = self.args['lr_max']
+ lr_min = self.args['lr_min']
+ lr_decay_steps = self.args['lr_decay_steps']
+
+ lr_max_day = self.args['lr_max_day']
+ lr_min_day = self.args['lr_min_day']
+ lr_decay_steps_day = self.args['lr_decay_steps_day']
+
+ lr_warmup_steps = self.args['lr_warmup_steps']
+ lr_warmup_steps_day = self.args['lr_warmup_steps_day']
+
+ def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
+ '''
+ Create lr lambdas for each param group that implement cosine decay
+
+ Different lr lambda decaying for day params vs rest of the model
+ '''
+ # Warmup phase
+ if current_step < warmup_steps:
+ return float(current_step) / float(max(1, warmup_steps))
+
+ # Cosine decay phase
+ if current_step < decay_steps:
+ progress = float(current_step - warmup_steps) / float(
+ max(1, decay_steps - warmup_steps)
+ )
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
+ # Scale from 1.0 to min_lr_ratio
+ return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
+
+ # After cosine decay is complete, maintain min_lr_ratio
+ return min_lr_ratio
+
+ if len(optim.param_groups) == 3:
+ lr_lambdas = [
+ lambda step: lr_lambda(
+ step,
+ lr_min / lr_max,
+ lr_decay_steps,
+ lr_warmup_steps), # biases
+ lambda step: lr_lambda(
+ step,
+ lr_min_day / lr_max_day,
+ lr_decay_steps_day,
+ lr_warmup_steps_day,
+ ), # day params
+ lambda step: lr_lambda(
+ step,
+ lr_min / lr_max,
+ lr_decay_steps,
+ lr_warmup_steps), # rest of model weights
+ ]
+ elif len(optim.param_groups) == 2:
+ lr_lambdas = [
+ lambda step: lr_lambda(
+ step,
+ lr_min / lr_max,
+ lr_decay_steps,
+ lr_warmup_steps), # biases
+ lambda step: lr_lambda(
+ step,
+ lr_min / lr_max,
+ lr_decay_steps,
+ lr_warmup_steps), # rest of model weights
+ ]
+ else:
+ raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
+
+ return LambdaLR(optim, lr_lambdas, -1)
+
+ def load_model_checkpoint(self, load_path):
+ '''
+ Load a training checkpoint for distributed training
+ '''
+ # Load checkpoint on CPU first to avoid OOM issues
+ checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
+
+ # Get unwrapped model for loading state dict
+ unwrapped_model = self.accelerator.unwrap_model(self.model)
+ unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
+
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
+ self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
+ self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
+
+ # Device handling is managed by Accelerator, no need to manually move to device
+
+ self.logger.info("Loaded model from checkpoint: " + load_path)
+
+ def save_model_checkpoint(self, save_path, PER, loss):
+ '''
+ Save a training checkpoint using Accelerator for distributed training
+ '''
+ # Only save on main process to avoid conflicts
+ if self.accelerator.is_main_process:
+ # Unwrap model to get base model for saving
+ unwrapped_model = self.accelerator.unwrap_model(self.model)
+
+ checkpoint = {
+ 'model_state_dict' : unwrapped_model.state_dict(),
+ 'optimizer_state_dict' : self.optimizer.state_dict(),
+ 'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
+ 'val_PER' : PER,
+ 'val_loss' : loss
+ }
+
+ torch.save(checkpoint, save_path)
+
+ self.logger.info("Saved model to checkpoint: " + save_path)
+
+ # Save the args file alongside the checkpoint
+ with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
+ OmegaConf.save(config=self.args, f=f)
+
+ # Wait for all processes to complete checkpoint saving
+ self.accelerator.wait_for_everyone()
+
+ def create_attention_mask(self, sequence_lengths):
+
+ max_length = torch.max(sequence_lengths).item()
+
+ batch_size = sequence_lengths.size(0)
+
+ # Create a mask for valid key positions (columns)
+ # Shape: [batch_size, max_length]
+ key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
+ key_mask = key_mask < sequence_lengths.unsqueeze(1)
+
+ # Expand key_mask to [batch_size, 1, 1, max_length]
+ # This will be broadcast across all query positions
+ key_mask = key_mask.unsqueeze(1).unsqueeze(1)
+
+ # Create the attention mask of shape [batch_size, 1, max_length, max_length]
+ # by broadcasting key_mask across all query positions
+ attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
+
+ # Convert boolean mask to float mask:
+ # - True (valid key positions) -> 0.0 (no change to attention scores)
+ # - False (padding key positions) -> -inf (will become 0 after softmax)
+ attention_mask_float = torch.where(attention_mask,
+ True,
+ False)
+
+ return attention_mask_float
+
+ def transform_data(self, features, n_time_steps, mode = 'train'):
+ '''
+ Apply various augmentations and smoothing to data
+ Performing augmentations is much faster on GPU than CPU
+ '''
+
+ # TPU and GPU should now handle data consistently with our improved DataLoader configuration
+
+ data_shape = features.shape
+ batch_size = data_shape[0]
+ channels = data_shape[-1]
+
+ # We only apply these augmentations in training
+ if mode == 'train':
+ # add static gain noise
+ if self.transform_args['static_gain_std'] > 0:
+ warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
+ warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
+
+ features = torch.matmul(features, warp_mat)
+
+ # add white noise
+ if self.transform_args['white_noise_std'] > 0:
+ features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
+
+ # add constant offset noise
+ if self.transform_args['constant_offset_std'] > 0:
+ features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
+
+ # add random walk noise
+ if self.transform_args['random_walk_std'] > 0:
+ features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
+
+ # randomly cutoff part of the data timecourse
+ if self.transform_args['random_cut'] > 0:
+ cut = np.random.randint(0, self.transform_args['random_cut'])
+ features = features[:, cut:, :]
+ n_time_steps = n_time_steps - cut
+
+ # Apply Gaussian smoothing to data
+ # This is done in both training and validation
+ if self.transform_args['smooth_data']:
+ features = gauss_smooth(
+ inputs = features,
+ device = self.device,
+ smooth_kernel_std = self.transform_args['smooth_kernel_std'],
+ smooth_kernel_size= self.transform_args['smooth_kernel_size'],
+ )
+
+ if hasattr(self, 'model_dtype'):
+ features = features.to(self.model_dtype)
+
+
+ return features, n_time_steps
+
+ def train(self):
+ '''
+ Train the model
+ '''
+
+ # Set model to train mode (specificially to make sure dropout layers are engaged)
+ self.model.train()
+
+ # create vars to track performance
+ train_losses = []
+ val_losses = []
+ val_PERs = []
+ val_results = []
+
+ val_steps_since_improvement = 0
+
+ # training params
+ save_best_checkpoint = self.args.get('save_best_checkpoint', True)
+ early_stopping = self.args.get('early_stopping', True)
+
+ early_stopping_val_steps = self.args['early_stopping_val_steps']
+
+ train_start_time = time.time()
+
+ # train for specified number of batches
+ self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...")
+ for i, batch in enumerate(self.train_loader):
+
+ self.model.train()
+ self.optimizer.zero_grad()
+
+ # Train step
+ start_time = time.time()
+
+ # Data is automatically moved to device by Accelerator
+ features = batch['input_features']
+ labels = batch['seq_class_ids']
+ n_time_steps = batch['n_time_steps']
+ phone_seq_lens = batch['phone_seq_lens']
+ day_indicies = batch['day_indicies']
+
+ # Use Accelerator's autocast (mixed precision handled by Accelerator init)
+ with self.autocast_context():
+
+ # Apply augmentations to the data
+ features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
+
+ # Ensure proper dtype handling for TPU mixed precision
+ adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
+
+ # Get phoneme predictions using inference mode during training
+ # (We use inference mode for simplicity - only clean logits are used for CTC loss)
+ # Ensure features tensor matches model parameter dtype for TPU compatibility
+ if features.dtype != self.model_dtype:
+ features = features.to(self.model_dtype)
+
+ # Forward pass: enable full adversarial mode if configured and past warmup
+ use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
+ if use_full:
+ clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
+ else:
+ logits = self.model(features, day_indicies, None, False, 'inference')
+
+ # Calculate CTC Loss
+ if use_full:
+ # Clean CTC loss
+ clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2)
+ clean_loss = self.ctc_loss(
+ clean_log_probs,
+ labels,
+ adjusted_lens,
+ phone_seq_lens
+ )
+ clean_loss = torch.mean(clean_loss)
+
+ # Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
+ noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2)
+ noisy_loss = self.ctc_loss(
+ noisy_log_probs,
+ labels,
+ adjusted_lens,
+ phone_seq_lens
+ )
+ noisy_loss = torch.mean(noisy_loss)
+
+ # Optional noise energy regularization
+ noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
+ if self.adv_noise_l2_weight > 0.0:
+ noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype)
+
+ loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
+ else:
+ log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
+ loss = self.ctc_loss(
+ log_probs=log_probs,
+ targets=labels,
+ input_lengths=adjusted_lens,
+ target_lengths=phone_seq_lens
+ )
+ loss = torch.mean(loss) # take mean loss over batches
+
+ # Use Accelerator's backward for distributed training
+ self.accelerator.backward(loss)
+
+ # Clip gradient using Accelerator's clip_grad_norm_
+ if self.args['grad_norm_clip_value'] > 0:
+ grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(),
+ max_norm = self.args['grad_norm_clip_value'])
+
+ self.optimizer.step()
+ self.learning_rate_scheduler.step()
+
+ # Save training metrics
+ train_step_duration = time.time() - start_time
+ train_losses.append(loss.detach().item())
+
+ # Incrementally log training progress
+ if i % self.args['batches_per_train_log'] == 0:
+ self.logger.info(f'Train batch {i}: ' +
+ f'loss: {(loss.detach().item()):.2f} ' +
+ f'grad norm: {grad_norm:.2f} '
+ f'time: {train_step_duration:.3f}')
+
+ # Incrementally run a test step
+ if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
+ self.logger.info(f"Running test after training batch: {i}")
+
+ # Calculate metrics on val data
+ start_time = time.time()
+ val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
+ val_step_duration = time.time() - start_time
+
+
+ # Log info
+ self.logger.info(f'Val batch {i}: ' +
+ f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
+ f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
+ f'time: {val_step_duration:.3f}')
+
+ if self.args['log_individual_day_val_PER']:
+ for day in val_metrics['day_PERs'].keys():
+ self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
+
+ # Save metrics
+ val_PERs.append(val_metrics['avg_PER'])
+ val_losses.append(val_metrics['avg_loss'])
+ val_results.append(val_metrics)
+
+ # Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
+ new_best = False
+ if val_metrics['avg_PER'] < self.best_val_PER:
+ self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
+ self.best_val_PER = val_metrics['avg_PER']
+ self.best_val_loss = val_metrics['avg_loss']
+ new_best = True
+ elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss):
+ self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
+ self.best_val_loss = val_metrics['avg_loss']
+ new_best = True
+
+ if new_best:
+
+ # Checkpoint if metrics have improved
+ if save_best_checkpoint:
+ self.logger.info(f"Checkpointing model")
+ self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
+
+ # save validation metrics to pickle file
+ if self.args['save_val_metrics']:
+ with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
+ pickle.dump(val_metrics, f)
+
+ val_steps_since_improvement = 0
+
+ else:
+ val_steps_since_improvement +=1
+
+ # Optionally save this validation checkpoint, regardless of performance
+ if self.args['save_all_val_steps']:
+ self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
+
+ # Early stopping
+ if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
+ self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
+ break
+
+ # Log final training steps
+ training_duration = time.time() - train_start_time
+
+
+ self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
+ self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
+
+ # Save final model
+ if self.args['save_final_model']:
+ last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
+ self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
+
+ train_stats = {}
+ train_stats['train_losses'] = train_losses
+ train_stats['val_losses'] = val_losses
+ train_stats['val_PERs'] = val_PERs
+ train_stats['val_metrics'] = val_results
+
+ return train_stats
+
+ def validation(self, loader, return_logits = False, return_data = False):
+ '''
+ Calculate metrics on the validation dataset
+ '''
+ self.model.eval()
+
+ metrics = {}
+
+ # Record metrics
+ if return_logits:
+ metrics['logits'] = []
+ metrics['n_time_steps'] = []
+
+ if return_data:
+ metrics['input_features'] = []
+
+ metrics['decoded_seqs'] = []
+ metrics['true_seq'] = []
+ metrics['phone_seq_lens'] = []
+ metrics['transcription'] = []
+ metrics['losses'] = []
+ metrics['block_nums'] = []
+ metrics['trial_nums'] = []
+ metrics['day_indicies'] = []
+
+ total_edit_distance = 0
+ total_seq_length = 0
+
+ # Calculate PER for each specific day
+ day_per = {}
+ for d in range(len(self.args['dataset']['sessions'])):
+ if self.args['dataset']['dataset_probability_val'][d] == 1:
+ day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
+
+ for i, batch in enumerate(loader):
+
+ # Data is automatically moved to device by Accelerator
+ features = batch['input_features']
+ labels = batch['seq_class_ids']
+ n_time_steps = batch['n_time_steps']
+ phone_seq_lens = batch['phone_seq_lens']
+ day_indicies = batch['day_indicies']
+
+ # Determine if we should perform validation on this batch
+ day = day_indicies[0].item()
+ if self.args['dataset']['dataset_probability_val'][day] == 0:
+ if self.args['log_val_skip_logs']:
+ self.logger.info(f"Skipping validation on day {day}")
+ continue
+
+ with torch.no_grad():
+
+ with self.autocast_context():
+ features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
+
+ # Ensure proper dtype handling for TPU mixed precision
+ adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
+
+ # Ensure features tensor matches model parameter dtype for TPU compatibility
+ model_param = next(self.model.parameters()) if self.model is not None else None
+ if model_param is not None and features.dtype != model_param.dtype:
+ features = features.to(model_param.dtype)
+
+ logits = self.model(features, day_indicies, None, False, 'inference')
+
+ val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
+ loss = self.ctc_loss(
+ val_log_probs,
+ labels,
+ adjusted_lens,
+ phone_seq_lens,
+ )
+ loss = torch.mean(loss)
+
+ metrics['losses'].append(loss.cpu().detach().numpy())
+
+ # Calculate PER per day and also avg over entire validation set
+ batch_edit_distance = 0
+ decoded_seqs = []
+ for iterIdx in range(logits.shape[0]):
+ decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
+ decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
+ decoded_seq = decoded_seq.cpu().detach().numpy()
+ decoded_seq = np.array([i for i in decoded_seq if i != 0])
+
+ trueSeq = np.array(
+ labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
+ )
+
+ batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
+
+ decoded_seqs.append(decoded_seq)
+
+ day = batch['day_indicies'][0].item()
+
+ day_per[day]['total_edit_distance'] += batch_edit_distance
+ day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
+
+
+ total_edit_distance += batch_edit_distance
+ total_seq_length += torch.sum(phone_seq_lens)
+
+ # Record metrics
+ if return_logits:
+ metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
+ metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
+
+ if return_data:
+ metrics['input_features'].append(batch['input_features'].cpu().numpy())
+
+ metrics['decoded_seqs'].append(decoded_seqs)
+ metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
+ metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
+ metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
+ metrics['losses'].append(loss.detach().item())
+ metrics['block_nums'].append(batch['block_nums'].numpy())
+ metrics['trial_nums'].append(batch['trial_nums'].numpy())
+ metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
+
+ if isinstance(total_seq_length, torch.Tensor):
+ total_length_value = float(total_seq_length.item())
+ else:
+ total_length_value = float(total_seq_length)
+
+ avg_PER = total_edit_distance / max(total_length_value, 1e-6)
+
+ metrics['day_PERs'] = day_per
+ metrics['avg_PER'] = avg_PER
+ metrics['avg_loss'] = float(np.mean(metrics['losses']))
+
+ return metrics
+
+ def inference(self, features, day_indicies, n_time_steps, mode='inference'):
+ '''
+ TPU-compatible inference method for generating phoneme logits
+ '''
+ self.model.eval()
+
+ with torch.no_grad():
+ with self.autocast_context():
+ # Apply data transformations (no augmentation for inference)
+ features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
+
+ # Ensure features tensor matches model parameter dtype for TPU compatibility
+ if features.dtype != self.model_dtype:
+ features = features.to(self.model_dtype)
+
+ # Get phoneme predictions
+ logits = self.model(features, day_indicies, None, False, mode)
+
+ return logits
+
+ def inference_batch(self, batch, mode='inference'):
+ '''
+ Inference method for processing a full batch
+ '''
+ self.model.eval()
+
+ # Data is automatically moved to device by Accelerator
+ features = batch['input_features']
+ day_indicies = batch['day_indicies']
+ n_time_steps = batch['n_time_steps']
+
+ with torch.no_grad():
+ with self.autocast_context():
+ # Apply data transformations (no augmentation for inference)
+ features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
+
+ # Calculate adjusted sequence lengths for CTC with proper dtype handling
+ adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
+
+ # Ensure features tensor matches model parameter dtype for TPU compatibility
+ if features.dtype != self.model_dtype:
+ features = features.to(self.model_dtype)
+
+ # Get phoneme predictions
+ logits = self.model(features, day_indicies, None, False, mode)
+
+ return logits, adjusted_lens
\ No newline at end of file
diff --git a/model_training_nnn_tpu/start_tpu_training.sh b/model_training_nnn_tpu/start_tpu_training.sh
new file mode 100644
index 0000000..a09022a
--- /dev/null
+++ b/model_training_nnn_tpu/start_tpu_training.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# TPU XLA Multi-threading Environment Setup
+# Set these BEFORE starting Python to ensure they take effect
+
+echo "Setting up XLA multi-threading environment..."
+
+# Get CPU core count
+CPU_CORES=$(nproc)
+echo "Detected $CPU_CORES CPU cores"
+
+# Set XLA compilation flags
+export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES"
+export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES
+
+# Additional XLA optimizations
+export XLA_USE_BF16=1
+export TPU_CORES=8
+
+# Print current settings
+echo "XLA_FLAGS: $XLA_FLAGS"
+echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS"
+echo "XLA_USE_BF16: $XLA_USE_BF16"
+
+# Start training
+echo "Starting TPU training..."
+python train_model.py --config_path rnn_args.yaml
\ No newline at end of file
diff --git a/model_training_nnn_tpu/test_simple_model.py b/model_training_nnn_tpu/test_simple_model.py
new file mode 100644
index 0000000..4a18c5a
--- /dev/null
+++ b/model_training_nnn_tpu/test_simple_model.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python3
+"""
+简化模型测试脚本 - 验证XLA编译是否正常工作
+"""
+
+import os
+import time
+import torch
+import torch.nn as nn
+
+# 设置XLA环境变量(必须在导入torch_xla之前)
+os.environ['XLA_FLAGS'] = (
+ '--xla_cpu_multi_thread_eigen=true '
+ '--xla_cpu_enable_fast_math=true '
+ f'--xla_force_host_platform_device_count={os.cpu_count()}'
+)
+os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
+os.environ['XLA_USE_BF16'] = '1'
+
+print(f"🔧 XLA环境变量设置:")
+print(f" CPU核心数: {os.cpu_count()}")
+print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
+print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
+
+import torch_xla.core.xla_model as xm
+
+class SimpleModel(nn.Module):
+ """简化的测试模型"""
+ def __init__(self):
+ super().__init__()
+ self.linear1 = nn.Linear(512, 256)
+ self.gru = nn.GRU(256, 128, batch_first=True)
+ self.linear2 = nn.Linear(128, 41) # 41个音素类别
+
+ def forward(self, x):
+ x = torch.relu(self.linear1(x))
+ x, _ = self.gru(x)
+ x = self.linear2(x)
+ return x
+
+def test_xla_compilation():
+ """测试XLA编译速度"""
+ print("\n🚀 开始简化模型XLA编译测试...")
+
+ # 检查TPU设备
+ device = xm.xla_device()
+ print(f"📱 TPU设备: {device}")
+ print(f"🌍 TPU World Size: {xm.xrt_world_size()}")
+
+ # 创建简化模型
+ model = SimpleModel().to(device)
+ print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
+
+ # 创建测试数据
+ batch_size = 8 # 小批次
+ seq_len = 100 # 短序列
+ x = torch.randn(batch_size, seq_len, 512, device=device)
+
+ print(f"📥 输入形状: {x.shape}")
+
+ # 首次前向传播 - 触发XLA编译
+ print(f"🔄 开始首次前向传播 (XLA编译)...")
+ start_time = time.time()
+
+ with torch.no_grad():
+ output = model(x)
+
+ compile_time = time.time() - start_time
+ print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}秒")
+ print(f"📤 输出形状: {output.shape}")
+
+ # 再次前向传播 - 使用编译后的图
+ print(f"🔄 第二次前向传播 (使用编译后的图)...")
+ start_time = time.time()
+
+ with torch.no_grad():
+ output2 = model(x)
+
+ execution_time = time.time() - start_time
+ print(f"⚡ 执行完成! 耗时: {execution_time:.4f}秒")
+
+ # 性能对比
+ speedup = compile_time / execution_time if execution_time > 0 else float('inf')
+ print(f"\n📈 性能分析:")
+ print(f" 编译时间: {compile_time:.2f}秒")
+ print(f" 执行时间: {execution_time:.4f}秒")
+ print(f" 加速比: {speedup:.1f}x")
+
+ if compile_time < 60: # 1分钟内编译完成
+ print("✅ XLA编译正常!")
+ return True
+ else:
+ print("❌ XLA编译过慢,可能有问题")
+ return False
+
+def test_training_step():
+ """测试训练步骤"""
+ print("\n🎯 测试简化训练步骤...")
+
+ device = xm.xla_device()
+ model = SimpleModel().to(device)
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+ criterion = nn.CrossEntropyLoss()
+
+ # 创建训练数据
+ x = torch.randn(4, 50, 512, device=device)
+ labels = torch.randint(0, 41, (4, 50), device=device)
+
+ print(f"🔄 开始训练步骤 (包含反向传播)...")
+ start_time = time.time()
+
+ # 前向传播
+ outputs = model(x)
+
+ # 计算损失
+ loss = criterion(outputs.view(-1, 41), labels.view(-1))
+
+ # 反向传播
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ step_time = time.time() - start_time
+ print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}")
+
+ return step_time < 120 # 2分钟内完成
+
+def main():
+ print("=" * 60)
+ print("🧪 XLA编译快速测试")
+ print("=" * 60)
+
+ try:
+ # 测试1: 简单模型编译
+ compilation_ok = test_xla_compilation()
+
+ if compilation_ok:
+ # 测试2: 训练步骤
+ training_ok = test_training_step()
+
+ if training_ok:
+ print("\n✅ 所有测试通过! 可以尝试完整模型训练")
+ print("💡 建议:")
+ print(" 1. 确保有足够内存 (32GB+)")
+ print(" 2. 减小batch_size (比如从32改为16)")
+ print(" 3. 使用gradient_accumulation_steps补偿")
+ else:
+ print("\n⚠️ 训练步骤较慢,建议优化")
+ else:
+ print("\n❌ XLA编译有问题,需要检查环境")
+
+ except Exception as e:
+ print(f"\n💥 测试失败: {e}")
+ print("💡 可能的问题:")
+ print(" - TPU资源不可用")
+ print(" - PyTorch XLA安装问题")
+ print(" - 内存不足")
+
+ print("=" * 60)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/model_training_nnn_tpu/train_model.py b/model_training_nnn_tpu/train_model.py
new file mode 100644
index 0000000..81390c2
--- /dev/null
+++ b/model_training_nnn_tpu/train_model.py
@@ -0,0 +1,25 @@
+import argparse
+from omegaconf import OmegaConf
+from rnn_trainer import BrainToTextDecoder_Trainer
+
+def main():
+ parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
+ parser.add_argument('--config_path', default='rnn_args.yaml',
+ help='Path to configuration file (default: rnn_args.yaml)')
+
+ args = parser.parse_args()
+
+ # Load configuration
+ config = OmegaConf.load(args.config_path)
+
+ # Initialize trainer
+ trainer = BrainToTextDecoder_Trainer(config)
+
+ # Start training
+ trainer.train()
+
+ print("Training completed successfully!")
+ print(f"Best validation PER: {trainer.best_val_PER:.5f}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file