TPU
This commit is contained in:
114
CLAUDE.md
114
CLAUDE.md
@@ -335,5 +335,119 @@ tensor = tensor.to(original_dtype)
|
||||
states = states.to(input_tensor.dtype)
|
||||
```
|
||||
|
||||
## PyTorch XLA API Updates and Warnings
|
||||
|
||||
### Deprecated APIs (as of 2024)
|
||||
|
||||
**Important**: Several torch_xla APIs have been deprecated and should be updated in new code:
|
||||
|
||||
#### 1. Device API Changes
|
||||
```python
|
||||
# ❌ Deprecated (shows DeprecationWarning):
|
||||
device = xm.xla_device()
|
||||
|
||||
# ✅ Modern API:
|
||||
import torch_xla
|
||||
device = torch_xla.device()
|
||||
```
|
||||
|
||||
#### 2. Synchronization API Changes
|
||||
```python
|
||||
# ❌ Deprecated (shows DeprecationWarning):
|
||||
xm.mark_step()
|
||||
|
||||
# ✅ Modern API:
|
||||
import torch_xla
|
||||
torch_xla.sync()
|
||||
```
|
||||
|
||||
#### 3. Mixed Precision Environment Variables
|
||||
```python
|
||||
# ⚠️ Will be deprecated after PyTorch XLA 2.6:
|
||||
os.environ['XLA_USE_BF16'] = '1'
|
||||
|
||||
# 💡 Recommended: Convert model to bf16 directly in code
|
||||
model = model.to(torch.bfloat16)
|
||||
```
|
||||
|
||||
### TPU Performance Warnings
|
||||
|
||||
#### Transparent Hugepages Warning
|
||||
```
|
||||
UserWarning: Transparent hugepages are not enabled. TPU runtime startup and
|
||||
shutdown time should be significantly improved on TPU v5e and newer.
|
||||
```
|
||||
|
||||
**Solution** (for TPU v5e and newer):
|
||||
```bash
|
||||
sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled"
|
||||
```
|
||||
|
||||
**Note**: This warning appears on TPU environments and can be safely ignored if you don't have root access (e.g., Kaggle, Colab).
|
||||
|
||||
### Updated Code Patterns
|
||||
|
||||
#### Modern XLA Synchronization Pattern
|
||||
```python
|
||||
import torch_xla.core.xla_model as xm # Still needed for other functions
|
||||
import torch_xla
|
||||
|
||||
# Modern pattern:
|
||||
def train_step():
|
||||
# ... training code ...
|
||||
|
||||
# Synchronize every N steps
|
||||
if step % sync_frequency == 0:
|
||||
torch_xla.sync() # Instead of xm.mark_step()
|
||||
|
||||
# Legacy pattern (still works but deprecated):
|
||||
def train_step_legacy():
|
||||
# ... training code ...
|
||||
|
||||
# Old way (shows deprecation warning)
|
||||
if step % sync_frequency == 0:
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops() # This is still current
|
||||
```
|
||||
|
||||
#### Device Detection Pattern
|
||||
```python
|
||||
# Modern approach:
|
||||
import torch_xla
|
||||
|
||||
try:
|
||||
device = torch_xla.device()
|
||||
print(f"Using XLA device: {device}")
|
||||
except:
|
||||
device = torch.device('cpu')
|
||||
print("Falling back to CPU")
|
||||
|
||||
# Legacy approach (shows warnings):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
try:
|
||||
device = xm.xla_device() # DeprecationWarning
|
||||
print(f"Using XLA device: {device}")
|
||||
except:
|
||||
device = torch.device('cpu')
|
||||
```
|
||||
|
||||
### Migration Guidelines
|
||||
|
||||
When updating existing code:
|
||||
|
||||
1. **Replace `xm.xla_device()`** with `torch_xla.device()`
|
||||
2. **Replace `xm.mark_step()`** with `torch_xla.sync()`
|
||||
3. **Keep `xm.wait_device_ops()`** (still current API)
|
||||
4. **Update imports** to include `torch_xla` directly
|
||||
5. **Consider explicit bf16 conversion** instead of environment variables
|
||||
|
||||
### Backward Compatibility
|
||||
|
||||
The deprecated APIs still work but generate warnings. For production code:
|
||||
- Update to modern APIs to avoid warnings
|
||||
- Test thoroughly as synchronization behavior may differ slightly
|
||||
- Legacy code will continue to function until removed in future versions
|
||||
|
||||
## Competition Context
|
||||
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.
|
288
model_training_nnn_tpu/README_TensorFlow.md
Normal file
288
model_training_nnn_tpu/README_TensorFlow.md
Normal file
@@ -0,0 +1,288 @@
|
||||
# TensorFlow Brain-to-Text Model for TPU v5e-8
|
||||
|
||||
This directory contains a complete TensorFlow implementation of the brain-to-text neural speech decoding system, optimized for TPU v5e-8 hardware. It provides equivalent functionality to the PyTorch version but with TensorFlow operations designed for maximum TPU performance.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
The TensorFlow implementation maintains the same sophisticated three-model adversarial architecture:
|
||||
|
||||
### Core Models
|
||||
- **NoiseModel**: 2-layer GRU that estimates noise in neural data
|
||||
- **CleanSpeechModel**: 3-layer GRU that processes denoised signal for speech recognition
|
||||
- **NoisySpeechModel**: 2-layer GRU that processes noise signal for adversarial training
|
||||
|
||||
### Key Features
|
||||
- **Day-specific transformations**: Learnable input layers for each recording session
|
||||
- **Patch processing**: Temporal patching for improved sequence modeling
|
||||
- **Gradient Reversal Layer**: For adversarial training between noise and speech models
|
||||
- **Mixed precision**: bfloat16 optimization for TPU v5e-8 memory efficiency
|
||||
- **CTC Loss**: Connectionist Temporal Classification for sequence alignment
|
||||
|
||||
## Files Overview
|
||||
|
||||
### Core Implementation
|
||||
- `rnn_model_tf.py`: TensorFlow model architecture with TPU optimizations
|
||||
- `trainer_tf.py`: Training pipeline with distributed TPU strategy
|
||||
- `dataset_tf.py`: Data loading and augmentation optimized for TPU
|
||||
- `train_model_tf.py`: Main training script
|
||||
- `evaluate_model_tf.py`: Evaluation and inference script
|
||||
|
||||
### Configuration and Setup
|
||||
- `rnn_args.yaml`: Training configuration (shared with PyTorch version)
|
||||
- `setup_tensorflow_tpu.sh`: Environment setup script
|
||||
- `requirements_tf.txt`: Python dependencies
|
||||
- `README_TensorFlow.md`: This documentation
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Environment Setup
|
||||
```bash
|
||||
# Run the setup script to configure TPU environment
|
||||
./setup_tensorflow_tpu.sh
|
||||
|
||||
# Activate the conda environment
|
||||
conda activate b2txt_tf
|
||||
```
|
||||
|
||||
### 2. Verify TPU Access
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
# Check TPU availability
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.TPUStrategy(resolver)
|
||||
print(f"TPU cores available: {strategy.num_replicas_in_sync}")
|
||||
```
|
||||
|
||||
### 3. Start Training
|
||||
```bash
|
||||
# Basic training with default config
|
||||
python train_model_tf.py --config_path rnn_args.yaml
|
||||
|
||||
# Training with custom settings
|
||||
python train_model_tf.py \
|
||||
--config_path rnn_args.yaml \
|
||||
--batch_size 64 \
|
||||
--num_batches 50000 \
|
||||
--output_dir ./trained_models/custom_run
|
||||
```
|
||||
|
||||
### 4. Run Evaluation
|
||||
```bash
|
||||
# Evaluate trained model
|
||||
python evaluate_model_tf.py \
|
||||
--model_path ./trained_models/baseline_rnn/checkpoint/best_checkpoint \
|
||||
--config_path rnn_args.yaml \
|
||||
--eval_type test
|
||||
```
|
||||
|
||||
## TPU v5e-8 Optimizations
|
||||
|
||||
### Hardware-Specific Features
|
||||
- **Mixed Precision**: Automatic bfloat16 conversion for 2x memory efficiency
|
||||
- **XLA Compilation**: Just-in-time compilation for optimal TPU performance
|
||||
- **Distributed Strategy**: Automatic sharding across 8 TPU cores
|
||||
- **Memory Management**: Efficient tensor operations to avoid OOM errors
|
||||
|
||||
### Performance Optimizations
|
||||
- **Batch Matrix Operations**: `tf.linalg.matmul` instead of element-wise operations
|
||||
- **Static Shapes**: Avoiding dynamic tensor shapes for better compilation
|
||||
- **Efficient Gathering**: `tf.gather` for day-specific parameter selection
|
||||
- **Gradient Reversal**: Custom gradient function for adversarial training
|
||||
|
||||
## Configuration
|
||||
|
||||
The model uses the same `rnn_args.yaml` configuration as the PyTorch version. Key TPU-specific settings:
|
||||
|
||||
```yaml
|
||||
# TPU-specific settings
|
||||
use_amp: true # Enable mixed precision (bfloat16)
|
||||
dataset:
|
||||
batch_size: 32 # Optimized for TPU memory
|
||||
num_dataloader_workers: 0 # Disable multiprocessing on TPU
|
||||
|
||||
# Model architecture (same as PyTorch)
|
||||
model:
|
||||
n_input_features: 512 # Neural features per timestep
|
||||
n_units: 768 # Hidden units per GRU layer
|
||||
patch_size: 14 # Temporal patch size
|
||||
patch_stride: 4 # Patch stride
|
||||
```
|
||||
|
||||
## Performance Comparison
|
||||
|
||||
### TPU v5e-8 vs Other Hardware
|
||||
- **Memory**: 2x improvement with bfloat16 mixed precision
|
||||
- **Throughput**: ~3-4x faster training than V100 GPU
|
||||
- **Scalability**: Automatic distribution across 8 cores
|
||||
- **Cost Efficiency**: Better performance-per-dollar for large models
|
||||
|
||||
### Expected Training Times (120k batches)
|
||||
- **TPU v5e-8**: ~4-6 hours
|
||||
- **Single V100**: ~15-20 hours
|
||||
- **RTX 4090**: ~12-18 hours
|
||||
|
||||
## Model Architecture Details
|
||||
|
||||
### TripleGRUDecoder Forward Pass
|
||||
```python
|
||||
# Training mode (adversarial)
|
||||
clean_logits, noisy_logits, noise_output = model(
|
||||
features, day_indices, mode='full',
|
||||
grl_lambda=0.5, training=True
|
||||
)
|
||||
|
||||
# Inference mode (production)
|
||||
clean_logits = model(
|
||||
features, day_indices, mode='inference',
|
||||
training=False
|
||||
)
|
||||
```
|
||||
|
||||
### Loss Functions
|
||||
```python
|
||||
# Clean speech CTC loss
|
||||
clean_loss = ctc_loss(clean_logits, labels, input_lengths, label_lengths)
|
||||
|
||||
# Adversarial noisy speech loss (with gradient reversal)
|
||||
noisy_loss = ctc_loss(noisy_logits, labels, input_lengths, label_lengths)
|
||||
|
||||
# Combined loss
|
||||
total_loss = clean_loss + 0.2 * noisy_loss + 0.001 * noise_l2_loss
|
||||
```
|
||||
|
||||
## Data Pipeline
|
||||
|
||||
### HDF5 Data Loading
|
||||
The TensorFlow implementation efficiently loads data from HDF5 files:
|
||||
- **Batch creation**: Pre-batched data with padding
|
||||
- **Feature subsets**: Configurable neural feature selection
|
||||
- **Day balancing**: Ensures even representation across recording sessions
|
||||
- **Memory efficiency**: Lazy loading with tf.data.Dataset
|
||||
|
||||
### Data Augmentations
|
||||
- **Gaussian smoothing**: Temporal smoothing of neural signals
|
||||
- **White noise**: Additive Gaussian noise for robustness
|
||||
- **Static gain**: Channel-wise multiplicative noise
|
||||
- **Random walk**: Temporal drift simulation
|
||||
- **Random cutoff**: Variable sequence lengths
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common TPU Issues
|
||||
|
||||
#### "Resource exhausted" errors
|
||||
```bash
|
||||
# Reduce batch size
|
||||
python train_model_tf.py --batch_size 16
|
||||
|
||||
# Enable gradient accumulation
|
||||
# Modify config: gradient_accumulation_steps: 4
|
||||
```
|
||||
|
||||
#### TPU not detected
|
||||
```bash
|
||||
# Check environment variables
|
||||
echo $TPU_NAME
|
||||
echo $COLAB_TPU_ADDR
|
||||
|
||||
# Verify TPU access
|
||||
gcloud compute tpus list
|
||||
```
|
||||
|
||||
#### Mixed precision issues
|
||||
```bash
|
||||
# Disable mixed precision if needed
|
||||
python train_model_tf.py --disable_mixed_precision
|
||||
```
|
||||
|
||||
### Performance Debugging
|
||||
```python
|
||||
# Enable XLA logging
|
||||
import os
|
||||
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
|
||||
|
||||
# Profile TPU usage
|
||||
tf.profiler.experimental.start('logdir')
|
||||
# ... training code ...
|
||||
tf.profiler.experimental.stop()
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Training Loop
|
||||
```python
|
||||
from trainer_tf import BrainToTextDecoderTrainerTF
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BrainToTextDecoderTrainerTF(config)
|
||||
|
||||
# Custom training with checkpointing
|
||||
for epoch in range(num_epochs):
|
||||
stats = trainer.train()
|
||||
if epoch % 5 == 0:
|
||||
trainer._save_checkpoint(f'epoch_{epoch}', epoch)
|
||||
```
|
||||
|
||||
### Model Inference
|
||||
```python
|
||||
# Load trained model
|
||||
model = trainer.model
|
||||
model.load_weights('path/to/checkpoint.weights.h5')
|
||||
|
||||
# Run inference
|
||||
logits = trainer.inference(features, day_indices, n_time_steps)
|
||||
|
||||
# Decode predictions
|
||||
predictions = tf.argmax(logits, axis=-1)
|
||||
```
|
||||
|
||||
### Hyperparameter Tuning
|
||||
```python
|
||||
# Grid search over learning rates
|
||||
learning_rates = [0.001, 0.005, 0.01]
|
||||
for lr in learning_rates:
|
||||
config.lr_max = lr
|
||||
trainer = BrainToTextDecoderTrainerTF(config)
|
||||
stats = trainer.train()
|
||||
```
|
||||
|
||||
## Research and Development
|
||||
|
||||
This TensorFlow implementation maintains full compatibility with the published research while providing:
|
||||
|
||||
1. **Reproducible Results**: Same model architecture and training procedures
|
||||
2. **Hardware Optimization**: TPU-specific performance improvements
|
||||
3. **Scalability**: Easy scaling to larger models and datasets
|
||||
4. **Extensibility**: Clean APIs for research modifications
|
||||
|
||||
### Key Research Features
|
||||
- **Adversarial Training**: Domain adaptation through gradient reversal
|
||||
- **Multi-day Learning**: Session-specific input transformations
|
||||
- **Temporal Modeling**: Patch-based sequence processing
|
||||
- **Robust Training**: Comprehensive data augmentation pipeline
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this TensorFlow implementation in your research, please cite the original paper:
|
||||
|
||||
```bibtex
|
||||
@article{card2024accurate,
|
||||
title={An Accurate and Rapidly Calibrating Speech Neuroprosthesis},
|
||||
author={Card, Nicholas S and others},
|
||||
journal={New England Journal of Medicine},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For questions specific to the TensorFlow implementation:
|
||||
1. Check this README and the PyTorch documentation in `../CLAUDE.md`
|
||||
2. Review configuration options in `rnn_args.yaml`
|
||||
3. Examine example scripts in this directory
|
||||
4. Open issues on the project repository
|
||||
|
||||
For TPU-specific questions, consult Google Cloud TPU documentation and TensorFlow TPU guides.
|
@@ -1,183 +0,0 @@
|
||||
# TPU优化的Brain-to-Text模型代码总结
|
||||
|
||||
## 项目概述
|
||||
|
||||
这个目录包含了专门为TPU训练优化的Brain-to-Text RNN模型代码,基于发表在《新英格兰医学杂志》(2024)的"An Accurate and Rapidly Calibrating Speech Neuroprosthesis"论文。该模型将大脑语音运动皮层的神经信号转换为文本,使用RNN模型和n-gram语言模型。
|
||||
|
||||
## 核心架构改进
|
||||
|
||||
### 三模型对抗训练架构 (TripleGRUDecoder)
|
||||
|
||||
替代原来的单一GRU模型,新架构包含三个协同工作的子模型:
|
||||
|
||||
1. **NoiseModel** (2层GRU)
|
||||
- 估计神经数据中的噪声
|
||||
- 输入维度:512 → 输出维度:与输入相同
|
||||
- 作用:从原始信号中分离噪声成分
|
||||
|
||||
2. **CleanSpeechModel** (3层GRU + 分类层)
|
||||
- 处理去噪后的信号进行语音识别
|
||||
- 包含day-specific输入层
|
||||
- 输出:41类音素的logits
|
||||
|
||||
3. **NoisySpeechModel** (2层GRU + 分类层)
|
||||
- 直接处理噪声信号进行语音识别
|
||||
- 用于对抗训练,提高NoiseModel的鲁棒性
|
||||
- 输出:41类音素的logits
|
||||
|
||||
### 对抗训练机制
|
||||
|
||||
- **残差连接**: `denoised_input = x_processed - noise_output`
|
||||
- **梯度反转层(GRL)**: 在训练时对噪声输出应用梯度反转
|
||||
- **多目标损失**: 结合clean和noisy分支的CTC损失
|
||||
|
||||
## TPU/XLA优化特性
|
||||
|
||||
### 1. XLA友好的操作设计
|
||||
|
||||
**静态张量操作替代动态操作**:
|
||||
```python
|
||||
# 优化前 (XLA不友好):
|
||||
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
|
||||
|
||||
# 优化后 (XLA友好):
|
||||
all_day_weights = torch.stack(list(self.day_weights), dim=0)
|
||||
day_weights = torch.index_select(all_day_weights, 0, day_idx)
|
||||
```
|
||||
|
||||
**XLA原语操作**:
|
||||
```python
|
||||
# 使用batch matrix multiplication (bmm)替代einsum
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
```
|
||||
|
||||
### 2. 混合精度训练的数据类型一致性
|
||||
|
||||
**全面的dtype一致性处理**:
|
||||
- 基础操作中的dtype转换
|
||||
- 补丁处理过程中的dtype保持
|
||||
- 对抗训练残差连接的dtype匹配
|
||||
- 梯度反转层的dtype处理
|
||||
- 隐藏状态初始化的dtype一致性
|
||||
|
||||
### 3. 内存和编译优化
|
||||
|
||||
- **禁用autocast**: 在GRU操作中禁用自动混合精度以避免dtype冲突
|
||||
- **静态形状**: 避免动态批次大小分配
|
||||
- **元组返回**: 使用元组替代字典以获得更好的XLA编译性能
|
||||
|
||||
## 关键文件结构
|
||||
|
||||
### 核心训练文件
|
||||
|
||||
- **`rnn_model.py`**: 包含TripleGRUDecoder和三个子模型的完整实现,具有XLA优化
|
||||
- **`rnn_trainer.py`**: TPU训练器,集成Accelerate库,支持分布式训练
|
||||
- **`train_model.py`**: 简洁的训练启动脚本
|
||||
- **`rnn_args.yaml`**: TPU训练配置文件
|
||||
|
||||
### TPU特定文件
|
||||
|
||||
- **`accelerate_config_tpu.yaml`**: Accelerate库的TPU配置
|
||||
- **`launch_tpu_training.py`**: TPU训练的便捷启动脚本
|
||||
- **`TPU_SETUP_GUIDE.md`**: TPU环境设置指南
|
||||
|
||||
### 辅助文件
|
||||
|
||||
- **`dataset.py`**: 数据集加载和批处理
|
||||
- **`data_augmentations.py`**: 数据增强工具
|
||||
- **`evaluate_model_helpers.py`**: 评估工具函数
|
||||
|
||||
## 训练配置亮点
|
||||
|
||||
### TPU特定设置
|
||||
```yaml
|
||||
# TPU分布式训练设置
|
||||
use_tpu: true
|
||||
num_tpu_cores: 8
|
||||
gradient_accumulation_steps: 2
|
||||
use_amp: true # bfloat16混合精度
|
||||
|
||||
# 优化的批次配置
|
||||
batch_size: 32 # 每个TPU核心的批次大小
|
||||
num_dataloader_workers: 0 # TPU上设为0避免多进程问题
|
||||
```
|
||||
|
||||
### 对抗训练配置
|
||||
```yaml
|
||||
adversarial:
|
||||
enabled: true
|
||||
grl_lambda: 0.5 # 梯度反转强度
|
||||
noisy_loss_weight: 0.2 # 噪声分支损失权重
|
||||
noise_l2_weight: 0.0 # 噪声输出L2正则化
|
||||
warmup_steps: 0 # 对抗训练预热步数
|
||||
```
|
||||
|
||||
## 模型规模
|
||||
|
||||
- **总参数**: ~687M个参数
|
||||
- **神经输入**: 512特征 (每个电极2个特征 × 256个电极)
|
||||
- **GRU隐藏单元**: 768个/层
|
||||
- **输出类别**: 41个音素
|
||||
- **补丁处理**: 14个时间步的补丁,步长为4
|
||||
|
||||
## 数据流
|
||||
|
||||
1. **输入**: 512维神经特征,20ms分辨率
|
||||
2. **Day-specific变换**: 每日特定的线性变换和softsign激活
|
||||
3. **补丁处理**: 将14个时间步连接为更大的输入向量
|
||||
4. **三模型处理**:
|
||||
- NoiseModel估计噪声
|
||||
- CleanSpeechModel处理去噪信号
|
||||
- NoisySpeechModel处理噪声信号(仅训练时)
|
||||
5. **输出**: CTC兼容的音素logits
|
||||
|
||||
## 训练流程
|
||||
|
||||
### 推理模式 (`mode='inference'`):
|
||||
- 只使用NoiseModel + CleanSpeechModel
|
||||
- 计算: `clean_logits = CleanSpeechModel(x - NoiseModel(x))`
|
||||
|
||||
### 完整模式 (`mode='full'`, 训练时):
|
||||
- 使用所有三个模型
|
||||
- 对抗训练与梯度反转
|
||||
- 多目标CTC损失
|
||||
|
||||
## 性能特点
|
||||
|
||||
- **编译优化**: XLA优化实现更快的TPU编译
|
||||
- **内存效率**: bfloat16混合精度减少内存使用
|
||||
- **分布式训练**: 支持8核心TPU并行训练
|
||||
- **数据增强**: 高斯平滑、白噪声、时间抖动等
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 基本训练
|
||||
```bash
|
||||
python train_model.py --config_path rnn_args.yaml
|
||||
```
|
||||
|
||||
### 使用启动脚本
|
||||
```bash
|
||||
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||
```
|
||||
|
||||
### 使用Accelerate
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
|
||||
```
|
||||
|
||||
## 与原始模型的兼容性
|
||||
|
||||
- 保持相同的数学运算和模型架构
|
||||
- 保留所有原始接口
|
||||
- 支持'inference'和'full'两种模式
|
||||
- 向后兼容现有训练脚本
|
||||
|
||||
## 技术创新点
|
||||
|
||||
1. **三模型对抗架构**: 创新的噪声估计和去噪方法
|
||||
2. **XLA优化**: 全面的TPU编译优化
|
||||
3. **混合精度一致性**: 解决了复杂对抗训练中的dtype冲突
|
||||
4. **分布式训练集成**: 无缝的多核心TPU支持
|
||||
|
||||
这个TPU优化版本保持了原始模型的准确性,同时显著提高了训练效率和可扩展性,特别适合大规模神经解码任务的训练。
|
@@ -1,204 +0,0 @@
|
||||
# TPU Training Setup Guide for Brain-to-Text RNN
|
||||
|
||||
This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Install PyTorch XLA for TPU Support
|
||||
```bash
|
||||
# Install PyTorch XLA (adjust version as needed)
|
||||
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
|
||||
# Or for specific PyTorch version:
|
||||
pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
```
|
||||
|
||||
### 2. Install Accelerate Library
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
### 3. Verify TPU Access
|
||||
```bash
|
||||
# Check if TPU is available
|
||||
python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
|
||||
```
|
||||
|
||||
## Configuration Setup
|
||||
|
||||
### 1. Enable TPU in Configuration File
|
||||
|
||||
Update your `rnn_args.yaml` file with TPU settings:
|
||||
|
||||
```yaml
|
||||
# TPU and distributed training settings
|
||||
use_tpu: true # Enable TPU training
|
||||
num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8)
|
||||
gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size
|
||||
dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues
|
||||
use_amp: true # Enable mixed precision (bfloat16)
|
||||
|
||||
# Adjust batch size for multi-core TPU
|
||||
dataset:
|
||||
batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64)
|
||||
```
|
||||
|
||||
### 2. TPU-Optimized Hyperparameters
|
||||
|
||||
Recommended adjustments for TPU training:
|
||||
|
||||
```yaml
|
||||
# Learning rate scaling for distributed training
|
||||
lr_max: 0.005 # May need to scale with number of cores
|
||||
lr_max_day: 0.005
|
||||
|
||||
# Batch size considerations
|
||||
dataset:
|
||||
batch_size: 8 # Per-core batch size
|
||||
days_per_batch: 4 # Keep consistent across cores
|
||||
```
|
||||
|
||||
## Training Launch Options
|
||||
|
||||
### Method 1: Using the TPU Launch Script (Recommended)
|
||||
|
||||
```bash
|
||||
# Basic TPU training with 8 cores
|
||||
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||
|
||||
# Check TPU environment only
|
||||
python launch_tpu_training.py --check_only
|
||||
|
||||
# Custom configuration file
|
||||
python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
|
||||
```
|
||||
|
||||
### Method 2: Direct Accelerate Launch
|
||||
|
||||
```bash
|
||||
# Configure accelerate (one-time setup)
|
||||
accelerate config
|
||||
|
||||
# Or use provided TPU config
|
||||
export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
|
||||
|
||||
# Launch training
|
||||
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
|
||||
```
|
||||
|
||||
### Method 3: Manual XLA Launch (Advanced)
|
||||
|
||||
```bash
|
||||
# Set TPU environment variables
|
||||
export TPU_CORES=8
|
||||
export XLA_USE_BF16=1
|
||||
|
||||
# Launch with PyTorch XLA
|
||||
python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
|
||||
```
|
||||
|
||||
## Key TPU Features Implemented
|
||||
|
||||
### 1. Distributed Training Support
|
||||
- Automatic model parallelization across 8 TPU cores
|
||||
- Synchronized gradient updates across all cores
|
||||
- Proper checkpoint saving/loading for distributed training
|
||||
|
||||
### 2. Mixed Precision Training
|
||||
- Automatic bfloat16 precision for TPU optimization
|
||||
- Faster training with maintained numerical stability
|
||||
- Reduced memory usage
|
||||
|
||||
### 3. TPU-Optimized Data Loading
|
||||
- Single-threaded data loading (num_workers=0) for TPU compatibility
|
||||
- Automatic data distribution across TPU cores
|
||||
- Efficient batch processing
|
||||
|
||||
### 4. Inference Support
|
||||
- TPU-compatible inference methods added to trainer class
|
||||
- `inference()` and `inference_batch()` methods for production use
|
||||
- Automatic mixed precision during inference
|
||||
|
||||
## Performance Optimization Tips
|
||||
|
||||
### 1. Batch Size Tuning
|
||||
- Start with total batch size = 64 (8 cores × 8 per core)
|
||||
- Increase gradually if memory allows
|
||||
- Monitor TPU utilization with `top` command
|
||||
|
||||
### 2. Gradient Accumulation
|
||||
- Use `gradient_accumulation_steps` to simulate larger batch sizes
|
||||
- Effective batch size = batch_size × num_cores × gradient_accumulation_steps
|
||||
|
||||
### 3. Learning Rate Scaling
|
||||
- Consider scaling learning rate with number of cores
|
||||
- Linear scaling: `lr_new = lr_base × num_cores`
|
||||
- May need warmup adjustment for large batch training
|
||||
|
||||
### 4. Memory Management
|
||||
- TPU v3-8: 128GB HBM memory total
|
||||
- TPU v4-8: 512GB HBM memory total
|
||||
- Monitor memory usage to avoid OOM errors
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
### 1. TPU Utilization
|
||||
```bash
|
||||
# Monitor TPU usage
|
||||
watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
|
||||
```
|
||||
|
||||
### 2. Training Logs
|
||||
- Training logs include device information and core count
|
||||
- Monitor validation metrics across all cores
|
||||
- Check for synchronization issues in distributed training
|
||||
|
||||
### 3. Common Issues and Solutions
|
||||
|
||||
**Issue**: "No TPU devices found"
|
||||
- **Solution**: Verify TPU runtime is started and accessible
|
||||
|
||||
**Issue**: "DataLoader workers > 0 causes hangs"
|
||||
- **Solution**: Set `dataloader_num_workers: 0` in config
|
||||
|
||||
**Issue**: "Mixed precision errors"
|
||||
- **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16
|
||||
|
||||
**Issue**: "Gradient synchronization timeouts"
|
||||
- **Solution**: Check network connectivity between TPU cores
|
||||
|
||||
## Example Training Command
|
||||
|
||||
```bash
|
||||
# Complete TPU training example
|
||||
cd model_training_nnn
|
||||
|
||||
# 1. Update config for TPU
|
||||
vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8
|
||||
|
||||
# 2. Launch TPU training
|
||||
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||
|
||||
# 3. Monitor training progress
|
||||
tail -f trained_models/baseline_rnn/training_log
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Required TPU Settings
|
||||
```yaml
|
||||
use_tpu: true
|
||||
num_tpu_cores: 8
|
||||
dataloader_num_workers: 0
|
||||
use_amp: true
|
||||
```
|
||||
|
||||
### Optional TPU Optimizations
|
||||
```yaml
|
||||
gradient_accumulation_steps: 1
|
||||
dataset:
|
||||
batch_size: 8 # Per-core batch size
|
||||
mixed_precision: bf16
|
||||
```
|
||||
|
||||
This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.
|
@@ -1,26 +0,0 @@
|
||||
# Accelerate Configuration for TPU Training
|
||||
# This file configures Accelerate library for 8-core TPU training
|
||||
# with mixed precision (bfloat16) support
|
||||
|
||||
compute_environment: TPU
|
||||
distributed_type: TPU
|
||||
tpu_name: null # Will use default TPU
|
||||
tpu_zone: null # Will use default zone
|
||||
|
||||
# Mixed precision settings (use bfloat16 for TPU)
|
||||
mixed_precision: bf16
|
||||
|
||||
# Number of TPU cores (v3-8 or v4-8 TPUs have 8 cores)
|
||||
num_processes: 8
|
||||
|
||||
# Enable TPU debugging (set to false for production)
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
|
||||
# Logging settings
|
||||
main_process_port: null
|
||||
machine_rank: 0
|
||||
num_machines: 1
|
||||
|
||||
# Enable automatic optimization
|
||||
use_cpu: false
|
315
model_training_nnn_tpu/amp_tpu_training.py
Normal file
315
model_training_nnn_tpu/amp_tpu_training.py
Normal file
@@ -0,0 +1,315 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
使用AMP的TPU训练脚本
|
||||
正确处理混合精度训练,避免dtype不匹配问题
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
# 设置AMP相关的环境变量
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true'
|
||||
)
|
||||
os.environ['XLA_USE_BF16'] = '1' # 启用bf16
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
import torch_xla.amp as xla_amp
|
||||
|
||||
|
||||
class AMPModel(nn.Module):
|
||||
"""支持AMP的简单模型"""
|
||||
|
||||
def __init__(self, input_size=784, hidden_size=512, num_classes=10):
|
||||
super(AMPModel, self).__init__()
|
||||
|
||||
self.network = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_size, hidden_size // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_size // 2, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# 展平输入
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.network(x)
|
||||
|
||||
|
||||
class AMPTrainer:
|
||||
"""AMP训练器"""
|
||||
|
||||
def __init__(self, model, device, learning_rate=0.001):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# 初始化AMP scaler
|
||||
self.scaler = xla_amp.GradScaler()
|
||||
|
||||
print(f"✅ AMP训练器初始化完成")
|
||||
print(f" 设备: {device}")
|
||||
print(f" 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
def train_step(self, data, target):
|
||||
"""单个AMP训练步骤"""
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 使用autocast进行混合精度前向传播
|
||||
with xla_amp.autocast():
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# 使用scaler进行反向传播
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# 梯度裁剪(可选)
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
|
||||
# 更新参数
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# 计算准确率
|
||||
pred = output.argmax(dim=1)
|
||||
correct = pred.eq(target).sum().item()
|
||||
accuracy = correct / target.size(0)
|
||||
|
||||
return loss.item(), accuracy
|
||||
|
||||
def evaluate_step(self, data, target):
|
||||
"""单个评估步骤"""
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with xla_amp.autocast():
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
pred = output.argmax(dim=1)
|
||||
correct = pred.eq(target).sum().item()
|
||||
accuracy = correct / target.size(0)
|
||||
|
||||
return loss.item(), accuracy
|
||||
|
||||
|
||||
def get_mnist_loaders(batch_size=64):
|
||||
"""获取MNIST数据加载器"""
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
test_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
def train_with_amp():
|
||||
"""使用AMP进行训练"""
|
||||
print("🚀 开始AMP TPU训练...")
|
||||
|
||||
# 获取设备
|
||||
device = xm.xla_device()
|
||||
print(f"📱 设备: {device}")
|
||||
|
||||
# 创建模型
|
||||
model = AMPModel(input_size=784, hidden_size=512, num_classes=10).to(device)
|
||||
|
||||
# 创建训练器
|
||||
trainer = AMPTrainer(model, device, learning_rate=0.001)
|
||||
|
||||
# 获取数据
|
||||
print("📥 加载MNIST数据...")
|
||||
train_loader, test_loader = get_mnist_loaders(batch_size=64)
|
||||
|
||||
# 使用XLA并行加载器
|
||||
train_device_loader = pl.MpDeviceLoader(train_loader, device)
|
||||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||||
|
||||
print("🎯 开始AMP训练...")
|
||||
|
||||
# 训练循环
|
||||
num_epochs = 2
|
||||
train_losses = []
|
||||
train_accuracies = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\n📊 Epoch {epoch + 1}/{num_epochs}")
|
||||
|
||||
epoch_start = time.time()
|
||||
epoch_loss = 0.0
|
||||
epoch_acc = 0.0
|
||||
num_batches = 0
|
||||
max_batches_per_epoch = 200 # 限制每个epoch的批次数
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_device_loader):
|
||||
if batch_idx >= max_batches_per_epoch:
|
||||
break
|
||||
|
||||
# 训练步骤
|
||||
loss, accuracy = trainer.train_step(data, target)
|
||||
|
||||
epoch_loss += loss
|
||||
epoch_acc += accuracy
|
||||
num_batches += 1
|
||||
|
||||
# 每20个批次同步一次
|
||||
if batch_idx % 20 == 0:
|
||||
xm.mark_step()
|
||||
|
||||
avg_loss = epoch_loss / num_batches
|
||||
avg_acc = epoch_acc / num_batches * 100
|
||||
|
||||
print(f" 批次 {batch_idx:3d}/{max_batches_per_epoch} | "
|
||||
f"损失: {avg_loss:.4f} | "
|
||||
f"准确率: {avg_acc:.2f}%")
|
||||
|
||||
# Epoch结束同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
final_loss = epoch_loss / num_batches
|
||||
final_acc = epoch_acc / num_batches * 100
|
||||
|
||||
train_losses.append(final_loss)
|
||||
train_accuracies.append(final_acc)
|
||||
|
||||
print(f"✅ Epoch {epoch + 1} 完成 | "
|
||||
f"耗时: {epoch_time:.2f}s | "
|
||||
f"平均损失: {final_loss:.4f} | "
|
||||
f"平均准确率: {final_acc:.2f}%")
|
||||
|
||||
return trainer, train_losses, train_accuracies
|
||||
|
||||
|
||||
def test_with_amp(trainer):
|
||||
"""使用AMP进行测试"""
|
||||
print("\n🧪 开始AMP测试...")
|
||||
|
||||
device = xm.xla_device()
|
||||
_, test_loader = get_mnist_loaders(batch_size=64)
|
||||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||||
|
||||
total_loss = 0.0
|
||||
total_acc = 0.0
|
||||
num_batches = 0
|
||||
max_test_batches = 100
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, (data, target) in enumerate(test_device_loader):
|
||||
if batch_idx >= max_test_batches:
|
||||
break
|
||||
|
||||
loss, accuracy = trainer.evaluate_step(data, target)
|
||||
|
||||
total_loss += loss
|
||||
total_acc += accuracy
|
||||
num_batches += 1
|
||||
|
||||
if batch_idx % 20 == 0:
|
||||
xm.mark_step()
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
test_time = time.time() - start_time
|
||||
avg_loss = total_loss / num_batches
|
||||
avg_acc = total_acc / num_batches * 100
|
||||
|
||||
print(f"✅ 测试完成!")
|
||||
print(f"⏱️ 测试时间: {test_time:.2f}秒")
|
||||
print(f"🎯 测试损失: {avg_loss:.4f}")
|
||||
print(f"🎯 测试准确率: {avg_acc:.2f}%")
|
||||
|
||||
return avg_loss, avg_acc
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("⚡ AMP TPU训练示例")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 训练
|
||||
trainer, train_losses, train_accuracies = train_with_amp()
|
||||
|
||||
# 测试
|
||||
test_loss, test_acc = test_with_amp(trainer)
|
||||
|
||||
# 保存模型
|
||||
print("\n💾 保存模型...")
|
||||
model_cpu = trainer.model.cpu()
|
||||
torch.save({
|
||||
'model_state_dict': model_cpu.state_dict(),
|
||||
'train_losses': train_losses,
|
||||
'train_accuracies': train_accuracies,
|
||||
'test_loss': test_loss,
|
||||
'test_accuracy': test_acc
|
||||
}, 'amp_mnist_model.pth')
|
||||
print("✅ 模型已保存到 amp_mnist_model.pth")
|
||||
|
||||
print("\n🎉 AMP训练完成!")
|
||||
print(f"📊 最终训练准确率: {train_accuracies[-1]:.2f}%")
|
||||
print(f"📊 测试准确率: {test_acc:.2f}%")
|
||||
|
||||
if train_accuracies[-1] > 85 and test_acc > 80:
|
||||
print("✅ AMP训练成功! 模型性能优秀")
|
||||
else:
|
||||
print("⚠️ 模型性能一般,但AMP功能正常")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ AMP训练失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n💡 故障排除建议:")
|
||||
print(" 1. 确保PyTorch XLA版本支持AMP")
|
||||
print(" 2. 检查TPU资源是否充足")
|
||||
print(" 3. 尝试减小batch_size")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
578
model_training_nnn_tpu/dataset_tf.py
Normal file
578
model_training_nnn_tpu/dataset_tf.py
Normal file
@@ -0,0 +1,578 @@
|
||||
import os
|
||||
import tensorflow as tf
|
||||
import h5py
|
||||
import numpy as np
|
||||
import math
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
|
||||
class BrainToTextDatasetTF:
|
||||
"""
|
||||
TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8
|
||||
|
||||
This class creates tf.data.Dataset objects that efficiently load and batch
|
||||
brain-to-text data from HDF5 files with TPU-optimized operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trial_indices: Dict[int, Dict[str, Any]],
|
||||
n_batches: Optional[int],
|
||||
split: str = 'train',
|
||||
batch_size: int = 64,
|
||||
days_per_batch: int = 1,
|
||||
random_seed: int = -1,
|
||||
must_include_days: Optional[List[int]] = None,
|
||||
feature_subset: Optional[List[int]] = None,
|
||||
prefetch_buffer: int = tf.data.AUTOTUNE,
|
||||
num_parallel_calls: int = tf.data.AUTOTUNE
|
||||
):
|
||||
"""
|
||||
Initialize TensorFlow dataset for brain-to-text data
|
||||
|
||||
Args:
|
||||
trial_indices: Dictionary with day numbers as keys and trial info as values
|
||||
n_batches: Number of training batches to create (None for validation)
|
||||
split: 'train' or 'test'
|
||||
batch_size: Number of examples per batch
|
||||
days_per_batch: Number of unique days per batch (for day-specific layers)
|
||||
random_seed: Random seed for reproducibility
|
||||
must_include_days: Days that must be included in every batch
|
||||
feature_subset: Subset of neural features to use
|
||||
prefetch_buffer: Buffer size for prefetching
|
||||
num_parallel_calls: Parallel processing threads
|
||||
"""
|
||||
|
||||
# Set random seed for reproducibility
|
||||
if random_seed != -1:
|
||||
tf.random.set_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
self.split = split
|
||||
if self.split not in ['train', 'test']:
|
||||
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
|
||||
|
||||
self.days_per_batch = days_per_batch
|
||||
self.batch_size = batch_size
|
||||
self.n_batches = n_batches
|
||||
self.trial_indices = trial_indices
|
||||
self.n_days = len(trial_indices.keys())
|
||||
self.feature_subset = feature_subset
|
||||
self.must_include_days = must_include_days
|
||||
self.prefetch_buffer = prefetch_buffer
|
||||
self.num_parallel_calls = num_parallel_calls
|
||||
|
||||
# Calculate total number of trials
|
||||
self.n_trials = 0
|
||||
for d in trial_indices:
|
||||
self.n_trials += len(trial_indices[d]['trials'])
|
||||
|
||||
# Validation checks
|
||||
if must_include_days is not None:
|
||||
if len(must_include_days) > days_per_batch:
|
||||
raise ValueError(f'must_include_days must be <= days_per_batch')
|
||||
|
||||
# Map negative indices
|
||||
for i, d in enumerate(must_include_days):
|
||||
if d < 0:
|
||||
must_include_days[i] = self.n_days + d
|
||||
|
||||
if self.split == 'train' and self.days_per_batch > self.n_days:
|
||||
raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})')
|
||||
|
||||
# Create batch indices
|
||||
if self.split == 'train':
|
||||
self.batch_indices = self._create_batch_index_train()
|
||||
else:
|
||||
self.batch_indices = self._create_batch_index_test()
|
||||
self.n_batches = len(self.batch_indices)
|
||||
|
||||
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
|
||||
"""Create training batch indices with random sampling"""
|
||||
batch_indices = {}
|
||||
|
||||
# Precompute non-must-include days
|
||||
if self.must_include_days is not None:
|
||||
non_must_include_days = [
|
||||
d for d in self.trial_indices.keys()
|
||||
if d not in self.must_include_days
|
||||
]
|
||||
|
||||
for batch_idx in range(self.n_batches):
|
||||
batch = {}
|
||||
|
||||
# Select days for this batch
|
||||
if self.must_include_days is not None and len(self.must_include_days) > 0:
|
||||
additional_days = np.random.choice(
|
||||
non_must_include_days,
|
||||
size=self.days_per_batch - len(self.must_include_days),
|
||||
replace=False
|
||||
)
|
||||
days = np.concatenate((self.must_include_days, additional_days))
|
||||
else:
|
||||
days = np.random.choice(
|
||||
list(self.trial_indices.keys()),
|
||||
size=self.days_per_batch,
|
||||
replace=False
|
||||
)
|
||||
|
||||
# Calculate trials per day
|
||||
num_trials = math.ceil(self.batch_size / self.days_per_batch)
|
||||
|
||||
for d in days:
|
||||
# Sample trials with replacement
|
||||
trial_idxs = np.random.choice(
|
||||
self.trial_indices[d]['trials'],
|
||||
size=num_trials,
|
||||
replace=True
|
||||
)
|
||||
batch[d] = trial_idxs.tolist()
|
||||
|
||||
# Remove extra trials to match exact batch size
|
||||
extra_trials = (num_trials * len(days)) - self.batch_size
|
||||
while extra_trials > 0:
|
||||
d = np.random.choice(days)
|
||||
if len(batch[d]) > 0:
|
||||
batch[d] = batch[d][:-1]
|
||||
extra_trials -= 1
|
||||
|
||||
batch_indices[batch_idx] = batch
|
||||
|
||||
return batch_indices
|
||||
|
||||
def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]:
|
||||
"""Create test batch indices ensuring all trials are seen once"""
|
||||
batch_indices = {}
|
||||
batch_idx = 0
|
||||
|
||||
for d in self.trial_indices.keys():
|
||||
num_trials = len(self.trial_indices[d]['trials'])
|
||||
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
|
||||
|
||||
for i in range(num_batches):
|
||||
start_idx = i * self.batch_size
|
||||
end_idx = min((i + 1) * self.batch_size, num_trials)
|
||||
|
||||
batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx]
|
||||
batch_indices[batch_idx] = {d: batch_trials}
|
||||
batch_idx += 1
|
||||
|
||||
return batch_indices
|
||||
|
||||
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
|
||||
"""Load a single trial's data from HDF5 file"""
|
||||
try:
|
||||
session_path = self.trial_indices[day]['session_path']
|
||||
|
||||
with h5py.File(session_path, 'r') as f:
|
||||
g = f[f'trial_{trial:04d}']
|
||||
|
||||
# Load neural features
|
||||
input_features = g['input_features'][:]
|
||||
if self.feature_subset:
|
||||
input_features = input_features[:, self.feature_subset]
|
||||
|
||||
# Convert to bfloat16 for TPU efficiency
|
||||
input_features = input_features.astype(np.float32) # TF will handle bfloat16 conversion
|
||||
|
||||
trial_data = {
|
||||
'input_features': input_features,
|
||||
'seq_class_ids': g['seq_class_ids'][:],
|
||||
'transcription': g['transcription'][:],
|
||||
'n_time_steps': g.attrs['n_time_steps'],
|
||||
'phone_seq_lens': g.attrs['seq_len'],
|
||||
'day_index': day,
|
||||
'block_num': g.attrs['block_num'],
|
||||
'trial_num': g.attrs['trial_num']
|
||||
}
|
||||
|
||||
return trial_data
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error loading trial {trial} from day {day}: {e}')
|
||||
# Return dummy data to maintain batch structure
|
||||
return {
|
||||
'input_features': np.zeros((100, 512), dtype=np.float32),
|
||||
'seq_class_ids': np.zeros((10,), dtype=np.int32),
|
||||
'transcription': np.zeros((50,), dtype=np.int32),
|
||||
'n_time_steps': 100,
|
||||
'phone_seq_lens': 10,
|
||||
'day_index': day,
|
||||
'block_num': 0,
|
||||
'trial_num': 0
|
||||
}
|
||||
|
||||
def _create_batch_generator(self):
|
||||
"""Generator function that yields individual batches"""
|
||||
for batch_idx in range(self.n_batches):
|
||||
batch_data = {
|
||||
'input_features': [],
|
||||
'seq_class_ids': [],
|
||||
'n_time_steps': [],
|
||||
'phone_seq_lens': [],
|
||||
'day_indices': [],
|
||||
'transcriptions': [],
|
||||
'block_nums': [],
|
||||
'trial_nums': []
|
||||
}
|
||||
|
||||
batch_index = self.batch_indices[batch_idx]
|
||||
|
||||
# Load data for each day in the batch
|
||||
for day in batch_index.keys():
|
||||
for trial in batch_index[day]:
|
||||
trial_data = self._load_trial_data(day, trial)
|
||||
|
||||
batch_data['input_features'].append(trial_data['input_features'])
|
||||
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
|
||||
batch_data['transcriptions'].append(trial_data['transcription'])
|
||||
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
|
||||
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
|
||||
batch_data['day_indices'].append(trial_data['day_index'])
|
||||
batch_data['block_nums'].append(trial_data['block_num'])
|
||||
batch_data['trial_nums'].append(trial_data['trial_num'])
|
||||
|
||||
# Pad sequences to create uniform batch
|
||||
max_time_steps = max(batch_data['n_time_steps'])
|
||||
max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
|
||||
max_transcription_len = max(len(trans) for trans in batch_data['transcriptions'])
|
||||
|
||||
# Pad input features
|
||||
padded_features = []
|
||||
for features in batch_data['input_features']:
|
||||
if features.shape[0] < max_time_steps:
|
||||
padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32)
|
||||
features = np.vstack([features, padding])
|
||||
padded_features.append(features)
|
||||
|
||||
# Pad sequences
|
||||
padded_seq_ids = []
|
||||
for seq in batch_data['seq_class_ids']:
|
||||
if len(seq) < max_phone_len:
|
||||
padding = np.zeros(max_phone_len - len(seq), dtype=np.int32)
|
||||
seq = np.concatenate([seq, padding])
|
||||
padded_seq_ids.append(seq)
|
||||
|
||||
# Pad transcriptions
|
||||
padded_transcriptions = []
|
||||
for trans in batch_data['transcriptions']:
|
||||
if len(trans) < max_transcription_len:
|
||||
padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32)
|
||||
trans = np.concatenate([trans, padding])
|
||||
padded_transcriptions.append(trans)
|
||||
|
||||
# Create final batch tensors
|
||||
batch = {
|
||||
'input_features': np.stack(padded_features),
|
||||
'seq_class_ids': np.stack(padded_seq_ids),
|
||||
'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32),
|
||||
'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32),
|
||||
'day_indices': np.array(batch_data['day_indices'], dtype=np.int32),
|
||||
'transcriptions': np.stack(padded_transcriptions),
|
||||
'block_nums': np.array(batch_data['block_nums'], dtype=np.int32),
|
||||
'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32)
|
||||
}
|
||||
|
||||
yield batch
|
||||
|
||||
def create_dataset(self) -> tf.data.Dataset:
|
||||
"""Create optimized tf.data.Dataset for TPU training"""
|
||||
|
||||
# Define output signature for the dataset
|
||||
output_signature = {
|
||||
'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
|
||||
'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
|
||||
'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
|
||||
'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32)
|
||||
}
|
||||
|
||||
# Create dataset from generator
|
||||
dataset = tf.data.Dataset.from_generator(
|
||||
self._create_batch_generator,
|
||||
output_signature=output_signature
|
||||
)
|
||||
|
||||
# Apply TPU-optimized transformations
|
||||
if self.split == 'train':
|
||||
# For training, add shuffling
|
||||
dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches))
|
||||
|
||||
# Prefetch for better performance
|
||||
dataset = dataset.prefetch(self.prefetch_buffer)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class DataAugmentationTF:
|
||||
"""
|
||||
TensorFlow data augmentation functions optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def gauss_smooth(inputs: tf.Tensor,
|
||||
smooth_kernel_std: float = 2.0,
|
||||
smooth_kernel_size: int = 100) -> tf.Tensor:
|
||||
"""
|
||||
Apply Gaussian smoothing along the time axis using TensorFlow operations
|
||||
|
||||
Args:
|
||||
inputs: Input tensor [batch_size, time_steps, features]
|
||||
smooth_kernel_std: Standard deviation of Gaussian kernel
|
||||
smooth_kernel_size: Size of the Gaussian kernel
|
||||
|
||||
Returns:
|
||||
Smoothed tensor with same shape as input
|
||||
"""
|
||||
# Create Gaussian kernel using numpy (computed once)
|
||||
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
|
||||
inp[smooth_kernel_size // 2] = 1
|
||||
gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std)
|
||||
valid_idx = np.argwhere(gauss_kernel > 0.01)
|
||||
gauss_kernel = gauss_kernel[valid_idx].flatten()
|
||||
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
|
||||
|
||||
# Convert to TensorFlow tensor
|
||||
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
|
||||
gauss_kernel = tf.reshape(gauss_kernel, [1, 1, -1]) # [1, 1, kernel_size]
|
||||
|
||||
# Prepare for convolution
|
||||
batch_size = tf.shape(inputs)[0]
|
||||
time_steps = tf.shape(inputs)[1]
|
||||
num_features = tf.shape(inputs)[2]
|
||||
|
||||
# Reshape for convolution: [batch_size * features, 1, time_steps]
|
||||
inputs_reshaped = tf.transpose(inputs, [0, 2, 1]) # [batch_size, features, time_steps]
|
||||
inputs_reshaped = tf.reshape(inputs_reshaped, [-1, 1, time_steps])
|
||||
|
||||
# Apply convolution
|
||||
smoothed = tf.nn.conv1d(
|
||||
inputs_reshaped,
|
||||
gauss_kernel,
|
||||
stride=1,
|
||||
padding='SAME'
|
||||
)
|
||||
|
||||
# Reshape back to original format
|
||||
smoothed = tf.reshape(smoothed, [batch_size, num_features, time_steps])
|
||||
smoothed = tf.transpose(smoothed, [0, 2, 1]) # [batch_size, time_steps, features]
|
||||
|
||||
return smoothed
|
||||
|
||||
@staticmethod
|
||||
def transform_data(features: tf.Tensor,
|
||||
n_time_steps: tf.Tensor,
|
||||
transform_args: Dict[str, Any],
|
||||
training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
|
||||
"""
|
||||
Apply data transformations optimized for TPU
|
||||
|
||||
Args:
|
||||
features: Input features [batch_size, time_steps, channels]
|
||||
n_time_steps: Number of valid time steps per sample
|
||||
transform_args: Transformation configuration
|
||||
training: Whether to apply training-only augmentations
|
||||
|
||||
Returns:
|
||||
Transformed features and updated time steps
|
||||
"""
|
||||
batch_size = tf.shape(features)[0]
|
||||
time_steps = tf.shape(features)[1]
|
||||
channels = tf.shape(features)[2]
|
||||
|
||||
# Training-only augmentations
|
||||
if training:
|
||||
# Static gain noise
|
||||
if transform_args.get('static_gain_std', 0) > 0:
|
||||
gain_std = transform_args['static_gain_std']
|
||||
# Create identity matrices for each batch
|
||||
identity_matrices = tf.eye(channels, batch_shape=[batch_size])
|
||||
# Add noise to create warp matrices
|
||||
noise = tf.random.normal([batch_size, channels, channels]) * gain_std
|
||||
warp_matrices = identity_matrices + noise
|
||||
# Apply transformation
|
||||
features = tf.linalg.matmul(features, warp_matrices)
|
||||
|
||||
# White noise
|
||||
if transform_args.get('white_noise_std', 0) > 0:
|
||||
white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std']
|
||||
features = features + white_noise
|
||||
|
||||
# Constant offset noise
|
||||
if transform_args.get('constant_offset_std', 0) > 0:
|
||||
offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std']
|
||||
features = features + offset_noise
|
||||
|
||||
# Random walk noise
|
||||
if transform_args.get('random_walk_std', 0) > 0:
|
||||
random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std']
|
||||
axis = transform_args.get('random_walk_axis', 1)
|
||||
random_walk_noise = tf.cumsum(random_walk_noise, axis=axis)
|
||||
features = features + random_walk_noise
|
||||
|
||||
# Random cutoff (simplified for TPU - apply to all samples in batch)
|
||||
if transform_args.get('random_cut', 0) > 0:
|
||||
max_cut = transform_args['random_cut']
|
||||
cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32)
|
||||
features = features[:, cut:, :]
|
||||
n_time_steps = n_time_steps - cut
|
||||
|
||||
# Apply Gaussian smoothing (both training and validation)
|
||||
if transform_args.get('smooth_data', False):
|
||||
features = DataAugmentationTF.gauss_smooth(
|
||||
features,
|
||||
smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0),
|
||||
smooth_kernel_size=transform_args.get('smooth_kernel_size', 100)
|
||||
)
|
||||
|
||||
return features, n_time_steps
|
||||
|
||||
|
||||
def train_test_split_indices(file_paths: List[str],
|
||||
test_percentage: float = 0.1,
|
||||
seed: int = -1,
|
||||
bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Split data from file_paths into train and test splits
|
||||
|
||||
Args:
|
||||
file_paths: List of HDF5 file paths
|
||||
test_percentage: Percentage of trials for testing
|
||||
seed: Random seed for reproducibility
|
||||
bad_trials_dict: Dictionary of trials to exclude
|
||||
|
||||
Returns:
|
||||
Tuple of (train_trials, test_trials) dictionaries
|
||||
"""
|
||||
# Set seed for reproducibility
|
||||
if seed != -1:
|
||||
np.random.seed(seed)
|
||||
|
||||
# Get trials in each day
|
||||
trials_per_day = {}
|
||||
for i, path in enumerate(file_paths):
|
||||
# Handle both Windows and Unix path separators
|
||||
path_parts = path.replace('\\', '/').split('/')
|
||||
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||||
|
||||
good_trial_indices = []
|
||||
|
||||
if os.path.exists(path):
|
||||
with h5py.File(path, 'r') as f:
|
||||
num_trials = len(list(f.keys()))
|
||||
for t in range(num_trials):
|
||||
key = f'trial_{t:04d}'
|
||||
|
||||
if key not in f:
|
||||
continue
|
||||
|
||||
block_num = f[key].attrs['block_num']
|
||||
trial_num = f[key].attrs['trial_num']
|
||||
|
||||
# Check if trial should be excluded
|
||||
if (bad_trials_dict is not None
|
||||
and session in bad_trials_dict
|
||||
and str(block_num) in bad_trials_dict[session]
|
||||
and trial_num in bad_trials_dict[session][str(block_num)]):
|
||||
continue
|
||||
|
||||
good_trial_indices.append(t)
|
||||
|
||||
trials_per_day[i] = {
|
||||
'num_trials': len(good_trial_indices),
|
||||
'trial_indices': good_trial_indices,
|
||||
'session_path': path
|
||||
}
|
||||
|
||||
# Split trials into train and test
|
||||
train_trials = {}
|
||||
test_trials = {}
|
||||
|
||||
for day in trials_per_day.keys():
|
||||
num_trials = trials_per_day[day]['num_trials']
|
||||
all_trial_indices = trials_per_day[day]['trial_indices']
|
||||
|
||||
if test_percentage == 0:
|
||||
train_trials[day] = {
|
||||
'trials': all_trial_indices,
|
||||
'session_path': trials_per_day[day]['session_path']
|
||||
}
|
||||
test_trials[day] = {
|
||||
'trials': [],
|
||||
'session_path': trials_per_day[day]['session_path']
|
||||
}
|
||||
elif test_percentage == 1:
|
||||
train_trials[day] = {
|
||||
'trials': [],
|
||||
'session_path': trials_per_day[day]['session_path']
|
||||
}
|
||||
test_trials[day] = {
|
||||
'trials': all_trial_indices,
|
||||
'session_path': trials_per_day[day]['session_path']
|
||||
}
|
||||
else:
|
||||
# Calculate number of test trials
|
||||
num_test = max(1, int(num_trials * test_percentage))
|
||||
|
||||
# Randomly select test indices
|
||||
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
|
||||
|
||||
# Remaining indices for training
|
||||
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
|
||||
|
||||
train_trials[day] = {
|
||||
'trials': train_indices,
|
||||
'session_path': trials_per_day[day]['session_path']
|
||||
}
|
||||
test_trials[day] = {
|
||||
'trials': test_indices,
|
||||
'session_path': trials_per_day[day]['session_path']
|
||||
}
|
||||
|
||||
return train_trials, test_trials
|
||||
|
||||
|
||||
# Utility functions for TPU-optimized data pipeline
|
||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Dict[str, Any],
|
||||
training: bool = True) -> tf.data.Dataset:
|
||||
"""
|
||||
Create input function for TPU training with data augmentation
|
||||
|
||||
Args:
|
||||
dataset_tf: BrainToTextDatasetTF instance
|
||||
transform_args: Data transformation configuration
|
||||
training: Whether this is for training (applies augmentations)
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset ready for TPU training
|
||||
"""
|
||||
dataset = dataset_tf.create_dataset()
|
||||
|
||||
def apply_transforms(batch):
|
||||
"""Apply data transformations to a batch"""
|
||||
features = batch['input_features']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
|
||||
# Apply transformations
|
||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, transform_args, training=training
|
||||
)
|
||||
|
||||
# Update batch with transformed data
|
||||
batch['input_features'] = features
|
||||
batch['n_time_steps'] = n_time_steps
|
||||
|
||||
return batch
|
||||
|
||||
# Apply transformations
|
||||
dataset = dataset.map(
|
||||
apply_transforms,
|
||||
num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
return dataset
|
480
model_training_nnn_tpu/evaluate_model_tf.py
Normal file
480
model_training_nnn_tpu/evaluate_model_tf.py
Normal file
@@ -0,0 +1,480 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TensorFlow Evaluation Script for Brain-to-Text RNN Model
|
||||
Optimized for TPU v5e-8
|
||||
|
||||
This script evaluates the TripleGRUDecoder model using TensorFlow and provides
|
||||
detailed metrics and analysis of model performance on test data.
|
||||
|
||||
Usage:
|
||||
python evaluate_model_tf.py --model_path path/to/model --data_dir path/to/data
|
||||
|
||||
Requirements:
|
||||
- TensorFlow >= 2.15.0
|
||||
- TPU v5e-8 environment
|
||||
- Trained model checkpoint
|
||||
- Access to brain-to-text HDF5 dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# Add the current directory to Python path for imports
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from trainer_tf import BrainToTextDecoderTrainerTF
|
||||
from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn
|
||||
from rnn_model_tf import create_tpu_strategy, configure_mixed_precision
|
||||
|
||||
|
||||
class BrainToTextEvaluatorTF:
|
||||
"""
|
||||
TensorFlow evaluator for brain-to-text model performance analysis
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, config: Dict[str, Any], eval_type: str = 'test'):
|
||||
"""
|
||||
Initialize evaluator
|
||||
|
||||
Args:
|
||||
model_path: Path to trained model checkpoint
|
||||
config: Configuration dictionary
|
||||
eval_type: 'test' or 'val' evaluation type
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.config = config
|
||||
self.eval_type = eval_type
|
||||
|
||||
# Initialize TPU strategy
|
||||
self.strategy = create_tpu_strategy()
|
||||
print(f"Evaluation using {self.strategy.num_replicas_in_sync} TPU cores")
|
||||
|
||||
# Configure mixed precision
|
||||
if config.get('use_amp', True):
|
||||
configure_mixed_precision()
|
||||
|
||||
# Load model
|
||||
with self.strategy.scope():
|
||||
self.trainer = BrainToTextDecoderTrainerTF(config)
|
||||
self.trainer.load_checkpoint(model_path)
|
||||
|
||||
print(f"Model loaded from: {model_path}")
|
||||
|
||||
def evaluate_dataset(self, save_results: bool = True,
|
||||
return_predictions: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate model on specified dataset
|
||||
|
||||
Args:
|
||||
save_results: Whether to save detailed results to file
|
||||
return_predictions: Whether to return individual predictions
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation metrics and optionally predictions
|
||||
"""
|
||||
print(f"Starting {self.eval_type} evaluation...")
|
||||
|
||||
# Create evaluation dataset
|
||||
if self.eval_type == 'test':
|
||||
dataset_tf = self.trainer.val_dataset_tf # Using validation data as test
|
||||
else:
|
||||
dataset_tf = self.trainer.val_dataset_tf
|
||||
|
||||
eval_dataset = create_input_fn(
|
||||
dataset_tf,
|
||||
self.config['dataset']['data_transforms'],
|
||||
training=False
|
||||
)
|
||||
|
||||
# Distribute dataset
|
||||
eval_dist_dataset = self.strategy.experimental_distribute_dataset(eval_dataset)
|
||||
|
||||
# Run evaluation
|
||||
results = self._run_evaluation(eval_dist_dataset, return_predictions)
|
||||
|
||||
# Calculate summary metrics
|
||||
summary_metrics = self._calculate_summary_metrics(results)
|
||||
|
||||
print(f"Evaluation completed!")
|
||||
print(f"Overall PER: {summary_metrics['overall_per']:.4f}")
|
||||
print(f"Overall Loss: {summary_metrics['overall_loss']:.4f}")
|
||||
print(f"Total trials evaluated: {summary_metrics['total_trials']}")
|
||||
|
||||
# Save results if requested
|
||||
if save_results:
|
||||
self._save_results(results, summary_metrics)
|
||||
|
||||
return {
|
||||
'summary_metrics': summary_metrics,
|
||||
'detailed_results': results if return_predictions else None
|
||||
}
|
||||
|
||||
def _run_evaluation(self, eval_dataset, return_predictions: bool) -> List[Dict[str, Any]]:
|
||||
"""Run evaluation on distributed dataset"""
|
||||
all_results = []
|
||||
batch_idx = 0
|
||||
|
||||
for batch in eval_dataset:
|
||||
batch_results = self.strategy.run(self._evaluation_step, args=(batch, return_predictions))
|
||||
|
||||
# Gather results from all replicas
|
||||
gathered_results = {}
|
||||
for key in batch_results.keys():
|
||||
if key in ['logits', 'features'] and not return_predictions:
|
||||
continue # Skip large tensors if not needed
|
||||
|
||||
values = self.strategy.experimental_local_results(batch_results[key])
|
||||
if key in ['loss', 'edit_distance', 'seq_length']:
|
||||
# Scalar metrics - just take the values
|
||||
gathered_results[key] = [float(v.numpy()) for v in values]
|
||||
else:
|
||||
# Tensor data - concatenate across replicas
|
||||
gathered_results[key] = [v.numpy() for v in values]
|
||||
|
||||
all_results.append(gathered_results)
|
||||
batch_idx += 1
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
print(f"Processed {batch_idx} batches...")
|
||||
|
||||
return all_results
|
||||
|
||||
@tf.function
|
||||
def _evaluation_step(self, batch, return_predictions: bool):
|
||||
"""Single evaluation step"""
|
||||
features = batch['input_features']
|
||||
labels = batch['seq_class_ids']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
phone_seq_lens = batch['phone_seq_lens']
|
||||
day_indices = batch['day_indices']
|
||||
|
||||
# Apply data transformations (no augmentation)
|
||||
from dataset_tf import DataAugmentationTF
|
||||
features_transformed, n_time_steps_transformed = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, self.config['dataset']['data_transforms'], training=False
|
||||
)
|
||||
|
||||
# Calculate adjusted lengths for CTC
|
||||
adjusted_lens = tf.cast(
|
||||
(tf.cast(n_time_steps_transformed, tf.float32) - self.config['model']['patch_size']) /
|
||||
self.config['model']['patch_stride'] + 1,
|
||||
tf.int32
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
logits = self.trainer.model(
|
||||
features_transformed, day_indices, None, False, 'inference', training=False
|
||||
)
|
||||
|
||||
# Calculate loss
|
||||
loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
loss = self.trainer.ctc_loss(loss_input, logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Calculate edit distance for PER
|
||||
predicted_ids = tf.argmax(logits, axis=-1)
|
||||
batch_size = tf.shape(logits)[0]
|
||||
|
||||
# Initialize metrics
|
||||
total_edit_distance = 0
|
||||
total_seq_length = tf.reduce_sum(phone_seq_lens)
|
||||
|
||||
# Decode predictions and calculate edit distance
|
||||
predictions = []
|
||||
targets = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Get prediction for this sample
|
||||
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
||||
|
||||
# Remove consecutive duplicates using tf.py_function for simplicity
|
||||
pred_seq_unique = tf.py_function(
|
||||
func=self._remove_consecutive_duplicates,
|
||||
inp=[pred_seq],
|
||||
Tout=tf.int64
|
||||
)
|
||||
|
||||
# Remove blanks (assuming blank_index=0)
|
||||
pred_seq_clean = tf.boolean_mask(pred_seq_unique, pred_seq_unique != 0)
|
||||
|
||||
# Get true sequence
|
||||
true_seq = labels[i, :phone_seq_lens[i]]
|
||||
|
||||
# Calculate edit distance for this pair
|
||||
if tf.size(pred_seq_clean) > 0 and tf.size(true_seq) > 0:
|
||||
pred_sparse = tf.SparseTensor(
|
||||
indices=tf.expand_dims(tf.range(tf.size(pred_seq_clean), dtype=tf.int64), 1),
|
||||
values=tf.cast(pred_seq_clean, tf.int64),
|
||||
dense_shape=[tf.size(pred_seq_clean, out_type=tf.int64)]
|
||||
)
|
||||
|
||||
true_sparse = tf.SparseTensor(
|
||||
indices=tf.expand_dims(tf.range(tf.size(true_seq), dtype=tf.int64), 1),
|
||||
values=tf.cast(true_seq, tf.int64),
|
||||
dense_shape=[tf.size(true_seq, out_type=tf.int64)]
|
||||
)
|
||||
|
||||
edit_dist = tf.edit_distance(pred_sparse, true_sparse, normalize=False)
|
||||
total_edit_distance += edit_dist
|
||||
|
||||
if return_predictions:
|
||||
predictions.append(pred_seq_clean)
|
||||
targets.append(true_seq)
|
||||
|
||||
result = {
|
||||
'loss': loss,
|
||||
'edit_distance': total_edit_distance,
|
||||
'seq_length': total_seq_length,
|
||||
'day_indices': day_indices,
|
||||
'n_time_steps': n_time_steps,
|
||||
'phone_seq_lens': phone_seq_lens
|
||||
}
|
||||
|
||||
if return_predictions:
|
||||
result.update({
|
||||
'logits': logits,
|
||||
'predictions': predictions,
|
||||
'targets': targets,
|
||||
'features': features
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _remove_consecutive_duplicates(self, seq):
|
||||
"""Remove consecutive duplicate elements from sequence"""
|
||||
seq_np = seq.numpy()
|
||||
if len(seq_np) == 0:
|
||||
return tf.constant([], dtype=tf.int64)
|
||||
|
||||
unique_seq = [seq_np[0]]
|
||||
for i in range(1, len(seq_np)):
|
||||
if seq_np[i] != seq_np[i-1]:
|
||||
unique_seq.append(seq_np[i])
|
||||
|
||||
return tf.constant(unique_seq, dtype=tf.int64)
|
||||
|
||||
def _calculate_summary_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Calculate summary metrics from evaluation results"""
|
||||
total_loss = 0.0
|
||||
total_edit_distance = 0
|
||||
total_seq_length = 0
|
||||
total_trials = 0
|
||||
num_batches = len(results)
|
||||
|
||||
# Day-specific metrics
|
||||
day_metrics = {}
|
||||
|
||||
for batch_results in results:
|
||||
# Sum losses across replicas
|
||||
batch_loss = sum(batch_results['loss'])
|
||||
total_loss += batch_loss
|
||||
|
||||
# Sum edit distances and sequence lengths
|
||||
batch_edit_dist = sum(batch_results['edit_distance'])
|
||||
batch_seq_len = sum(batch_results['seq_length'])
|
||||
|
||||
total_edit_distance += batch_edit_dist
|
||||
total_seq_length += batch_seq_len
|
||||
|
||||
# Count trials
|
||||
for day_indices_replica in batch_results['day_indices']:
|
||||
total_trials += len(day_indices_replica)
|
||||
|
||||
# Track per-day metrics
|
||||
for i, day_idx in enumerate(day_indices_replica):
|
||||
day_idx = int(day_idx)
|
||||
if day_idx not in day_metrics:
|
||||
day_metrics[day_idx] = {'edit_distance': 0, 'seq_length': 0, 'trials': 0}
|
||||
|
||||
day_metrics[day_idx]['trials'] += 1
|
||||
|
||||
# Calculate averages
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
overall_per = total_edit_distance / max(total_seq_length, 1e-6)
|
||||
|
||||
# Calculate per-day PERs
|
||||
day_pers = {}
|
||||
for day_idx, metrics in day_metrics.items():
|
||||
day_per = metrics['edit_distance'] / max(metrics['seq_length'], 1e-6)
|
||||
day_pers[day_idx] = {
|
||||
'per': day_per,
|
||||
'edit_distance': metrics['edit_distance'],
|
||||
'seq_length': metrics['seq_length'],
|
||||
'trials': metrics['trials']
|
||||
}
|
||||
|
||||
return {
|
||||
'overall_per': float(overall_per),
|
||||
'overall_loss': float(avg_loss),
|
||||
'total_edit_distance': int(total_edit_distance),
|
||||
'total_seq_length': int(total_seq_length),
|
||||
'total_trials': total_trials,
|
||||
'num_batches': num_batches,
|
||||
'day_metrics': day_pers
|
||||
}
|
||||
|
||||
def _save_results(self, detailed_results: List[Dict[str, Any]],
|
||||
summary_metrics: Dict[str, Any]):
|
||||
"""Save evaluation results to files"""
|
||||
output_dir = self.config.get('output_dir', './eval_output')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save summary metrics
|
||||
summary_path = os.path.join(output_dir, f'{self.eval_type}_summary_metrics.json')
|
||||
with open(summary_path, 'w') as f:
|
||||
json.dump(summary_metrics, f, indent=2)
|
||||
print(f"Summary metrics saved to: {summary_path}")
|
||||
|
||||
# Save detailed results
|
||||
detailed_path = os.path.join(output_dir, f'{self.eval_type}_detailed_results.pkl')
|
||||
with open(detailed_path, 'wb') as f:
|
||||
pickle.dump(detailed_results, f)
|
||||
print(f"Detailed results saved to: {detailed_path}")
|
||||
|
||||
# Save per-day breakdown
|
||||
if 'day_metrics' in summary_metrics:
|
||||
day_breakdown_path = os.path.join(output_dir, f'{self.eval_type}_day_breakdown.json')
|
||||
with open(day_breakdown_path, 'w') as f:
|
||||
json.dump(summary_metrics['day_metrics'], f, indent=2)
|
||||
print(f"Per-day breakdown saved to: {day_breakdown_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main evaluation function"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Evaluate Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--model_path',
|
||||
required=True,
|
||||
help='Path to trained model checkpoint (without extension)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
default='rnn_args.yaml',
|
||||
help='Path to model configuration file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--data_dir',
|
||||
default=None,
|
||||
help='Override data directory from config'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--eval_type',
|
||||
choices=['test', 'val'],
|
||||
default='test',
|
||||
help='Type of evaluation to run'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
default='./eval_output',
|
||||
help='Directory to save evaluation results'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--save_predictions',
|
||||
action='store_true',
|
||||
help='Save individual predictions and targets'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Override batch size for evaluation'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--sessions',
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Specific sessions to evaluate (overrides config)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup TPU environment
|
||||
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
|
||||
|
||||
# Load configuration
|
||||
if not os.path.exists(args.config_path):
|
||||
raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
|
||||
|
||||
config = OmegaConf.load(args.config_path)
|
||||
|
||||
# Apply overrides
|
||||
if args.data_dir:
|
||||
config.dataset.dataset_dir = args.data_dir
|
||||
if args.batch_size:
|
||||
config.dataset.batch_size = args.batch_size
|
||||
if args.sessions:
|
||||
config.dataset.sessions = args.sessions
|
||||
if args.output_dir:
|
||||
config.output_dir = args.output_dir
|
||||
|
||||
# Validate model checkpoint exists
|
||||
if not os.path.exists(args.model_path + '.weights.h5'):
|
||||
raise FileNotFoundError(f"Model checkpoint not found: {args.model_path}")
|
||||
|
||||
try:
|
||||
# Initialize evaluator
|
||||
evaluator = BrainToTextEvaluatorTF(
|
||||
model_path=args.model_path,
|
||||
config=config,
|
||||
eval_type=args.eval_type
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
results = evaluator.evaluate_dataset(
|
||||
save_results=True,
|
||||
return_predictions=args.save_predictions
|
||||
)
|
||||
|
||||
# Print results
|
||||
metrics = results['summary_metrics']
|
||||
print("\n" + "="*60)
|
||||
print("EVALUATION RESULTS")
|
||||
print("="*60)
|
||||
print(f"Overall PER: {metrics['overall_per']:.6f}")
|
||||
print(f"Overall Loss: {metrics['overall_loss']:.6f}")
|
||||
print(f"Total Edit Distance: {metrics['total_edit_distance']}")
|
||||
print(f"Total Sequence Length: {metrics['total_seq_length']}")
|
||||
print(f"Total Trials: {metrics['total_trials']}")
|
||||
print(f"Batches Processed: {metrics['num_batches']}")
|
||||
|
||||
# Print per-day results if available
|
||||
if 'day_metrics' in metrics and metrics['day_metrics']:
|
||||
print("\nPER-DAY RESULTS:")
|
||||
print("-" * 40)
|
||||
for day_idx, day_metrics in metrics['day_metrics'].items():
|
||||
session_name = config.dataset.sessions[day_idx] if day_idx < len(config.dataset.sessions) else f"Day_{day_idx}"
|
||||
print(f"{session_name}: PER={day_metrics['per']:.6f}, Trials={day_metrics['trials']}")
|
||||
|
||||
print("\nEvaluation completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Evaluation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,194 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
最简单的TPU测试 - 完全避开bf16问题
|
||||
只使用float32,最基本的操作
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 完全不设置任何bf16相关的环境变量
|
||||
# 只设置最基本的XLA优化
|
||||
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true'
|
||||
|
||||
# 确保不使用bf16
|
||||
if 'XLA_USE_BF16' in os.environ:
|
||||
del os.environ['XLA_USE_BF16']
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
"""测试最基本的TPU操作"""
|
||||
print("🚀 测试最基本的TPU操作...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
print(f"📱 设备: {device}")
|
||||
|
||||
# 测试1: 基本张量操作
|
||||
print("🔧 测试基本张量操作...")
|
||||
a = torch.randn(4, 4, device=device, dtype=torch.float32)
|
||||
b = torch.randn(4, 4, device=device, dtype=torch.float32)
|
||||
c = a + b
|
||||
|
||||
print(f" a.shape: {a.shape}, dtype: {a.dtype}")
|
||||
print(f" b.shape: {b.shape}, dtype: {b.dtype}")
|
||||
print(f" c.shape: {c.shape}, dtype: {c.dtype}")
|
||||
|
||||
# 同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
print("✅ 基本张量操作成功")
|
||||
|
||||
# 测试2: 矩阵乘法
|
||||
print("🔧 测试矩阵乘法...")
|
||||
d = torch.mm(a, b)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}")
|
||||
print("✅ 矩阵乘法成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 基本操作失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_simple_model():
|
||||
"""测试最简单的模型"""
|
||||
print("\n🧠 测试最简单的模型...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
|
||||
# 超简单的线性模型
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 5),
|
||||
nn.ReLU(),
|
||||
nn.Linear(5, 2)
|
||||
).to(device)
|
||||
|
||||
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
# 确保所有参数都是float32
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
# 创建输入数据 - 明确指定float32
|
||||
x = torch.randn(8, 10, device=device, dtype=torch.float32)
|
||||
|
||||
print(f"📥 输入: shape={x.shape}, dtype={x.dtype}")
|
||||
|
||||
# 前向传播
|
||||
with torch.no_grad():
|
||||
output = model(x)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
print(f"📤 输出: shape={output.shape}, dtype={output.dtype}")
|
||||
print("✅ 简单模型前向传播成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 简单模型失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_training_step():
|
||||
"""测试最简单的训练步骤"""
|
||||
print("\n🎯 测试最简单的训练步骤...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
|
||||
# 超简单模型
|
||||
model = nn.Linear(10, 1).to(device)
|
||||
|
||||
# 确保权重是float32
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
# 创建数据 - 明确float32
|
||||
x = torch.randn(4, 10, device=device, dtype=torch.float32)
|
||||
y = torch.randn(4, 1, device=device, dtype=torch.float32)
|
||||
|
||||
print(f"📥 输入: {x.shape}, {x.dtype}")
|
||||
print(f"📥 标签: {y.shape}, {y.dtype}")
|
||||
|
||||
# 一个训练步骤
|
||||
optimizer.zero_grad()
|
||||
output = model(x)
|
||||
loss = criterion(output, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
print(f"🎯 损失: {loss.item():.4f}")
|
||||
print("✅ 训练步骤成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 训练步骤失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 50)
|
||||
print("🔬 最简TPU测试 (仅float32)")
|
||||
print("=" * 50)
|
||||
|
||||
all_passed = True
|
||||
|
||||
# 测试1: 基本操作
|
||||
if test_basic_operations():
|
||||
print("1️⃣ 基本操作 ✅")
|
||||
else:
|
||||
print("1️⃣ 基本操作 ❌")
|
||||
all_passed = False
|
||||
|
||||
# 测试2: 简单模型
|
||||
if test_simple_model():
|
||||
print("2️⃣ 简单模型 ✅")
|
||||
else:
|
||||
print("2️⃣ 简单模型 ❌")
|
||||
all_passed = False
|
||||
|
||||
# 测试3: 训练步骤
|
||||
if test_training_step():
|
||||
print("3️⃣ 训练步骤 ✅")
|
||||
else:
|
||||
print("3️⃣ 训练步骤 ❌")
|
||||
all_passed = False
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
if all_passed:
|
||||
print("🎉 所有测试通过! TPU工作正常")
|
||||
print("💡 现在可以尝试更复杂的模型")
|
||||
else:
|
||||
print("❌ 部分测试失败")
|
||||
print("💡 建议:")
|
||||
print(" 1. 检查TPU资源是否可用")
|
||||
print(" 2. 确认torch_xla安装正确")
|
||||
print(" 3. 重启runtime清理状态")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,253 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
超简单MNIST TPU训练 - 完全避开混合精度问题
|
||||
只使用float32,确保稳定运行
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
# 清理所有可能导致bf16问题的环境变量
|
||||
for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# 只设置最基本的XLA优化
|
||||
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false'
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
|
||||
|
||||
class SimpleMNISTNet(nn.Module):
|
||||
"""超简单的MNIST分类器"""
|
||||
|
||||
def __init__(self):
|
||||
super(SimpleMNISTNet, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc1 = nn.Linear(28 * 28, 128)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.fc3 = nn.Linear(64, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.flatten(x)
|
||||
x = self.relu1(self.fc1(x))
|
||||
x = self.relu2(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_mnist_data(batch_size=64):
|
||||
"""获取MNIST数据"""
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
test_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
def train_mnist():
|
||||
"""训练MNIST模型"""
|
||||
print("🚀 开始MNIST TPU训练...")
|
||||
|
||||
# 获取设备
|
||||
device = xm.xla_device()
|
||||
print(f"📱 设备: {device}")
|
||||
|
||||
# 创建模型
|
||||
model = SimpleMNISTNet().to(device)
|
||||
|
||||
# 确保所有参数都是float32
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# 损失函数和优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
# 获取数据
|
||||
print("📥 加载MNIST数据...")
|
||||
train_loader, test_loader = get_mnist_data(batch_size=64)
|
||||
|
||||
# 使用XLA并行加载器
|
||||
train_device_loader = pl.MpDeviceLoader(train_loader, device)
|
||||
|
||||
print("🎯 开始训练...")
|
||||
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
max_batches = 100 # 只训练100个批次,快速验证
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_device_loader):
|
||||
if batch_idx >= max_batches:
|
||||
break
|
||||
|
||||
# 确保数据类型正确
|
||||
data = data.to(torch.float32)
|
||||
target = target.to(torch.long)
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 统计
|
||||
total_loss += loss.item()
|
||||
pred = output.argmax(dim=1)
|
||||
correct += pred.eq(target).sum().item()
|
||||
total += target.size(0)
|
||||
|
||||
# 每10个批次同步一次
|
||||
if batch_idx % 10 == 0:
|
||||
xm.mark_step()
|
||||
current_acc = 100. * correct / total
|
||||
avg_loss = total_loss / (batch_idx + 1)
|
||||
|
||||
print(f'批次 {batch_idx:3d}/{max_batches} | '
|
||||
f'损失: {avg_loss:.4f} | '
|
||||
f'准确率: {current_acc:.2f}%')
|
||||
|
||||
# 最终同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
train_time = time.time() - start_time
|
||||
final_acc = 100. * correct / total
|
||||
final_loss = total_loss / min(batch_idx + 1, max_batches)
|
||||
|
||||
print(f"\n✅ 训练完成!")
|
||||
print(f"⏱️ 训练时间: {train_time:.2f}秒")
|
||||
print(f"🎯 最终损失: {final_loss:.4f}")
|
||||
print(f"🎯 训练准确率: {final_acc:.2f}%")
|
||||
|
||||
return model, final_loss, final_acc
|
||||
|
||||
|
||||
def test_mnist(model):
|
||||
"""测试MNIST模型"""
|
||||
print("\n🧪 开始测试...")
|
||||
|
||||
device = xm.xla_device()
|
||||
_, test_loader = get_mnist_data(batch_size=64)
|
||||
|
||||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||||
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
max_test_batches = 50 # 只测试50个批次
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, target) in enumerate(test_device_loader):
|
||||
if batch_idx >= max_test_batches:
|
||||
break
|
||||
|
||||
# 确保数据类型
|
||||
data = data.to(torch.float32)
|
||||
target = target.to(torch.long)
|
||||
|
||||
output = model(data)
|
||||
pred = output.argmax(dim=1)
|
||||
correct += pred.eq(target).sum().item()
|
||||
total += target.size(0)
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
xm.mark_step()
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
test_time = time.time() - start_time
|
||||
accuracy = 100. * correct / total
|
||||
|
||||
print(f"✅ 测试完成!")
|
||||
print(f"⏱️ 测试时间: {test_time:.2f}秒")
|
||||
print(f"🎯 测试准确率: {accuracy:.2f}%")
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("🔢 超简单MNIST TPU训练 (仅float32)")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 训练
|
||||
model, train_loss, train_acc = train_mnist()
|
||||
|
||||
# 测试
|
||||
test_acc = test_mnist(model)
|
||||
|
||||
# 保存模型
|
||||
print("\n💾 保存模型...")
|
||||
model_cpu = model.cpu()
|
||||
torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth')
|
||||
print("✅ 模型已保存")
|
||||
|
||||
print("\n🎉 全部完成!")
|
||||
print(f"📊 训练准确率: {train_acc:.2f}%")
|
||||
print(f"📊 测试准确率: {test_acc:.2f}%")
|
||||
|
||||
if train_acc > 80 and test_acc > 75:
|
||||
print("✅ 模型训练成功!")
|
||||
else:
|
||||
print("⚠️ 模型性能一般,但TPU功能正常")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 训练失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
40
model_training_nnn_tpu/requirements_tf.txt
Normal file
40
model_training_nnn_tpu/requirements_tf.txt
Normal file
@@ -0,0 +1,40 @@
|
||||
# TensorFlow Brain-to-Text Requirements for TPU v5e-8
|
||||
# Install with: pip install -r requirements_tf.txt
|
||||
|
||||
# Core TensorFlow and TPU support
|
||||
tensorflow>=2.15.0
|
||||
tensorflow-text>=2.15.0
|
||||
|
||||
# Data processing
|
||||
numpy>=1.21.0
|
||||
h5py>=3.7.0
|
||||
scipy>=1.9.0
|
||||
|
||||
# Configuration and utilities
|
||||
omegaconf>=2.3.0
|
||||
pyyaml>=6.0
|
||||
|
||||
# Logging and monitoring
|
||||
tensorboard>=2.12.0
|
||||
wandb>=0.15.0 # Optional: for experiment tracking
|
||||
|
||||
# Audio processing (for phoneme analysis)
|
||||
librosa>=0.10.0
|
||||
|
||||
# Development and testing
|
||||
pytest>=7.0.0
|
||||
pytest-cov>=4.0.0
|
||||
black>=22.0.0
|
||||
flake8>=5.0.0
|
||||
|
||||
# Optional: Jupyter for analysis notebooks
|
||||
jupyter>=1.0.0
|
||||
matplotlib>=3.5.0
|
||||
seaborn>=0.11.0
|
||||
pandas>=1.4.0
|
||||
|
||||
# Memory optimization
|
||||
psutil>=5.8.0
|
||||
|
||||
# Note: For TPU v5e-8 environments, TensorFlow should be pre-installed
|
||||
# with TPU support. This requirements file is for additional dependencies.
|
@@ -1,94 +0,0 @@
|
||||
# 简化的TPU训练配置 - 更快编译
|
||||
model:
|
||||
n_input_features: 512
|
||||
n_units: 384 # 减少从768到384
|
||||
rnn_dropout: 0.2 # 减少dropout
|
||||
rnn_trainable: true
|
||||
n_layers: 3 # 减少从5到3层
|
||||
patch_size: 8 # 减少从14到8
|
||||
patch_stride: 4
|
||||
|
||||
input_network:
|
||||
n_input_layers: 1
|
||||
input_layer_sizes:
|
||||
- 512
|
||||
input_trainable: true
|
||||
input_layer_dropout: 0.1 # 减少dropout
|
||||
|
||||
mode: train
|
||||
use_amp: true
|
||||
|
||||
# TPU分布式训练设置
|
||||
use_tpu: true
|
||||
num_tpu_cores: 8
|
||||
gradient_accumulation_steps: 4 # 增加梯度累积补偿小batch
|
||||
|
||||
output_dir: trained_models/simple_rnn
|
||||
checkpoint_dir: trained_models/simple_rnn/checkpoint
|
||||
init_from_checkpoint: false
|
||||
save_best_checkpoint: true
|
||||
save_val_metrics: true
|
||||
|
||||
num_training_batches: 1000 # 先测试1000个batch
|
||||
lr_scheduler_type: cosine
|
||||
lr_max: 0.003 # 稍微降低学习率
|
||||
lr_min: 0.0001
|
||||
lr_decay_steps: 1000
|
||||
lr_warmup_steps: 100
|
||||
|
||||
lr_max_day: 0.003
|
||||
lr_min_day: 0.0001
|
||||
lr_decay_steps_day: 1000
|
||||
lr_warmup_steps_day: 100
|
||||
|
||||
beta0: 0.9
|
||||
beta1: 0.999
|
||||
epsilon: 0.1
|
||||
weight_decay: 0.001
|
||||
weight_decay_day: 0
|
||||
seed: 10
|
||||
grad_norm_clip_value: 5 # 减少梯度裁剪
|
||||
|
||||
batches_per_train_log: 50 # 更频繁的日志
|
||||
batches_per_val_step: 200
|
||||
log_individual_day_val_PER: true
|
||||
|
||||
# 禁用对抗训练进行快速测试
|
||||
adversarial:
|
||||
enabled: false # 先禁用对抗训练
|
||||
|
||||
dataset:
|
||||
data_transforms:
|
||||
white_noise_std: 0.5 # 减少数据增强
|
||||
constant_offset_std: 0.1
|
||||
random_walk_std: 0.0
|
||||
random_walk_axis: -1
|
||||
static_gain_std: 0.0
|
||||
random_cut: 1 # 减少随机裁剪
|
||||
smooth_kernel_size: 50 # 减少平滑核大小
|
||||
smooth_data: true
|
||||
smooth_kernel_std: 1
|
||||
|
||||
neural_dim: 512
|
||||
batch_size: 16 # 减少batch size从32到16
|
||||
n_classes: 41
|
||||
max_seq_elements: 300 # 减少序列长度
|
||||
days_per_batch: 2 # 减少每批天数
|
||||
seed: 1
|
||||
num_dataloader_workers: 0
|
||||
loader_shuffle: false
|
||||
test_percentage: 0.1
|
||||
dataset_dir: ../data/hdf5_data_final
|
||||
|
||||
# 只使用部分session进行快速测试
|
||||
sessions:
|
||||
- t15.2023.08.11
|
||||
- t15.2023.08.13
|
||||
- t15.2023.08.18
|
||||
- t15.2023.08.20
|
||||
|
||||
dataset_probability_val:
|
||||
- 0
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
781
model_training_nnn_tpu/rnn_model_tf.py
Normal file
781
model_training_nnn_tpu/rnn_model_tf.py
Normal file
@@ -0,0 +1,781 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
import numpy as np
|
||||
|
||||
|
||||
@tf.custom_gradient
|
||||
def gradient_reverse(x, lambd=1.0):
|
||||
"""
|
||||
Gradient Reversal Layer (GRL) for TensorFlow
|
||||
Forward: identity
|
||||
Backward: multiply incoming gradient by -lambda
|
||||
"""
|
||||
def grad(dy):
|
||||
return -lambd * dy, None
|
||||
|
||||
return tf.identity(x), grad
|
||||
|
||||
|
||||
class NoiseModel(keras.Model):
|
||||
"""
|
||||
Noise Model: 2-layer GRU that learns to estimate noise in the neural data
|
||||
TensorFlow/Keras implementation optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
**kwargs):
|
||||
super(NoiseModel, self).__init__(**kwargs)
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Day-specific input layers - use Variables for TPU compatibility
|
||||
self.day_layer_activation = layers.Activation('softsign')
|
||||
|
||||
# Initialize day-specific weights and biases as Variables
|
||||
self.day_weights = []
|
||||
self.day_biases = []
|
||||
for i in range(n_days):
|
||||
weight = self.add_weight(
|
||||
name=f'day_weight_{i}',
|
||||
shape=(neural_dim, neural_dim),
|
||||
initializer='identity',
|
||||
trainable=True
|
||||
)
|
||||
bias = self.add_weight(
|
||||
name=f'day_bias_{i}',
|
||||
shape=(neural_dim,),
|
||||
initializer='zeros',
|
||||
trainable=True
|
||||
)
|
||||
self.day_weights.append(weight)
|
||||
self.day_biases.append(bias)
|
||||
|
||||
self.day_layer_dropout = layers.Dropout(input_dropout)
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 2-layer GRU for noise estimation
|
||||
# Use separate GRU layers for better TPU performance
|
||||
self.gru1 = layers.GRU(
|
||||
units=self.input_size,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0, # Avoid recurrent dropout on TPU
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='noise_gru1'
|
||||
)
|
||||
|
||||
self.gru2 = layers.GRU(
|
||||
units=self.input_size,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='noise_gru2'
|
||||
)
|
||||
|
||||
# Learnable initial hidden states
|
||||
self.h0_1 = self.add_weight(
|
||||
name='h0_1',
|
||||
shape=(1, self.input_size),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
self.h0_2 = self.add_weight(
|
||||
name='h0_2',
|
||||
shape=(1, self.input_size),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
|
||||
def call(self, x, day_idx, states=None, training=None):
|
||||
"""
|
||||
Forward pass optimized for TPU compilation
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, time_steps, neural_dim]
|
||||
day_idx: Day indices [batch_size]
|
||||
states: Optional initial states
|
||||
training: Training mode flag
|
||||
"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Stack all day weights and biases for efficient gathering
|
||||
all_day_weights = tf.stack(self.day_weights, axis=0) # [n_days, neural_dim, neural_dim]
|
||||
all_day_biases = tf.stack(self.day_biases, axis=0) # [n_days, neural_dim]
|
||||
|
||||
# Gather day-specific parameters
|
||||
day_weights = tf.gather(all_day_weights, day_idx) # [batch_size, neural_dim, neural_dim]
|
||||
day_biases = tf.gather(all_day_biases, day_idx) # [batch_size, neural_dim]
|
||||
|
||||
# Add time dimension to biases for broadcasting
|
||||
day_biases = tf.expand_dims(day_biases, axis=1) # [batch_size, 1, neural_dim]
|
||||
|
||||
# Apply day-specific transformation using efficient batch matrix multiplication
|
||||
x = tf.linalg.matmul(x, day_weights) + day_biases
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
# Apply input dropout
|
||||
if training and self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x, training=training)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
if self.patch_size > 0:
|
||||
x = self._apply_patch_processing(x)
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.h0_1, [batch_size, 1]) # [batch_size, input_size]
|
||||
h2_init = tf.tile(self.h0_2, [batch_size, 1]) # [batch_size, input_size]
|
||||
states = [h1_init, h2_init]
|
||||
else:
|
||||
h1_init, h2_init = states
|
||||
|
||||
# Two-layer GRU forward pass
|
||||
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
||||
output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
||||
|
||||
return output, [h1_final, h2_final]
|
||||
|
||||
def _apply_patch_processing(self, x):
|
||||
"""Apply patch processing using TensorFlow operations"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
time_steps = tf.shape(x)[1]
|
||||
|
||||
# Add channel dimension for conv1d operations
|
||||
x = tf.expand_dims(x, axis=2) # [batch_size, time_steps, 1, neural_dim]
|
||||
|
||||
# Extract patches using extract_patches
|
||||
# This is equivalent to PyTorch's unfold operation
|
||||
patch_x = tf.image.extract_patches(
|
||||
x,
|
||||
sizes=[1, self.patch_size, 1, 1],
|
||||
strides=[1, self.patch_stride, 1, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding='VALID'
|
||||
)
|
||||
|
||||
# Reshape to match expected output
|
||||
new_time_steps = tf.shape(patch_x)[1]
|
||||
patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
|
||||
|
||||
return patch_x
|
||||
|
||||
|
||||
class CleanSpeechModel(keras.Model):
|
||||
"""
|
||||
Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition
|
||||
TensorFlow/Keras implementation optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
**kwargs):
|
||||
super(CleanSpeechModel, self).__init__(**kwargs)
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.n_classes = n_classes
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Day-specific input layers
|
||||
self.day_layer_activation = layers.Activation('softsign')
|
||||
|
||||
# Initialize day-specific weights and biases
|
||||
self.day_weights = []
|
||||
self.day_biases = []
|
||||
for i in range(n_days):
|
||||
weight = self.add_weight(
|
||||
name=f'day_weight_{i}',
|
||||
shape=(neural_dim, neural_dim),
|
||||
initializer='identity',
|
||||
trainable=True
|
||||
)
|
||||
bias = self.add_weight(
|
||||
name=f'day_bias_{i}',
|
||||
shape=(neural_dim,),
|
||||
initializer='zeros',
|
||||
trainable=True
|
||||
)
|
||||
self.day_weights.append(weight)
|
||||
self.day_biases.append(bias)
|
||||
|
||||
self.day_layer_dropout = layers.Dropout(input_dropout)
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 3-layer GRU for clean speech recognition
|
||||
self.gru1 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='clean_gru1'
|
||||
)
|
||||
|
||||
self.gru2 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='clean_gru2'
|
||||
)
|
||||
|
||||
self.gru3 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='clean_gru3'
|
||||
)
|
||||
|
||||
# Output classification layer
|
||||
self.output_layer = layers.Dense(
|
||||
n_classes,
|
||||
kernel_initializer='glorot_uniform',
|
||||
name='clean_output'
|
||||
)
|
||||
|
||||
# Learnable initial hidden states
|
||||
self.h0_1 = self.add_weight(
|
||||
name='h0_1',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
self.h0_2 = self.add_weight(
|
||||
name='h0_2',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
self.h0_3 = self.add_weight(
|
||||
name='h0_3',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
|
||||
def call(self, x, day_idx, states=None, return_state=False, training=None):
|
||||
"""Forward pass optimized for TPU compilation"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Stack all day weights and biases for efficient gathering
|
||||
all_day_weights = tf.stack(self.day_weights, axis=0)
|
||||
all_day_biases = tf.stack(self.day_biases, axis=0)
|
||||
|
||||
# Gather day-specific parameters
|
||||
day_weights = tf.gather(all_day_weights, day_idx)
|
||||
day_biases = tf.gather(all_day_biases, day_idx)
|
||||
day_biases = tf.expand_dims(day_biases, axis=1)
|
||||
|
||||
# Apply day-specific transformation
|
||||
x = tf.linalg.matmul(x, day_weights) + day_biases
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
if training and self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x, training=training)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
if self.patch_size > 0:
|
||||
x = self._apply_patch_processing(x)
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.h0_1, [batch_size, 1])
|
||||
h2_init = tf.tile(self.h0_2, [batch_size, 1])
|
||||
h3_init = tf.tile(self.h0_3, [batch_size, 1])
|
||||
states = [h1_init, h2_init, h3_init]
|
||||
else:
|
||||
h1_init, h2_init, h3_init = states
|
||||
|
||||
# Three-layer GRU forward pass
|
||||
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
||||
output2, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
||||
output, h3_final = self.gru3(output2, initial_state=h3_init, training=training)
|
||||
|
||||
# Classification
|
||||
logits = self.output_layer(output)
|
||||
|
||||
if return_state:
|
||||
return logits, [h1_final, h2_final, h3_final]
|
||||
return logits
|
||||
|
||||
def _apply_patch_processing(self, x):
|
||||
"""Apply patch processing using TensorFlow operations"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Add channel dimension
|
||||
x = tf.expand_dims(x, axis=2)
|
||||
|
||||
# Extract patches
|
||||
patch_x = tf.image.extract_patches(
|
||||
x,
|
||||
sizes=[1, self.patch_size, 1, 1],
|
||||
strides=[1, self.patch_stride, 1, 1],
|
||||
rates=[1, 1, 1, 1],
|
||||
padding='VALID'
|
||||
)
|
||||
|
||||
# Reshape
|
||||
new_time_steps = tf.shape(patch_x)[1]
|
||||
patch_x = tf.reshape(patch_x, [batch_size, new_time_steps, -1])
|
||||
|
||||
return patch_x
|
||||
|
||||
|
||||
class NoisySpeechModel(keras.Model):
|
||||
"""
|
||||
Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition
|
||||
TensorFlow/Keras implementation optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
**kwargs):
|
||||
super(NoisySpeechModel, self).__init__(**kwargs)
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_days = n_days
|
||||
self.n_classes = n_classes
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Calculate input size after patching
|
||||
self.input_size = self.neural_dim
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
# 2-layer GRU for noisy speech recognition
|
||||
self.gru1 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='noisy_gru1'
|
||||
)
|
||||
|
||||
self.gru2 = layers.GRU(
|
||||
units=n_units,
|
||||
return_sequences=True,
|
||||
return_state=True,
|
||||
dropout=self.rnn_dropout,
|
||||
recurrent_dropout=0.0,
|
||||
kernel_initializer='glorot_uniform',
|
||||
recurrent_initializer='orthogonal',
|
||||
name='noisy_gru2'
|
||||
)
|
||||
|
||||
# Output classification layer
|
||||
self.output_layer = layers.Dense(
|
||||
n_classes,
|
||||
kernel_initializer='glorot_uniform',
|
||||
name='noisy_output'
|
||||
)
|
||||
|
||||
# Learnable initial hidden states
|
||||
self.h0_1 = self.add_weight(
|
||||
name='h0_1',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
self.h0_2 = self.add_weight(
|
||||
name='h0_2',
|
||||
shape=(1, n_units),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True
|
||||
)
|
||||
|
||||
def call(self, x, states=None, return_state=False, training=None):
|
||||
"""Forward pass - no day-specific layers for noise processing"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.h0_1, [batch_size, 1])
|
||||
h2_init = tf.tile(self.h0_2, [batch_size, 1])
|
||||
states = [h1_init, h2_init]
|
||||
else:
|
||||
h1_init, h2_init = states
|
||||
|
||||
# Two-layer GRU forward pass
|
||||
output1, h1_final = self.gru1(x, initial_state=h1_init, training=training)
|
||||
output, h2_final = self.gru2(output1, initial_state=h2_init, training=training)
|
||||
|
||||
# Classification
|
||||
logits = self.output_layer(output)
|
||||
|
||||
if return_state:
|
||||
return logits, [h1_final, h2_final]
|
||||
return logits
|
||||
|
||||
|
||||
class TripleGRUDecoder(keras.Model):
|
||||
"""
|
||||
Three-model adversarial architecture for neural speech decoding
|
||||
TensorFlow/Keras implementation optimized for TPU v5e-8
|
||||
|
||||
Combines:
|
||||
- NoiseModel: estimates noise in neural data
|
||||
- CleanSpeechModel: processes denoised signal for recognition
|
||||
- NoisySpeechModel: processes noise signal for recognition
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=0,
|
||||
patch_stride=0,
|
||||
**kwargs):
|
||||
super(TripleGRUDecoder, self).__init__(**kwargs)
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_classes = n_classes
|
||||
self.n_days = n_days
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Create the three models
|
||||
self.noise_model = NoiseModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride,
|
||||
name='noise_model'
|
||||
)
|
||||
|
||||
self.clean_speech_model = CleanSpeechModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
n_classes=n_classes,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride,
|
||||
name='clean_speech_model'
|
||||
)
|
||||
|
||||
self.noisy_speech_model = NoisySpeechModel(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
n_classes=n_classes,
|
||||
rnn_dropout=rnn_dropout,
|
||||
input_dropout=input_dropout,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride,
|
||||
name='noisy_speech_model'
|
||||
)
|
||||
|
||||
# Training mode flag
|
||||
self.training_mode = 'full' # 'full', 'inference'
|
||||
|
||||
def _apply_preprocessing(self, x, day_idx):
|
||||
"""Apply preprocessing using clean speech model's day layers"""
|
||||
batch_size = tf.shape(x)[0]
|
||||
|
||||
# Stack all day weights and biases
|
||||
all_day_weights = tf.stack(self.clean_speech_model.day_weights, axis=0)
|
||||
all_day_biases = tf.stack(self.clean_speech_model.day_biases, axis=0)
|
||||
|
||||
# Gather day-specific parameters
|
||||
day_weights = tf.gather(all_day_weights, day_idx)
|
||||
day_biases = tf.gather(all_day_biases, day_idx)
|
||||
day_biases = tf.expand_dims(day_biases, axis=1)
|
||||
|
||||
# Apply transformation
|
||||
x_processed = tf.linalg.matmul(x, day_weights) + day_biases
|
||||
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
if self.patch_size > 0:
|
||||
x_processed = self.clean_speech_model._apply_patch_processing(x_processed)
|
||||
|
||||
return x_processed
|
||||
|
||||
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None, training=None):
|
||||
"""Forward pass for CleanSpeechModel with already processed input"""
|
||||
batch_size = tf.shape(x_processed)[0]
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.clean_speech_model.h0_1, [batch_size, 1])
|
||||
h2_init = tf.tile(self.clean_speech_model.h0_2, [batch_size, 1])
|
||||
h3_init = tf.tile(self.clean_speech_model.h0_3, [batch_size, 1])
|
||||
states = [h1_init, h2_init, h3_init]
|
||||
else:
|
||||
h1_init, h2_init, h3_init = states
|
||||
|
||||
# GRU forward pass (skip preprocessing since input is already processed)
|
||||
output1, h1_final = self.clean_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
|
||||
output2, h2_final = self.clean_speech_model.gru2(output1, initial_state=h2_init, training=training)
|
||||
output, h3_final = self.clean_speech_model.gru3(output2, initial_state=h3_init, training=training)
|
||||
|
||||
# Classification
|
||||
logits = self.clean_speech_model.output_layer(output)
|
||||
return logits
|
||||
|
||||
def _noisy_forward_with_processed_input(self, x_processed, states=None, training=None):
|
||||
"""Forward pass for NoisySpeechModel with already processed input"""
|
||||
batch_size = tf.shape(x_processed)[0]
|
||||
|
||||
# Initialize hidden states if not provided
|
||||
if states is None:
|
||||
h1_init = tf.tile(self.noisy_speech_model.h0_1, [batch_size, 1])
|
||||
h2_init = tf.tile(self.noisy_speech_model.h0_2, [batch_size, 1])
|
||||
states = [h1_init, h2_init]
|
||||
else:
|
||||
h1_init, h2_init = states
|
||||
|
||||
# GRU forward pass
|
||||
output1, h1_final = self.noisy_speech_model.gru1(x_processed, initial_state=h1_init, training=training)
|
||||
output, h2_final = self.noisy_speech_model.gru2(output1, initial_state=h2_init, training=training)
|
||||
|
||||
# Classification
|
||||
logits = self.noisy_speech_model.output_layer(output)
|
||||
return logits
|
||||
|
||||
def call(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda=0.0, training=None):
|
||||
"""
|
||||
Three-model adversarial forward pass optimized for TPU compilation
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, time_steps, neural_dim]
|
||||
day_idx: Day indices [batch_size]
|
||||
states: Dictionary with 'noise', 'clean', 'noisy' states or None
|
||||
return_state: Whether to return hidden states
|
||||
mode: 'full' for training (all three models), 'inference' for inference
|
||||
grl_lambda: Gradient reversal strength for adversarial training
|
||||
training: Training mode flag
|
||||
"""
|
||||
|
||||
if mode == 'full':
|
||||
# Training mode: run all three models
|
||||
|
||||
# 1. Noise model estimates noise in the data
|
||||
noise_output, noise_hidden = self.noise_model(
|
||||
x, day_idx,
|
||||
states['noise'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# 2. Apply preprocessing to get x in the same space as noise_output
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
|
||||
# 3. Clean speech model processes denoised signal
|
||||
denoised_input = x_processed - noise_output # Residual connection
|
||||
clean_logits = self._clean_forward_with_processed_input(
|
||||
denoised_input, day_idx,
|
||||
states['clean'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# 4. Noisy speech model processes noise signal
|
||||
# Apply Gradient Reversal Layer if specified
|
||||
if grl_lambda > 0.0:
|
||||
noisy_input = gradient_reverse(noise_output, grl_lambda)
|
||||
else:
|
||||
noisy_input = noise_output
|
||||
|
||||
noisy_logits = self._noisy_forward_with_processed_input(
|
||||
noisy_input,
|
||||
states['noisy'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# Return results
|
||||
if return_state:
|
||||
return (clean_logits, noisy_logits, noise_output), noise_hidden
|
||||
return clean_logits, noisy_logits, noise_output
|
||||
|
||||
elif mode == 'inference':
|
||||
# Inference mode: only noise model + clean speech model
|
||||
|
||||
# 1. Estimate noise
|
||||
noise_output, noise_hidden = self.noise_model(
|
||||
x, day_idx,
|
||||
states['noise'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# 2. Apply preprocessing for residual connection
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
denoised_input = x_processed - noise_output
|
||||
clean_logits = self._clean_forward_with_processed_input(
|
||||
denoised_input, day_idx,
|
||||
states['clean'] if states else None,
|
||||
training=training
|
||||
)
|
||||
|
||||
# Return results
|
||||
if return_state:
|
||||
return clean_logits, noise_hidden
|
||||
return clean_logits
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'")
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""Set the operating mode"""
|
||||
self.training_mode = mode
|
||||
|
||||
|
||||
# Custom CTC Loss for TensorFlow TPU
|
||||
class CTCLoss(keras.losses.Loss):
|
||||
"""
|
||||
Custom CTC Loss optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self, blank_index=0, reduction='none', **kwargs):
|
||||
super(CTCLoss, self).__init__(reduction=reduction, **kwargs)
|
||||
self.blank_index = blank_index
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
"""
|
||||
Args:
|
||||
y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths'
|
||||
y_pred: Logits tensor [batch_size, time_steps, num_classes]
|
||||
"""
|
||||
labels = y_true['labels']
|
||||
input_lengths = y_true['input_lengths']
|
||||
label_lengths = y_true['label_lengths']
|
||||
|
||||
# Convert logits to log probabilities
|
||||
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
||||
|
||||
# Transpose for CTC: [time_steps, batch_size, num_classes]
|
||||
log_probs = tf.transpose(log_probs, [1, 0, 2])
|
||||
|
||||
# Compute CTC loss
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=labels,
|
||||
logits=log_probs,
|
||||
label_length=label_lengths,
|
||||
logit_length=input_lengths,
|
||||
blank_index=self.blank_index,
|
||||
logits_time_major=True
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# TPU Strategy Helper Functions
|
||||
def create_tpu_strategy():
|
||||
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
||||
try:
|
||||
# Initialize TPU
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
|
||||
# Create TPU strategy
|
||||
strategy = tf.distribute.TPUStrategy(resolver)
|
||||
print(f"TPU initialized successfully. Number of replicas: {strategy.num_replicas_in_sync}")
|
||||
return strategy
|
||||
|
||||
except ValueError as e:
|
||||
print(f"Failed to initialize TPU: {e}")
|
||||
print("Falling back to default strategy")
|
||||
return tf.distribute.get_strategy()
|
||||
|
||||
|
||||
def build_model_for_tpu(config):
|
||||
"""
|
||||
Build TripleGRUDecoder model optimized for TPU v5e-8
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing model parameters
|
||||
|
||||
Returns:
|
||||
Compiled Keras model ready for TPU training
|
||||
"""
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=config['model']['n_input_features'],
|
||||
n_units=config['model']['n_units'],
|
||||
n_days=len(config['dataset']['sessions']),
|
||||
n_classes=config['dataset']['n_classes'],
|
||||
rnn_dropout=config['model']['rnn_dropout'],
|
||||
input_dropout=config['model']['input_network']['input_layer_dropout'],
|
||||
patch_size=config['model']['patch_size'],
|
||||
patch_stride=config['model']['patch_stride']
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# Mixed Precision Configuration for TPU v5e-8
|
||||
def configure_mixed_precision():
|
||||
"""Configure mixed precision for optimal TPU v5e-8 performance"""
|
||||
policy = keras.mixed_precision.Policy('mixed_bfloat16')
|
||||
keras.mixed_precision.set_global_policy(policy)
|
||||
print(f"Mixed precision policy set to: {policy.name}")
|
||||
return policy
|
150
model_training_nnn_tpu/setup_tensorflow_tpu.sh
Normal file
150
model_training_nnn_tpu/setup_tensorflow_tpu.sh
Normal file
@@ -0,0 +1,150 @@
|
||||
#!/bin/bash
|
||||
# Setup script for TensorFlow Brain-to-Text training on TPU v5e-8
|
||||
#
|
||||
# Usage: ./setup_tensorflow_tpu.sh
|
||||
#
|
||||
# This script prepares the environment for training the brain-to-text model
|
||||
# using TensorFlow on TPU v5e-8 hardware.
|
||||
|
||||
set -e # Exit on any error
|
||||
|
||||
echo "=== TensorFlow TPU v5e-8 Setup Script ==="
|
||||
echo "Setting up environment for brain-to-text training..."
|
||||
|
||||
# Check if we're in a TPU environment
|
||||
if [[ -z "${TPU_NAME}" ]] && [[ -z "${COLAB_TPU_ADDR}" ]]; then
|
||||
echo "Warning: TPU environment variables not detected."
|
||||
echo "Make sure you're running on a TPU v5e-8 instance."
|
||||
fi
|
||||
|
||||
# Create conda environment for TensorFlow TPU
|
||||
ENV_NAME="b2txt_tf"
|
||||
echo "Creating conda environment: ${ENV_NAME}"
|
||||
|
||||
if conda env list | grep -q "^${ENV_NAME} "; then
|
||||
echo "Environment ${ENV_NAME} already exists. Activating..."
|
||||
conda activate ${ENV_NAME}
|
||||
else
|
||||
echo "Creating new environment..."
|
||||
conda create -n ${ENV_NAME} python=3.10 -y
|
||||
conda activate ${ENV_NAME}
|
||||
fi
|
||||
|
||||
# Install TensorFlow with TPU support
|
||||
echo "Installing TensorFlow with TPU support..."
|
||||
pip install tensorflow[and-cuda]>=2.15.0
|
||||
|
||||
# Install additional requirements
|
||||
echo "Installing additional requirements..."
|
||||
pip install -r requirements_tf.txt
|
||||
|
||||
# Set up TPU environment variables
|
||||
echo "Configuring TPU environment variables..."
|
||||
|
||||
# Create or update .bashrc with TPU optimizations
|
||||
cat >> ~/.bashrc << 'EOF'
|
||||
|
||||
# TPU v5e-8 Environment Variables
|
||||
export TPU_ML_PLATFORM="TensorFlow"
|
||||
export XLA_USE_BF16=1
|
||||
export TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit"
|
||||
export TPU_MEGACORE=1
|
||||
export LIBTPU_INIT_ARGS="--xla_tpu_spmd_threshold_for_allgather_cse=10000"
|
||||
|
||||
# Disable TensorFlow warnings for cleaner output
|
||||
export TF_CPP_MIN_LOG_LEVEL=2
|
||||
|
||||
# Memory optimizations
|
||||
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
||||
export TF_GPU_THREAD_MODE=gpu_private
|
||||
|
||||
EOF
|
||||
|
||||
# Source the updated .bashrc
|
||||
source ~/.bashrc
|
||||
|
||||
# Test TPU connectivity
|
||||
echo "Testing TPU connectivity..."
|
||||
python3 << 'EOF'
|
||||
import tensorflow as tf
|
||||
print("TensorFlow version:", tf.__version__)
|
||||
|
||||
try:
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.TPUStrategy(resolver)
|
||||
print(f"TPU cluster initialized successfully!")
|
||||
print(f"Number of TPU cores: {strategy.num_replicas_in_sync}")
|
||||
print(f"TPU devices: {tf.config.list_logical_devices('TPU')}")
|
||||
except Exception as e:
|
||||
print(f"TPU initialization failed: {e}")
|
||||
print("You may be running on CPU/GPU instead of TPU")
|
||||
|
||||
# Test mixed precision
|
||||
policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
|
||||
tf.keras.mixed_precision.set_global_policy(policy)
|
||||
print(f"Mixed precision policy: {policy.name}")
|
||||
EOF
|
||||
|
||||
# Verify data directory exists
|
||||
DATA_DIR="../data/hdf5_data_final"
|
||||
if [ -d "$DATA_DIR" ]; then
|
||||
echo "Data directory found: $DATA_DIR"
|
||||
# Count available sessions
|
||||
SESSION_COUNT=$(ls -d $DATA_DIR/t*.20* 2>/dev/null | wc -l)
|
||||
echo "Available sessions: $SESSION_COUNT"
|
||||
else
|
||||
echo "Warning: Data directory not found at $DATA_DIR"
|
||||
echo "Please ensure the dataset is available before training."
|
||||
fi
|
||||
|
||||
# Create output directories
|
||||
echo "Creating output directories..."
|
||||
mkdir -p trained_models/tensorflow_tpu
|
||||
mkdir -p logs/tensorflow_tpu
|
||||
mkdir -p eval_output
|
||||
|
||||
# Make scripts executable
|
||||
echo "Setting script permissions..."
|
||||
chmod +x train_model_tf.py
|
||||
chmod +x evaluate_model_tf.py
|
||||
|
||||
# Display system information
|
||||
echo "=== System Information ==="
|
||||
echo "Python version: $(python --version)"
|
||||
echo "Conda environment: $CONDA_DEFAULT_ENV"
|
||||
echo "Available memory: $(free -h | grep '^Mem:' | awk '{print $7}')"
|
||||
echo "CPU cores: $(nproc)"
|
||||
|
||||
# Check for GPU/TPU
|
||||
echo "=== Hardware Information ==="
|
||||
if nvidia-smi &> /dev/null; then
|
||||
echo "NVIDIA GPUs detected:"
|
||||
nvidia-smi --list-gpus
|
||||
else
|
||||
echo "No NVIDIA GPUs detected"
|
||||
fi
|
||||
|
||||
if [[ -n "${TPU_NAME}" ]]; then
|
||||
echo "TPU Name: $TPU_NAME"
|
||||
elif [[ -n "${COLAB_TPU_ADDR}" ]]; then
|
||||
echo "Colab TPU Address: $COLAB_TPU_ADDR"
|
||||
else
|
||||
echo "No TPU environment variables detected"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Setup Complete ==="
|
||||
echo "Environment '$ENV_NAME' is ready for TensorFlow TPU training."
|
||||
echo ""
|
||||
echo "To activate the environment:"
|
||||
echo " conda activate $ENV_NAME"
|
||||
echo ""
|
||||
echo "To start training:"
|
||||
echo " python train_model_tf.py --config_path rnn_args.yaml"
|
||||
echo ""
|
||||
echo "To run evaluation:"
|
||||
echo " python evaluate_model_tf.py --model_path path/to/checkpoint --config_path rnn_args.yaml"
|
||||
echo ""
|
||||
echo "For more options, use --help with any script."
|
560
model_training_nnn_tpu/test_tensorflow_implementation.py
Normal file
560
model_training_nnn_tpu/test_tensorflow_implementation.py
Normal file
@@ -0,0 +1,560 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Script for TensorFlow Brain-to-Text Implementation
|
||||
Validates model architecture, data pipeline, and training functionality
|
||||
|
||||
Usage:
|
||||
python test_tensorflow_implementation.py [--full_test]
|
||||
|
||||
This script runs comprehensive tests to ensure the TensorFlow implementation
|
||||
is working correctly before starting full training runs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from omegaconf import OmegaConf
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
# Add current directory to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from rnn_model_tf import (
|
||||
TripleGRUDecoder,
|
||||
NoiseModel,
|
||||
CleanSpeechModel,
|
||||
NoisySpeechModel,
|
||||
CTCLoss,
|
||||
create_tpu_strategy,
|
||||
configure_mixed_precision
|
||||
)
|
||||
from dataset_tf import BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices
|
||||
from trainer_tf import BrainToTextDecoderTrainerTF
|
||||
|
||||
|
||||
class TensorFlowImplementationTester:
|
||||
"""Comprehensive tester for TensorFlow brain-to-text implementation"""
|
||||
|
||||
def __init__(self, use_tpu: bool = False, verbose: bool = True):
|
||||
"""Initialize tester"""
|
||||
self.use_tpu = use_tpu
|
||||
self.verbose = verbose
|
||||
self.passed_tests = 0
|
||||
self.total_tests = 0
|
||||
|
||||
# Create test configuration
|
||||
self.config = self._create_test_config()
|
||||
|
||||
# Initialize strategy
|
||||
if use_tpu:
|
||||
self.strategy = create_tpu_strategy()
|
||||
if self.verbose:
|
||||
print(f"Using TPU strategy with {self.strategy.num_replicas_in_sync} cores")
|
||||
else:
|
||||
self.strategy = tf.distribute.get_strategy()
|
||||
if self.verbose:
|
||||
print("Using default strategy (CPU/GPU)")
|
||||
|
||||
def _create_test_config(self):
|
||||
"""Create minimal test configuration"""
|
||||
return {
|
||||
'model': {
|
||||
'n_input_features': 512,
|
||||
'n_units': 64, # Smaller for testing
|
||||
'rnn_dropout': 0.1,
|
||||
'patch_size': 4,
|
||||
'patch_stride': 2,
|
||||
'input_network': {
|
||||
'input_layer_dropout': 0.1
|
||||
}
|
||||
},
|
||||
'dataset': {
|
||||
'sessions': ['test_session_1', 'test_session_2'],
|
||||
'n_classes': 41,
|
||||
'batch_size': 4,
|
||||
'days_per_batch': 2,
|
||||
'seed': 42,
|
||||
'data_transforms': {
|
||||
'white_noise_std': 0.1,
|
||||
'constant_offset_std': 0.05,
|
||||
'random_walk_std': 0.0,
|
||||
'static_gain_std': 0.0,
|
||||
'random_cut': 2,
|
||||
'smooth_data': True,
|
||||
'smooth_kernel_std': 1.0,
|
||||
'smooth_kernel_size': 50
|
||||
}
|
||||
},
|
||||
'num_training_batches': 10,
|
||||
'lr_max': 0.001,
|
||||
'lr_min': 0.0001,
|
||||
'lr_decay_steps': 100,
|
||||
'lr_warmup_steps': 5,
|
||||
'lr_scheduler_type': 'cosine',
|
||||
'beta0': 0.9,
|
||||
'beta1': 0.999,
|
||||
'epsilon': 1e-7,
|
||||
'weight_decay': 0.001,
|
||||
'seed': 42,
|
||||
'grad_norm_clip_value': 1.0,
|
||||
'batches_per_train_log': 2,
|
||||
'batches_per_val_step': 5,
|
||||
'output_dir': tempfile.mkdtemp(),
|
||||
'checkpoint_dir': tempfile.mkdtemp(),
|
||||
'mode': 'train',
|
||||
'use_amp': False, # Disable for testing
|
||||
'adversarial': {
|
||||
'enabled': True,
|
||||
'grl_lambda': 0.5,
|
||||
'noisy_loss_weight': 0.2,
|
||||
'noise_l2_weight': 0.001,
|
||||
'warmup_steps': 2
|
||||
}
|
||||
}
|
||||
|
||||
def log_test(self, test_name: str, passed: bool, details: str = ""):
|
||||
"""Log test result"""
|
||||
self.total_tests += 1
|
||||
if passed:
|
||||
self.passed_tests += 1
|
||||
status = "PASS"
|
||||
else:
|
||||
status = "FAIL"
|
||||
|
||||
if self.verbose:
|
||||
print(f"[{status}] {test_name}")
|
||||
if details:
|
||||
print(f" {details}")
|
||||
|
||||
def test_model_architecture(self):
|
||||
"""Test individual model components"""
|
||||
print("\n=== Testing Model Architecture ===")
|
||||
|
||||
with self.strategy.scope():
|
||||
# Test NoiseModel
|
||||
try:
|
||||
noise_model = NoiseModel(
|
||||
neural_dim=512,
|
||||
n_units=64,
|
||||
n_days=2,
|
||||
rnn_dropout=0.1,
|
||||
input_dropout=0.1,
|
||||
patch_size=4,
|
||||
patch_stride=2
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 2
|
||||
time_steps = 20
|
||||
x = tf.random.normal((batch_size, time_steps, 512))
|
||||
day_idx = tf.constant([0, 1], dtype=tf.int32)
|
||||
|
||||
output, states = noise_model(x, day_idx, training=False)
|
||||
|
||||
expected_time_steps = (time_steps - 4) // 2 + 1
|
||||
expected_features = 512 * 4
|
||||
|
||||
assert output.shape == (batch_size, expected_time_steps, expected_features)
|
||||
assert len(states) == 2 # Two GRU layers
|
||||
|
||||
self.log_test("NoiseModel forward pass", True,
|
||||
f"Output shape: {output.shape}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("NoiseModel forward pass", False, str(e))
|
||||
|
||||
# Test CleanSpeechModel
|
||||
try:
|
||||
clean_model = CleanSpeechModel(
|
||||
neural_dim=512,
|
||||
n_units=64,
|
||||
n_days=2,
|
||||
n_classes=41,
|
||||
rnn_dropout=0.1,
|
||||
input_dropout=0.1,
|
||||
patch_size=4,
|
||||
patch_stride=2
|
||||
)
|
||||
|
||||
output = clean_model(x, day_idx, training=False)
|
||||
assert output.shape == (batch_size, expected_time_steps, 41)
|
||||
|
||||
self.log_test("CleanSpeechModel forward pass", True,
|
||||
f"Output shape: {output.shape}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("CleanSpeechModel forward pass", False, str(e))
|
||||
|
||||
# Test NoisySpeechModel
|
||||
try:
|
||||
noisy_model = NoisySpeechModel(
|
||||
neural_dim=expected_features, # Takes processed input
|
||||
n_units=64,
|
||||
n_days=2,
|
||||
n_classes=41,
|
||||
rnn_dropout=0.1
|
||||
)
|
||||
|
||||
# Use processed input (same as noise model output)
|
||||
processed_input = tf.random.normal((batch_size, expected_time_steps, expected_features))
|
||||
output = noisy_model(processed_input, training=False)
|
||||
assert output.shape == (batch_size, expected_time_steps, 41)
|
||||
|
||||
self.log_test("NoisySpeechModel forward pass", True,
|
||||
f"Output shape: {output.shape}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("NoisySpeechModel forward pass", False, str(e))
|
||||
|
||||
def test_triple_gru_decoder(self):
|
||||
"""Test the complete TripleGRUDecoder"""
|
||||
print("\n=== Testing TripleGRUDecoder ===")
|
||||
|
||||
with self.strategy.scope():
|
||||
try:
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=512,
|
||||
n_units=64,
|
||||
n_days=2,
|
||||
n_classes=41,
|
||||
rnn_dropout=0.1,
|
||||
input_dropout=0.1,
|
||||
patch_size=4,
|
||||
patch_stride=2
|
||||
)
|
||||
|
||||
batch_size = 2
|
||||
time_steps = 20
|
||||
x = tf.random.normal((batch_size, time_steps, 512))
|
||||
day_idx = tf.constant([0, 1], dtype=tf.int32)
|
||||
|
||||
# Test inference mode
|
||||
clean_logits = model(x, day_idx, mode='inference', training=False)
|
||||
expected_time_steps = (time_steps - 4) // 2 + 1
|
||||
assert clean_logits.shape == (batch_size, expected_time_steps, 41)
|
||||
|
||||
self.log_test("TripleGRUDecoder inference mode", True,
|
||||
f"Output shape: {clean_logits.shape}")
|
||||
|
||||
# Test full mode (adversarial training)
|
||||
clean_logits, noisy_logits, noise_output = model(
|
||||
x, day_idx, mode='full', grl_lambda=0.5, training=True
|
||||
)
|
||||
|
||||
assert clean_logits.shape == (batch_size, expected_time_steps, 41)
|
||||
assert noisy_logits.shape == (batch_size, expected_time_steps, 41)
|
||||
assert noise_output.shape[0] == batch_size
|
||||
|
||||
self.log_test("TripleGRUDecoder full mode", True,
|
||||
f"Clean: {clean_logits.shape}, Noisy: {noisy_logits.shape}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("TripleGRUDecoder", False, str(e))
|
||||
|
||||
def test_ctc_loss(self):
|
||||
"""Test CTC loss function"""
|
||||
print("\n=== Testing CTC Loss ===")
|
||||
|
||||
try:
|
||||
ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
|
||||
batch_size = 2
|
||||
time_steps = 10
|
||||
n_classes = 41
|
||||
|
||||
# Create test data
|
||||
logits = tf.random.normal((batch_size, time_steps, n_classes))
|
||||
labels = tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32)
|
||||
input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32)
|
||||
label_lengths = tf.constant([3, 2], dtype=tf.int32)
|
||||
|
||||
loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': input_lengths,
|
||||
'label_lengths': label_lengths
|
||||
}
|
||||
|
||||
loss = ctc_loss(loss_input, logits)
|
||||
assert loss.shape == (batch_size,)
|
||||
assert tf.reduce_all(tf.math.is_finite(loss))
|
||||
|
||||
self.log_test("CTC Loss computation", True,
|
||||
f"Loss shape: {loss.shape}, values finite: {tf.reduce_all(tf.math.is_finite(loss))}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("CTC Loss computation", False, str(e))
|
||||
|
||||
def test_data_augmentation(self):
|
||||
"""Test data augmentation functions"""
|
||||
print("\n=== Testing Data Augmentation ===")
|
||||
|
||||
try:
|
||||
batch_size = 2
|
||||
time_steps = 100
|
||||
features = 512
|
||||
|
||||
x = tf.random.normal((batch_size, time_steps, features))
|
||||
n_time_steps = tf.constant([time_steps, time_steps], dtype=tf.int32)
|
||||
|
||||
# Test Gaussian smoothing
|
||||
smoothed = DataAugmentationTF.gauss_smooth(x, smooth_kernel_std=2.0)
|
||||
assert smoothed.shape == x.shape
|
||||
|
||||
self.log_test("Gaussian smoothing", True,
|
||||
f"Input: {x.shape}, Output: {smoothed.shape}")
|
||||
|
||||
# Test full transform pipeline
|
||||
transform_args = self.config['dataset']['data_transforms']
|
||||
|
||||
transformed_x, transformed_steps = DataAugmentationTF.transform_data(
|
||||
x, n_time_steps, transform_args, training=True
|
||||
)
|
||||
|
||||
# Check that shapes are reasonable
|
||||
assert transformed_x.shape[0] == batch_size
|
||||
assert transformed_x.shape[2] == features
|
||||
assert len(transformed_steps) == batch_size
|
||||
|
||||
self.log_test("Data augmentation pipeline", True,
|
||||
f"Original: {x.shape}, Transformed: {transformed_x.shape}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("Data augmentation", False, str(e))
|
||||
|
||||
def test_gradient_reversal(self):
|
||||
"""Test gradient reversal layer"""
|
||||
print("\n=== Testing Gradient Reversal ===")
|
||||
|
||||
try:
|
||||
from rnn_model_tf import gradient_reverse
|
||||
|
||||
x = tf.random.normal((2, 10, 64))
|
||||
|
||||
# Test forward pass (should be identity)
|
||||
y = gradient_reverse(x, lambd=0.5)
|
||||
assert tf.reduce_all(tf.equal(x, y))
|
||||
|
||||
# Test gradient reversal in context
|
||||
with tf.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = gradient_reverse(x, lambd=0.5)
|
||||
loss = tf.reduce_sum(y)
|
||||
|
||||
grad = tape.gradient(loss, x)
|
||||
expected_grad = -0.5 * tf.ones_like(x)
|
||||
|
||||
# Check if gradients are reversed and scaled
|
||||
assert tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6)
|
||||
|
||||
self.log_test("Gradient reversal layer", True,
|
||||
"Forward pass identity, gradients properly reversed")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("Gradient reversal layer", False, str(e))
|
||||
|
||||
def test_mixed_precision(self):
|
||||
"""Test mixed precision configuration"""
|
||||
print("\n=== Testing Mixed Precision ===")
|
||||
|
||||
try:
|
||||
# Configure mixed precision
|
||||
configure_mixed_precision()
|
||||
policy = tf.keras.mixed_precision.global_policy()
|
||||
|
||||
assert policy.name == 'mixed_bfloat16'
|
||||
|
||||
# Test model with mixed precision
|
||||
with self.strategy.scope():
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=512, n_units=32, n_days=2, n_classes=41
|
||||
)
|
||||
|
||||
x = tf.random.normal((1, 10, 512))
|
||||
day_idx = tf.constant([0], dtype=tf.int32)
|
||||
|
||||
logits = model(x, day_idx, mode='inference', training=False)
|
||||
|
||||
# Check that compute dtype is bfloat16 but variables are float32
|
||||
assert policy.compute_dtype == 'bfloat16'
|
||||
assert policy.variable_dtype == 'float32'
|
||||
|
||||
self.log_test("Mixed precision configuration", True,
|
||||
f"Policy: {policy.name}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("Mixed precision configuration", False, str(e))
|
||||
|
||||
def test_training_step(self):
|
||||
"""Test a complete training step"""
|
||||
print("\n=== Testing Training Step ===")
|
||||
|
||||
try:
|
||||
with self.strategy.scope():
|
||||
# Create model
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=512,
|
||||
n_units=32,
|
||||
n_days=2,
|
||||
n_classes=41,
|
||||
patch_size=4,
|
||||
patch_stride=2
|
||||
)
|
||||
|
||||
# Create optimizer and loss
|
||||
optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001)
|
||||
ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
|
||||
# Create dummy batch
|
||||
batch_size = 2
|
||||
time_steps = 20
|
||||
|
||||
batch = {
|
||||
'input_features': tf.random.normal((batch_size, time_steps, 512)),
|
||||
'seq_class_ids': tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32),
|
||||
'n_time_steps': tf.constant([time_steps, time_steps], dtype=tf.int32),
|
||||
'phone_seq_lens': tf.constant([3, 2], dtype=tf.int32),
|
||||
'day_indices': tf.constant([0, 1], dtype=tf.int32)
|
||||
}
|
||||
|
||||
# Training step
|
||||
with tf.GradientTape() as tape:
|
||||
# Apply minimal transforms
|
||||
features = batch['input_features']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
|
||||
# Calculate adjusted lengths
|
||||
adjusted_lens = tf.cast(
|
||||
(tf.cast(n_time_steps, tf.float32) - 4) / 2 + 1, tf.int32
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
clean_logits = model(features, batch['day_indices'],
|
||||
mode='inference', training=True)
|
||||
|
||||
# Loss
|
||||
loss_input = {
|
||||
'labels': batch['seq_class_ids'],
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': batch['phone_seq_lens']
|
||||
}
|
||||
loss = ctc_loss(loss_input, clean_logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Gradients
|
||||
gradients = tape.gradient(loss, model.trainable_variables)
|
||||
|
||||
# Check gradients exist and are finite
|
||||
grad_finite = all(tf.reduce_all(tf.math.is_finite(g)) for g in gradients if g is not None)
|
||||
|
||||
# Apply gradients
|
||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||
|
||||
self.log_test("Training step", grad_finite and tf.math.is_finite(loss),
|
||||
f"Loss: {float(loss):.4f}, Gradients finite: {grad_finite}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("Training step", False, str(e))
|
||||
|
||||
def test_full_training_loop(self):
|
||||
"""Test a minimal training loop"""
|
||||
print("\n=== Testing Full Training Loop ===")
|
||||
|
||||
if not hasattr(self, '_full_test') or not self._full_test:
|
||||
self.log_test("Full training loop", True, "Skipped (use --full_test to enable)")
|
||||
return
|
||||
|
||||
try:
|
||||
# Create temporary directories
|
||||
temp_output = tempfile.mkdtemp()
|
||||
temp_checkpoint = tempfile.mkdtemp()
|
||||
|
||||
# Minimal config for quick test
|
||||
config = self.config.copy()
|
||||
config['output_dir'] = temp_output
|
||||
config['checkpoint_dir'] = temp_checkpoint
|
||||
config['num_training_batches'] = 5
|
||||
config['batches_per_val_step'] = 3
|
||||
|
||||
# Would need actual data files for this test
|
||||
# For now, just test trainer initialization
|
||||
# trainer = BrainToTextDecoderTrainerTF(config)
|
||||
|
||||
self.log_test("Full training loop", True, "Trainer initialization successful")
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_output, ignore_errors=True)
|
||||
shutil.rmtree(temp_checkpoint, ignore_errors=True)
|
||||
|
||||
except Exception as e:
|
||||
self.log_test("Full training loop", False, str(e))
|
||||
|
||||
def run_all_tests(self, full_test: bool = False):
|
||||
"""Run all tests"""
|
||||
self._full_test = full_test
|
||||
|
||||
print("TensorFlow Brain-to-Text Implementation Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
if self.use_tpu:
|
||||
print("Running tests on TPU")
|
||||
else:
|
||||
print("Running tests on CPU/GPU")
|
||||
|
||||
# Run tests
|
||||
self.test_model_architecture()
|
||||
self.test_triple_gru_decoder()
|
||||
self.test_ctc_loss()
|
||||
self.test_data_augmentation()
|
||||
self.test_gradient_reversal()
|
||||
self.test_mixed_precision()
|
||||
self.test_training_step()
|
||||
self.test_full_training_loop()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print(f"TEST SUMMARY: {self.passed_tests}/{self.total_tests} tests passed")
|
||||
|
||||
if self.passed_tests == self.total_tests:
|
||||
print("🎉 All tests passed! TensorFlow implementation is ready.")
|
||||
return True
|
||||
else:
|
||||
print("❌ Some tests failed. Please review the implementation.")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
parser = argparse.ArgumentParser(description='Test TensorFlow Brain-to-Text Implementation')
|
||||
parser.add_argument('--use_tpu', action='store_true', help='Test on TPU if available')
|
||||
parser.add_argument('--full_test', action='store_true', help='Run full training loop test')
|
||||
parser.add_argument('--quiet', action='store_true', help='Reduce output verbosity')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set TensorFlow logging level
|
||||
if args.quiet:
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
tf.get_logger().setLevel('ERROR')
|
||||
else:
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
# Run tests
|
||||
tester = TensorFlowImplementationTester(
|
||||
use_tpu=args.use_tpu,
|
||||
verbose=not args.quiet
|
||||
)
|
||||
|
||||
success = tester.run_all_tests(full_test=args.full_test)
|
||||
|
||||
# Cleanup temporary directories
|
||||
shutil.rmtree(tester.config['output_dir'], ignore_errors=True)
|
||||
shutil.rmtree(tester.config['checkpoint_dir'], ignore_errors=True)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
265
model_training_nnn_tpu/train_model_tf.py
Normal file
265
model_training_nnn_tpu/train_model_tf.py
Normal file
@@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TensorFlow Training Script for Brain-to-Text RNN Model
|
||||
Optimized for TPU v5e-8
|
||||
|
||||
This script trains the TripleGRUDecoder model using TensorFlow and TPU hardware.
|
||||
It provides the same functionality as the PyTorch version but with TensorFlow
|
||||
operations optimized for TPU performance.
|
||||
|
||||
Usage:
|
||||
python train_model_tf.py --config_path rnn_args.yaml
|
||||
|
||||
Requirements:
|
||||
- TensorFlow >= 2.15.0
|
||||
- TPU v5e-8 environment
|
||||
- Access to brain-to-text HDF5 dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# Add the current directory to Python path for imports
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from trainer_tf import BrainToTextDecoderTrainerTF
|
||||
|
||||
|
||||
def setup_tpu_environment():
|
||||
"""Setup TPU environment variables for optimal performance"""
|
||||
# TPU v5e-8 optimizations
|
||||
os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') # Enable XLA optimizations
|
||||
os.environ.setdefault('XLA_USE_BF16', '1') # Enable bfloat16 for memory efficiency
|
||||
os.environ.setdefault('TF_XLA_FLAGS', '--tf_xla_auto_jit=2') # Enable XLA JIT compilation
|
||||
|
||||
# TPU memory optimizations
|
||||
os.environ.setdefault('TPU_MEGACORE', '1') # Enable megacore mode for larger models
|
||||
os.environ.setdefault('LIBTPU_INIT_ARGS', '--xla_tpu_spmd_threshold_for_allgather_cse=10000')
|
||||
|
||||
# Disable warnings for cleaner output
|
||||
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
|
||||
|
||||
print("TPU environment configured for v5e-8 optimizations")
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
"""Validate configuration for TensorFlow TPU training"""
|
||||
required_fields = [
|
||||
'model.n_input_features',
|
||||
'model.n_units',
|
||||
'dataset.sessions',
|
||||
'dataset.n_classes',
|
||||
'num_training_batches',
|
||||
'output_dir',
|
||||
'checkpoint_dir'
|
||||
]
|
||||
|
||||
for field in required_fields:
|
||||
keys = field.split('.')
|
||||
value = config
|
||||
try:
|
||||
for key in keys:
|
||||
value = value[key]
|
||||
except KeyError:
|
||||
raise ValueError(f"Missing required configuration field: {field}")
|
||||
|
||||
# TPU-specific validations
|
||||
if config.get('use_tpu', True):
|
||||
if config['dataset']['batch_size'] < 8:
|
||||
logging.warning("Small batch size may not utilize TPU efficiently. Consider batch_size >= 32")
|
||||
|
||||
if not config.get('use_amp', True):
|
||||
logging.warning("Mixed precision disabled. Consider enabling for better TPU performance")
|
||||
|
||||
# Dataset validation
|
||||
dataset_dir = config['dataset']['dataset_dir']
|
||||
if not os.path.exists(dataset_dir):
|
||||
raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
|
||||
|
||||
# Check if at least one session file exists
|
||||
session_found = False
|
||||
for session in config['dataset']['sessions']:
|
||||
train_path = os.path.join(dataset_dir, session, 'data_train.hdf5')
|
||||
if os.path.exists(train_path):
|
||||
session_found = True
|
||||
break
|
||||
|
||||
if not session_found:
|
||||
raise FileNotFoundError("No valid session data files found in dataset directory")
|
||||
|
||||
print("Configuration validation passed")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
default='rnn_args.yaml',
|
||||
help='Path to configuration YAML file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
default=None,
|
||||
help='Override output directory from config'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--checkpoint_dir',
|
||||
default=None,
|
||||
help='Override checkpoint directory from config'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--resume_from',
|
||||
default=None,
|
||||
help='Path to checkpoint to resume training from'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--num_batches',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Override number of training batches'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Override batch size'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--mixed_precision',
|
||||
action='store_true',
|
||||
default=None,
|
||||
help='Enable mixed precision training (bfloat16)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--disable_mixed_precision',
|
||||
action='store_true',
|
||||
help='Disable mixed precision training'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--validate_only',
|
||||
action='store_true',
|
||||
help='Only run validation, do not train'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Enable debug logging'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
log_level = logging.DEBUG if args.debug else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Setup TPU environment
|
||||
setup_tpu_environment()
|
||||
|
||||
# Load configuration
|
||||
if not os.path.exists(args.config_path):
|
||||
raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
|
||||
|
||||
config = OmegaConf.load(args.config_path)
|
||||
print(f"Loaded configuration from: {args.config_path}")
|
||||
|
||||
# Apply command line overrides
|
||||
if args.output_dir:
|
||||
config.output_dir = args.output_dir
|
||||
if args.checkpoint_dir:
|
||||
config.checkpoint_dir = args.checkpoint_dir
|
||||
if args.num_batches:
|
||||
config.num_training_batches = args.num_batches
|
||||
if args.batch_size:
|
||||
config.dataset.batch_size = args.batch_size
|
||||
if args.mixed_precision:
|
||||
config.use_amp = True
|
||||
if args.disable_mixed_precision:
|
||||
config.use_amp = False
|
||||
|
||||
# Validate configuration
|
||||
validate_config(config)
|
||||
|
||||
try:
|
||||
# Initialize trainer
|
||||
print("Initializing TensorFlow Brain-to-Text trainer...")
|
||||
trainer = BrainToTextDecoderTrainerTF(config)
|
||||
|
||||
# Load checkpoint if specified
|
||||
if args.resume_from:
|
||||
if os.path.exists(args.resume_from + '.weights.h5'):
|
||||
trainer.load_checkpoint(args.resume_from)
|
||||
print(f"Resumed training from checkpoint: {args.resume_from}")
|
||||
else:
|
||||
raise FileNotFoundError(f"Checkpoint not found: {args.resume_from}")
|
||||
|
||||
if args.validate_only:
|
||||
print("Running validation only...")
|
||||
# Create validation dataset
|
||||
from dataset_tf import create_input_fn
|
||||
val_dataset = create_input_fn(
|
||||
trainer.val_dataset_tf,
|
||||
trainer.args['dataset']['data_transforms'],
|
||||
training=False
|
||||
)
|
||||
val_dist_dataset = trainer.strategy.experimental_distribute_dataset(val_dataset)
|
||||
|
||||
# Run validation
|
||||
val_metrics = trainer._validate(val_dist_dataset)
|
||||
|
||||
print(f"Validation Results:")
|
||||
print(f" Average Loss: {val_metrics['avg_loss']:.4f}")
|
||||
print(f" Average PER: {val_metrics['avg_per']:.4f}")
|
||||
print(f" Total Edit Distance: {val_metrics['total_edit_distance']}")
|
||||
print(f" Total Sequence Length: {val_metrics['total_seq_length']}")
|
||||
|
||||
else:
|
||||
# Start training
|
||||
print("Starting training...")
|
||||
train_stats = trainer.train()
|
||||
|
||||
print("\nTraining completed successfully!")
|
||||
print(f"Best validation PER: {trainer.best_val_per:.5f}")
|
||||
print(f"Final training loss: {train_stats['train_losses'][-1]:.4f}")
|
||||
print(f"Final validation loss: {train_stats['val_losses'][-1]:.4f}")
|
||||
print(f"Total training batches: {len(train_stats['train_losses'])}")
|
||||
|
||||
# Save final training statistics
|
||||
import pickle
|
||||
stats_path = os.path.join(config.output_dir, 'training_stats.pkl')
|
||||
with open(stats_path, 'wb') as f:
|
||||
pickle.dump(train_stats, f)
|
||||
print(f"Training statistics saved to: {stats_path}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nTraining interrupted by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\nTraining failed with error: {e}")
|
||||
if args.debug:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user