diff --git a/model_training_nnn_tpu/amp_tpu_training.py b/model_training_nnn_tpu/amp_tpu_training.py deleted file mode 100644 index 3c29e68..0000000 --- a/model_training_nnn_tpu/amp_tpu_training.py +++ /dev/null @@ -1,315 +0,0 @@ -#!/usr/bin/env python3 -""" -使用AMP的TPU训练脚本 -正确处理混合精度训练,避免dtype不匹配问题 -""" - -import os -import time -import torch -import torch.nn as nn -import torch.optim as optim -import torchvision -import torchvision.transforms as transforms - -# 设置AMP相关的环境变量 -os.environ['XLA_FLAGS'] = ( - '--xla_cpu_multi_thread_eigen=true ' - '--xla_cpu_enable_fast_math=true' -) -os.environ['XLA_USE_BF16'] = '1' # 启用bf16 - -import torch_xla.core.xla_model as xm -import torch_xla.distributed.parallel_loader as pl -import torch_xla.amp as xla_amp - - -class AMPModel(nn.Module): - """支持AMP的简单模型""" - - def __init__(self, input_size=784, hidden_size=512, num_classes=10): - super(AMPModel, self).__init__() - - self.network = nn.Sequential( - nn.Linear(input_size, hidden_size), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Linear(hidden_size, hidden_size // 2), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Linear(hidden_size // 2, num_classes) - ) - - def forward(self, x): - # 展平输入 - x = x.view(x.size(0), -1) - return self.network(x) - - -class AMPTrainer: - """AMP训练器""" - - def __init__(self, model, device, learning_rate=0.001): - self.model = model - self.device = device - self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) - self.criterion = nn.CrossEntropyLoss() - - # 初始化AMP scaler - self.scaler = xla_amp.GradScaler() - - print(f"✅ AMP训练器初始化完成") - print(f" 设备: {device}") - print(f" 模型参数: {sum(p.numel() for p in model.parameters()):,}") - - def train_step(self, data, target): - """单个AMP训练步骤""" - self.model.train() - self.optimizer.zero_grad() - - # 使用autocast进行混合精度前向传播 - with xla_amp.autocast(): - output = self.model(data) - loss = self.criterion(output, target) - - # 使用scaler进行反向传播 - self.scaler.scale(loss).backward() - - # 梯度裁剪(可选) - self.scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - # 更新参数 - self.scaler.step(self.optimizer) - self.scaler.update() - - # 计算准确率 - pred = output.argmax(dim=1) - correct = pred.eq(target).sum().item() - accuracy = correct / target.size(0) - - return loss.item(), accuracy - - def evaluate_step(self, data, target): - """单个评估步骤""" - self.model.eval() - - with torch.no_grad(): - with xla_amp.autocast(): - output = self.model(data) - loss = self.criterion(output, target) - - pred = output.argmax(dim=1) - correct = pred.eq(target).sum().item() - accuracy = correct / target.size(0) - - return loss.item(), accuracy - - -def get_mnist_loaders(batch_size=64): - """获取MNIST数据加载器""" - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - - train_dataset = torchvision.datasets.MNIST( - root='./mnist_data', - train=True, - download=True, - transform=transform - ) - - test_dataset = torchvision.datasets.MNIST( - root='./mnist_data', - train=False, - download=True, - transform=transform - ) - - train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=True, - num_workers=0 - ) - - test_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=batch_size, - shuffle=False, - num_workers=0 - ) - - return train_loader, test_loader - - -def train_with_amp(): - """使用AMP进行训练""" - print("🚀 开始AMP TPU训练...") - - # 获取设备 - device = xm.xla_device() - print(f"📱 设备: {device}") - - # 创建模型 - model = AMPModel(input_size=784, hidden_size=512, num_classes=10).to(device) - - # 创建训练器 - trainer = AMPTrainer(model, device, learning_rate=0.001) - - # 获取数据 - print("📥 加载MNIST数据...") - train_loader, test_loader = get_mnist_loaders(batch_size=64) - - # 使用XLA并行加载器 - train_device_loader = pl.MpDeviceLoader(train_loader, device) - test_device_loader = pl.MpDeviceLoader(test_loader, device) - - print("🎯 开始AMP训练...") - - # 训练循环 - num_epochs = 2 - train_losses = [] - train_accuracies = [] - - for epoch in range(num_epochs): - print(f"\n📊 Epoch {epoch + 1}/{num_epochs}") - - epoch_start = time.time() - epoch_loss = 0.0 - epoch_acc = 0.0 - num_batches = 0 - max_batches_per_epoch = 200 # 限制每个epoch的批次数 - - for batch_idx, (data, target) in enumerate(train_device_loader): - if batch_idx >= max_batches_per_epoch: - break - - # 训练步骤 - loss, accuracy = trainer.train_step(data, target) - - epoch_loss += loss - epoch_acc += accuracy - num_batches += 1 - - # 每20个批次同步一次 - if batch_idx % 20 == 0: - xm.mark_step() - - avg_loss = epoch_loss / num_batches - avg_acc = epoch_acc / num_batches * 100 - - print(f" 批次 {batch_idx:3d}/{max_batches_per_epoch} | " - f"损失: {avg_loss:.4f} | " - f"准确率: {avg_acc:.2f}%") - - # Epoch结束同步 - xm.mark_step() - xm.wait_device_ops() - - epoch_time = time.time() - epoch_start - final_loss = epoch_loss / num_batches - final_acc = epoch_acc / num_batches * 100 - - train_losses.append(final_loss) - train_accuracies.append(final_acc) - - print(f"✅ Epoch {epoch + 1} 完成 | " - f"耗时: {epoch_time:.2f}s | " - f"平均损失: {final_loss:.4f} | " - f"平均准确率: {final_acc:.2f}%") - - return trainer, train_losses, train_accuracies - - -def test_with_amp(trainer): - """使用AMP进行测试""" - print("\n🧪 开始AMP测试...") - - device = xm.xla_device() - _, test_loader = get_mnist_loaders(batch_size=64) - test_device_loader = pl.MpDeviceLoader(test_loader, device) - - total_loss = 0.0 - total_acc = 0.0 - num_batches = 0 - max_test_batches = 100 - - start_time = time.time() - - for batch_idx, (data, target) in enumerate(test_device_loader): - if batch_idx >= max_test_batches: - break - - loss, accuracy = trainer.evaluate_step(data, target) - - total_loss += loss - total_acc += accuracy - num_batches += 1 - - if batch_idx % 20 == 0: - xm.mark_step() - - xm.mark_step() - xm.wait_device_ops() - - test_time = time.time() - start_time - avg_loss = total_loss / num_batches - avg_acc = total_acc / num_batches * 100 - - print(f"✅ 测试完成!") - print(f"⏱️ 测试时间: {test_time:.2f}秒") - print(f"🎯 测试损失: {avg_loss:.4f}") - print(f"🎯 测试准确率: {avg_acc:.2f}%") - - return avg_loss, avg_acc - - -def main(): - """主函数""" - print("=" * 60) - print("⚡ AMP TPU训练示例") - print("=" * 60) - - try: - # 训练 - trainer, train_losses, train_accuracies = train_with_amp() - - # 测试 - test_loss, test_acc = test_with_amp(trainer) - - # 保存模型 - print("\n💾 保存模型...") - model_cpu = trainer.model.cpu() - torch.save({ - 'model_state_dict': model_cpu.state_dict(), - 'train_losses': train_losses, - 'train_accuracies': train_accuracies, - 'test_loss': test_loss, - 'test_accuracy': test_acc - }, 'amp_mnist_model.pth') - print("✅ 模型已保存到 amp_mnist_model.pth") - - print("\n🎉 AMP训练完成!") - print(f"📊 最终训练准确率: {train_accuracies[-1]:.2f}%") - print(f"📊 测试准确率: {test_acc:.2f}%") - - if train_accuracies[-1] > 85 and test_acc > 80: - print("✅ AMP训练成功! 模型性能优秀") - else: - print("⚠️ 模型性能一般,但AMP功能正常") - - except Exception as e: - print(f"❌ AMP训练失败: {e}") - import traceback - traceback.print_exc() - - print("\n💡 故障排除建议:") - print(" 1. 确保PyTorch XLA版本支持AMP") - print(" 2. 检查TPU资源是否充足") - print(" 3. 尝试减小batch_size") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/model_training_nnn_tpu/check_tpu_memory.py b/model_training_nnn_tpu/check_tpu_memory.py deleted file mode 100644 index 7356847..0000000 --- a/model_training_nnn_tpu/check_tpu_memory.py +++ /dev/null @@ -1,403 +0,0 @@ -#!/usr/bin/env python3 -""" -TPU训练内存监控工具 - 专注于训练过程中的实时内存和MXU监控 -适用于TPU v5e-8环境 -""" - -import tensorflow as tf -import time -import numpy as np - -def monitor_tpu_during_training(): - """训练过程中的TPU实时内存和MXU监控""" - print("📊 TPU训练实时监控工具") - print("=" * 50) - - # 获取TPU设备 - try: - tpu_devices = tf.config.list_logical_devices('TPU') - print(f"📍 发现TPU设备: {len(tpu_devices)}个") - if not tpu_devices: - print("❌ 未发现TPU设备") - return - except Exception as e: - print(f"❌ 无法检测TPU设备: {e}") - return - - def get_detailed_memory_snapshot(): - """获取详细的内存快照,包含所有核心信息""" - snapshot = {} - total_current = 0 - total_peak = 0 - active_cores = 0 - core_details = [] - - for i, device in enumerate(tpu_devices): - try: - memory_info = tf.config.experimental.get_memory_info(device.name) - if memory_info and 'current' in memory_info: - current_mb = memory_info['current'] // (1024 * 1024) - peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) - - if current_mb > 1: # >1MB算活跃 - active_cores += 1 - total_current += current_mb - total_peak += peak_mb - core_details.append(f"Core{i}:{current_mb}MB") - - snapshot[f'core_{i}'] = { - 'current': current_mb, - 'peak': peak_mb, - 'device': device.name - } - else: - snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name} - except Exception as e: - snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name, 'error': str(e)} - - snapshot['summary'] = { - 'total_current': total_current, - 'total_peak': total_peak, - 'active_cores': active_cores, - 'total_cores': len(tpu_devices), - 'core_details': core_details - } - return snapshot - - def test_mxu_performance(): - """测试MXU性能和计算能力""" - print("\n🧮 MXU计算性能测试:") - - mxu_results = [] - try: - with tf.device(tpu_devices[0].name): - # 测试不同规模的矩阵运算 - test_configs = [ - (2000, "2K×2K", tf.bfloat16), - (4000, "4K×4K", tf.bfloat16), - (6000, "6K×6K", tf.bfloat16), - ] - - for size, desc, dtype in test_configs: - try: - # 获取测试前内存 - pre_mem = get_detailed_memory_snapshot() - - start_time = time.time() - - # 创建矩阵并执行MXU密集型运算 - matrix_a = tf.random.normal([size, size], dtype=dtype) - matrix_b = tf.random.normal([size, size], dtype=dtype) - - @tf.function - def mxu_operation(): - # 连续矩阵运算,充分使用MXU - result = tf.matmul(matrix_a, matrix_b) - result = tf.matmul(result, matrix_a) - return tf.reduce_sum(result) - - result = mxu_operation() - # 使用result确保计算被执行 - _ = result.numpy() - end_time = time.time() - - # 获取测试后内存 - post_mem = get_detailed_memory_snapshot() - - duration = end_time - start_time - # 计算FLOPS (两次矩阵乘法) - flops = 2 * (2 * size**3) - tflops = flops / duration / 1e12 - - memory_used = post_mem['summary']['total_current'] - pre_mem['summary']['total_current'] - - print(f" {desc} ({dtype.name}): {duration:.3f}s, {tflops:.1f}TFLOPS, 内存+{memory_used}MB") - - mxu_results.append({ - 'size': size, - 'tflops': tflops, - 'duration': duration, - 'memory_used': memory_used - }) - - except Exception as e: - print(f" {desc}: 测试失败 - {str(e)[:50]}") - - # MXU性能分析 - if mxu_results: - max_tflops = max(r['tflops'] for r in mxu_results) - total_memory = sum(r['memory_used'] for r in mxu_results if r['memory_used'] > 0) - - # TPU v5e-8单核理论性能 - theoretical_tflops = 275 # bf16峰值性能 - efficiency = (max_tflops / theoretical_tflops) * 100 - - print(f"\n 📊 MXU性能汇总:") - print(f" 峰值性能: {max_tflops:.1f} TFLOPS") - print(f" 理论峰值: {theoretical_tflops} TFLOPS") - print(f" MXU效率: {efficiency:.1f}%") - print(f" 计算内存占用: {total_memory}MB") - - if efficiency > 80: - status = "🟢 优秀" - elif efficiency > 50: - status = "🟡 良好" - elif efficiency > 20: - status = "🟠 中等" - else: - status = "🔴 需优化" - - print(f" 性能评级: {status}") - - except Exception as e: - print(f" MXU测试失败: {e}") - - try: - print("🎯 开始TPU训练监控...") - - # 1. 获取初始状态 - print("\n📸 初始TPU状态:") - baseline_snapshot = get_detailed_memory_snapshot() - - print(f" 总内存使用: {baseline_snapshot['summary']['total_current']}MB") - print(f" 活跃核心: {baseline_snapshot['summary']['active_cores']}/{baseline_snapshot['summary']['total_cores']}") - - # 显示各核心详细状态 - for i in range(len(tpu_devices)): - core = baseline_snapshot[f'core_{i}'] - if core['current'] > 0 or core['peak'] > 0: - print(f" Core{i}: 当前{core['current']}MB, 峰值{core['peak']}MB") - - # 2. MXU性能基准测试 - test_mxu_performance() - - # 3. 创建分布式策略 - 使用项目验证的TPU初始化代码 - print(f"\n🔄 使用项目标准TPU初始化...") - try: - # 使用项目里验证过的TPU初始化代码 - # 禁用GPU避免冲突 - try: - tf.config.set_visible_devices([], 'GPU') - print("🚫 GPU已禁用,避免CUDA冲突") - except: - pass - - # 使用标准的TPU初始化流程 - print("🚀 使用官方TensorFlow TPU初始化...") - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - tf.config.experimental_connect_to_cluster(resolver) - tf.tpu.experimental.initialize_tpu_system(resolver) - - # 验证TPU设备 - tpu_devices_check = tf.config.list_logical_devices('TPU') - print(f"✅ TPU设备验证: 发现 {len(tpu_devices_check)} 个设备") - - # 创建TPU策略 - strategy = tf.distribute.TPUStrategy(resolver) - print(f"✅ 成功创建TPU策略: {strategy.num_replicas_in_sync}个副本") - use_distributed = True - - except Exception as e: - print(f"⚠️ 分布式策略失败: {str(e)[:80]}") - print(" 将使用单设备模拟") - use_distributed = False - - # 4. 模拟Brain-to-Text训练场景 - print(f"\n🧠 模拟Brain-to-Text训练场景...") - - if use_distributed: - # 分布式训练模拟 - with strategy.scope(): - print("📦 创建分布式模型参数...") - - # 创建接近真实Brain-to-Text模型的参数 (修复维度匹配) - model_components = { - # GRU层权重:第一层接收512维输入,后续层接收256维 - 'gru_layer_0': tf.Variable(tf.random.normal([512, 256]), name='gru_0'), - 'gru_layer_1': tf.Variable(tf.random.normal([256, 256]), name='gru_1'), - 'gru_layer_2': tf.Variable(tf.random.normal([256, 256]), name='gru_2'), - 'output_projection': tf.Variable(tf.random.normal([256, 41]), name='output'), - # 添加day-specific层模拟 (输入512维,输出512维) - 'day_weights': [tf.Variable(tf.random.normal([512, 512]), name=f'day_{i}') for i in range(8)] - } - - # 检查模型加载后内存 - after_model = get_detailed_memory_snapshot() - model_memory = after_model['summary']['total_current'] - baseline_snapshot['summary']['total_current'] - print(f"🧠 模型加载完成: +{model_memory}MB, {after_model['summary']['active_cores']}个活跃核心") - - # 训练循环模拟 - print(f"\n🔄 开始训练循环监控...") - - for step in range(10): - step_start_time = time.time() - - @tf.function - def distributed_training_step(): - # 模拟真实训练数据大小 - batch_size = 32 - seq_length = 1000 - features = 512 - - # 输入数据 - neural_data = tf.random.normal([batch_size, seq_length, features]) - targets = tf.random.uniform([batch_size, seq_length], maxval=41, dtype=tf.int32) - - # 模拟前向传播 - x = neural_data - - # Day-specific transformation (简化版本避免复杂的维度操作) - # 模拟day-specific变换:对每个时间步应用相同变换 - day_weight = model_components['day_weights'][0] # 简化:使用第一个day权重 - # 对最后一个维度进行变换: [batch, seq, 512] @ [512, 512] -> [batch, seq, 512] - x = tf.matmul(x, day_weight) - - # 为CTC损失添加目标使用(模拟) - target_length = tf.reduce_sum(tf.cast(targets > 0, tf.int32), axis=1) - # 简化的CTC相关计算 - batch_loss_weight = tf.reduce_mean(tf.cast(target_length, tf.float32)) - - # GRU layers simulation - for i in range(3): - layer_name = f'gru_layer_{i}' - weight = model_components[layer_name] - - # 处理张量维度:第一层从3D输入,后续层从2D输入 - if i == 0: - # 第一层:取最后时间步 [batch, seq, features] -> [batch, features] - if len(x.shape) == 3: - x = x[:, -1, :] # 取最后时间步 - x = tf.nn.tanh(tf.matmul(x, weight)) - else: - # 后续层:直接处理2D张量 [batch, features] -> [batch, features] - x = tf.nn.tanh(tf.matmul(x, weight)) - - # 输出投影 - logits = tf.matmul(x, model_components['output_projection']) - - # CTC loss模拟(使用batch_loss_weight作为权重) - base_loss = tf.reduce_mean(tf.square(logits)) - loss = base_loss * batch_loss_weight - - return loss - - # 执行训练步骤 - per_replica_loss = strategy.run(distributed_training_step) - # 聚合分布式结果 - loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None) - step_duration = time.time() - step_start_time - - # 获取当前内存状态 - current_snapshot = get_detailed_memory_snapshot() - step_memory = current_snapshot['summary']['total_current'] - memory_delta = step_memory - baseline_snapshot['summary']['total_current'] - - # 显示详细训练状态 - active_cores_info = f"({', '.join(current_snapshot['summary']['core_details'])})" if current_snapshot['summary']['core_details'] else "(无活跃)" - - print(f" Step {step:2d}: loss={float(loss.numpy()):.4f}, " - f"时间={step_duration:.3f}s, " - f"内存={step_memory}MB(+{memory_delta}), " - f"活跃={current_snapshot['summary']['active_cores']}/{current_snapshot['summary']['total_cores']} {active_cores_info}") - - # 每5步显示峰值内存 - if step % 5 == 0: - peak_info = f"峰值: {current_snapshot['summary']['total_peak']}MB" - print(f" {peak_info}") - - time.sleep(0.2) # 短暂暂停观察 - - else: - # 单设备训练模拟(改进版) - print("🔸 单设备训练模拟...") - - with tf.device(tpu_devices[0].name): - # 创建较小的模型参数 - simple_weights = tf.Variable(tf.random.normal([512, 256]), name='simple_net') - - for step in range(8): - step_start = time.time() - - # 创建较大的数据批次 - batch_data = tf.random.normal([64, 1000, 512]) # 增大batch size - - # 模拟计算密集型操作 - @tf.function - def compute_step(): - x = tf.reshape(batch_data, [-1, 512]) - result = tf.matmul(x, simple_weights) - result = tf.nn.relu(result) - return tf.reduce_mean(result) - - result = compute_step() - step_duration = time.time() - step_start - - # 获取内存状态 - snapshot = get_detailed_memory_snapshot() - memory_change = snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current'] - - print(f" Step {step}: result={result.numpy():.4f}, " - f"时间={step_duration:.3f}s, " - f"内存变化=+{memory_change}MB, " - f"峰值={snapshot['summary']['total_peak']}MB") - - # 5. 最终分析报告 - final_snapshot = get_detailed_memory_snapshot() - total_growth = final_snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current'] - peak_usage = final_snapshot['summary']['total_peak'] - - print(f"\n📈 训练监控报告:") - print(f" 总内存增长: +{total_growth}MB") - print(f" 峰值内存使用: {peak_usage}MB ({peak_usage/1024:.2f}GB)") - print(f" 最终活跃核心: {final_snapshot['summary']['active_cores']}/{final_snapshot['summary']['total_cores']}") - - # 各核心最终状态 - print(f" 各核心最终状态:") - has_changes = False - for i in range(len(tpu_devices)): - final_core = final_snapshot[f'core_{i}'] - baseline_core = baseline_snapshot[f'core_{i}'] - current_change = final_core['current'] - baseline_core['current'] - peak_change = final_core['peak'] - baseline_core['peak'] - - if current_change != 0 or peak_change != 0: - has_changes = True - print(f" Core{i}: 当前{final_core['current']}MB(+{current_change}), 峰值{final_core['peak']}MB(+{peak_change})") - - if not has_changes: - print(f" 所有核心内存无明显变化") - - # 分布式使用分析 - if final_snapshot['summary']['active_cores'] == 1: - print(f"\n⚠️ 分布式问题诊断:") - print(f" 只有1个核心活跃,其他7个核心空闲") - print(f" 可能原因: TPU策略配置问题或模型未正确分布") - print(f" 建议: 检查分布式策略和模型分片") - elif final_snapshot['summary']['active_cores'] > 4: - print(f"\n✅ 分布式状态良好:") - print(f" {final_snapshot['summary']['active_cores']}个核心活跃,多核心并行工作正常") - else: - print(f"\n🟡 分布式部分工作:") - print(f" {final_snapshot['summary']['active_cores']}个核心活跃,可能存在负载不均衡") - - print("✅ TPU训练监控完成") - - except Exception as e: - print(f"❌ 训练监控失败: {e}") - import traceback - print(f"详细错误: {traceback.format_exc()[:300]}") - -if __name__ == "__main__": - print("🚀 TPU训练内存监控工具") - print("专注于训练过程中的实时内存和性能监控") - print("适用于TPU v5e-8环境") - print() - - monitor_tpu_during_training() - - print(f"\n🎯 监控要点总结:") - print(f" 1. 确认所有8个TPU核心是否活跃") - print(f" 2. 监控内存增长模式和峰值使用") - print(f" 3. 检测MXU计算性能和效率") - print(f" 4. 验证分布式策略是否正常工作") - print(f" 5. 识别可能的内存泄漏或性能瓶颈") \ No newline at end of file diff --git a/model_training_nnn_tpu/dataset.py b/model_training_nnn_tpu/dataset.py deleted file mode 100644 index 086370e..0000000 --- a/model_training_nnn_tpu/dataset.py +++ /dev/null @@ -1,336 +0,0 @@ -import os -import torch -from torch.utils.data import Dataset -import h5py -import numpy as np -from torch.nn.utils.rnn import pad_sequence -import math - -class BrainToTextDataset(Dataset): - ''' - Dataset for brain-to-text data - - Returns an entire batch of data instead of a single example - ''' - - def __init__( - self, - trial_indicies, - n_batches, - split = 'train', - batch_size = 64, - days_per_batch = 1, - random_seed = -1, - must_include_days = None, - feature_subset = None - ): - ''' - trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values - n_batches: (int) - number of random training batches to create - split: (string) - string specifying if this is a train or test dataset - batch_size: (int) - number of examples to include in batch returned from __getitem_() - days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates - to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch - random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random - must_include_days ([int]) - list of days that must be included in every batch - feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data - ''' - - # Set random seed for reproducibility - if random_seed != -1: - np.random.seed(random_seed) - torch.manual_seed(random_seed) - - self.split = split - - # Ensure the split is valid - if self.split not in ['train', 'test']: - raise ValueError(f'split must be either "train" or "test". Received {self.split}') - - self.days_per_batch = days_per_batch - - self.batch_size = batch_size - - self.n_batches = n_batches - - self.days = {} - self.n_trials = 0 - self.trial_indicies = trial_indicies - self.n_days = len(trial_indicies.keys()) - - self.feature_subset = feature_subset - - # Calculate total number of trials in the dataset - for d in trial_indicies: - self.n_trials += len(trial_indicies[d]['trials']) - - if must_include_days is not None and len(must_include_days) > days_per_batch: - raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}') - - if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train': - raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset') - - if must_include_days is not None: - # Map must_include_days to correct indicies if they are negative - for i, d in enumerate(must_include_days): - if d < 0: - must_include_days[i] = self.n_days + d - - self.must_include_days = must_include_days - - # Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error - if self.split == 'train' and self.days_per_batch > self.n_days: - raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.') - - - if self.split == 'train': - self.batch_index = self.create_batch_index_train() - else: - self.batch_index = self.create_batch_index_test() - self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data - - def __len__(self): - ''' - How many batches are in this dataset. - Because training data is sampled randomly, there is no fixed dataset length, - however this method is required for DataLoader to work - ''' - return self.n_batches if self.n_batches is not None else 0 - - def __getitem__(self, idx): - ''' - Gets an entire batch of data from the dataset, not just a single item - ''' - batch = { - 'input_features' : [], - 'seq_class_ids' : [], - 'n_time_steps' : [], - 'phone_seq_lens' : [], - 'day_indicies' : [], - 'transcriptions' : [], - 'block_nums' : [], - 'trial_nums' : [], - } - - index = self.batch_index[idx] - - # Iterate through each day in the index - for d in index.keys(): - - # Open the hdf5 file for that day - with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f: - - # For each trial in the selected trials in that day - for t in index[d]: - - try: - g = f[f'trial_{t:04d}'] - - # Remove features is neccessary - input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # neural data - convert to bf16 for TPU compatibility - if self.feature_subset: - input_features = input_features[:,self.feature_subset] - - batch['input_features'].append(input_features) - - batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels - batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions - batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding - batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding - batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers - batch['block_nums'].append(g.attrs['block_num']) - batch['trial_nums'].append(g.attrs['trial_num']) - - except Exception as e: - print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}') - continue - - # Pad data to form a cohesive batch - ensure bf16 dtype is preserved - batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16) - batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0) - - batch['n_time_steps'] = torch.tensor(batch['n_time_steps']) - batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens']) - batch['day_indicies'] = torch.tensor(batch['day_indicies']) - batch['transcriptions'] = torch.stack(batch['transcriptions']) - batch['block_nums'] = torch.tensor(batch['block_nums']) - batch['trial_nums'] = torch.tensor(batch['trial_nums']) - - return batch - - - def create_batch_index_train(self): - ''' - Create an index that maps a batch_number to batch_size number of trials - - Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days - (or as even as possible if batch_size is not divisible by days_per_batch) - ''' - - batch_index = {} - - # Precompute the days that are not in must_include_days - if self.must_include_days is not None: - non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days] - - for batch_idx in range(self.n_batches): - batch = {} - - # Which days will be used for this batch. Picked randomly without replacement - # TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day - - # If must_include_days is not empty, we will use those days and then randomly sample the rest - if self.must_include_days is not None and len(self.must_include_days) > 0: - - days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False))) - - # Otherwise we will select random days without replacement - else: - days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False) - - # How many trials will be sampled from each day - num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials - - for d in days: - - # Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem - trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True) - batch[d] = trial_idxs - - # Remove extra trials - extra_trials = (num_trials * len(days)) - self.batch_size - - # While we still have extra trials, remove the last trial from a random day - while extra_trials > 0: - d = np.random.choice(days) - batch[d] = batch[d][:-1] - extra_trials -= 1 - - batch_index[batch_idx] = batch - - return batch_index - - def create_batch_index_test(self): - ''' - Create an index that is all validation/testing data in batches of up to self.batch_size - - If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size - - This index will ensures that every trial in the validation set is seen once and only once - ''' - batch_index = {} - batch_idx = 0 - - for d in self.trial_indicies.keys(): - - # Calculate how many batches we need for this day - num_trials = len(self.trial_indicies[d]['trials']) - num_batches = (num_trials + self.batch_size - 1) // self.batch_size - - # Create batches for this day - for i in range(num_batches): - start_idx = i * self.batch_size - end_idx = min((i + 1) * self.batch_size, num_trials) - - # Get the trial indices for this batch - batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx] - - # Add to batch_index - batch_index[batch_idx] = {d : batch_trials} - batch_idx += 1 - - return batch_index - -def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None): - ''' - Split data from file_paths into train and test splits - Returns two dictionaries that detail which trials in each day will be a part of that split: - Example: - { - 0: trials[1,2,3], session_path: 'path' - 1: trials[2,5,6], session_path: 'path' - } - - Args: - file_paths (list): List of file paths to the hdf5 files containing the data - test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing - seed (int): Seed for reproducibility. If set to -1, the split will be random - bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as: - { - 'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...}, - 'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...}, - ... - } - ''' - # Set seed for reporoducibility - if seed != -1: - np.random.seed(seed) - - # Get trials in each day - trials_per_day = {} - for i, path in enumerate(file_paths): - # Handle both Windows and Unix path separators - path_parts = path.replace('\\', '/').split('/') - session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0] - - good_trial_indices = [] - - if os.path.exists(path): - with h5py.File(path, 'r') as f: - num_trials = len(list(f.keys())) - for t in range(num_trials): - key = f'trial_{t:04d}' - - block_num = f[key].attrs['block_num'] - trial_num = f[key].attrs['trial_num'] - - if ( - bad_trials_dict is not None - and session in bad_trials_dict - and str(block_num) in bad_trials_dict[session] - and trial_num in bad_trials_dict[session][str(block_num)] - ): - # print(f'Bad trial: {session}_{block_num}_{trial_num}') - continue - - good_trial_indices.append(t) - - trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path} - - # Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training - train_trials = {} - test_trials = {} - - for day in trials_per_day.keys(): - - num_trials = trials_per_day[day]['num_trials'] - - # Generate all trial indices for this day (assuming 0-indexed) - all_trial_indices = trials_per_day[day]['trial_indices'] - - # If test_percentage is 0 or 1, we can just assign all trials to either train or test - if test_percentage == 0: - train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']} - test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']} - continue - - elif test_percentage == 1: - train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']} - test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']} - continue - - else: - # Calculate how many trials to use for testing - num_test = max(1, int(num_trials * test_percentage)) - - # Randomly select indices for testing - test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist() - - # Remaining indices go to training - train_indices = [idx for idx in all_trial_indices if idx not in test_indices] - - # Store the split indices - train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']} - test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']} - - return train_trials, test_trials \ No newline at end of file diff --git a/model_training_nnn_tpu/evaluate_model.py b/model_training_nnn_tpu/evaluate_model.py deleted file mode 100644 index d84e07b..0000000 --- a/model_training_nnn_tpu/evaluate_model.py +++ /dev/null @@ -1,304 +0,0 @@ -import os -import torch -import numpy as np -import pandas as pd -import redis -from omegaconf import OmegaConf -import time -from tqdm import tqdm -import editdistance -import argparse - -from rnn_model import GRUDecoder -from evaluate_model_helpers import * - -# argument parser for command line arguments -parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.') -parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline', - help='Path to the pretrained model directory (relative to the current working directory).') -parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', - help='Path to the dataset directory (relative to the current working directory).') -parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'], - help='Evaluation type: "val" for validation set, "test" for test set. ' - 'If "test", ground truth is not available.') -parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', - help='Path to the CSV file with metadata about the dataset (relative to the current working directory).') -parser.add_argument('--gpu_number', type=int, default=-1, - help='GPU number to use for RNN model inference. Set to -1 to use CPU.') -args = parser.parse_args() - -# paths to model and data directories -# Note: these paths are relative to the current working directory -model_path = args.model_path -data_dir = args.data_dir - -# define evaluation type -eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available - -# load csv file -b2txt_csv_df = pd.read_csv(args.csv_path) - -# load model args -model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml')) - -# set up gpu device -gpu_number = args.gpu_number -if torch.cuda.is_available() and gpu_number >= 0: - if gpu_number >= torch.cuda.device_count(): - raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}') - device = f'cuda:{gpu_number}' - device = torch.device(device) - print(f'Using {device} for model inference.') -else: - if gpu_number >= 0: - print(f'GPU number {gpu_number} requested but not available.') - print('Using CPU for model inference.') - device = torch.device('cpu') - -# define model -model = GRUDecoder( - neural_dim = model_args['model']['n_input_features'], - n_units = model_args['model']['n_units'], - n_days = len(model_args['dataset']['sessions']), - n_classes = model_args['dataset']['n_classes'], - rnn_dropout = model_args['model']['rnn_dropout'], - input_dropout = model_args['model']['input_network']['input_layer_dropout'], - n_layers = model_args['model']['n_layers'], - patch_size = model_args['model']['patch_size'], - patch_stride = model_args['model']['patch_stride'], -) - -# load model weights -checkpoint = torch.load( - os.path.join(model_path, 'checkpoint/best_checkpoint'), - map_location=device, - weights_only=False, -) -# rename keys to not start with "module." (happens if model was saved with DataParallel) -for key in list(checkpoint['model_state_dict'].keys()): - checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key) - checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key) -model.load_state_dict(checkpoint['model_state_dict']) - -# add model to device -model.to(device) - -# set model to eval mode -model.eval() - -# load data for each session -test_data = {} -total_test_trials = 0 -for session in model_args['dataset']['sessions']: - files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')] - if f'data_{eval_type}.hdf5' in files: - eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5') - - data = load_h5py_file(eval_file, b2txt_csv_df) - test_data[session] = data - - total_test_trials += len(test_data[session]["neural_features"]) - print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.') -print(f'Total number of {eval_type} trials: {total_test_trials}') -print() - - -# put neural data through the pretrained model to get phoneme predictions (logits) -with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar: - for session, data in test_data.items(): - - data['logits'] = [] - data['pred_seq'] = [] - input_layer = model_args['dataset']['sessions'].index(session) - - for trial in range(len(data['neural_features'])): - # get neural input for the trial - neural_input = data['neural_features'][trial] - - # add batch dimension - neural_input = np.expand_dims(neural_input, axis=0) - - # convert to torch tensor - neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) - - # run decoding step - logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device) - data['logits'].append(logits) - - pbar.update(1) -pbar.close() - - -# convert logits to phoneme sequences and print them out -for session, data in test_data.items(): - data['pred_seq'] = [] - for trial in range(len(data['logits'])): - logits = data['logits'][trial][0] - pred_seq = np.argmax(logits, axis=-1) - # remove blanks (0) - pred_seq = [int(p) for p in pred_seq if p != 0] - # remove consecutive duplicates - pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] - # convert to phonemes - pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq] - # add to data - data['pred_seq'].append(pred_seq) - - # print out the predicted sequences - block_num = data['block_num'][trial] - trial_num = data['trial_num'][trial] - print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}') - if eval_type == 'val': - sentence_label = data['sentence_label'][trial] - true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]] - true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq] - - print(f'Sentence label: {sentence_label}') - print(f'True sequence: {" ".join(true_seq)}') - print(f'Predicted Sequence: {" ".join(pred_seq)}') - print() - - -# language model inference via redis -# make sure that the standalone language model is running on the localhost redis ip -# see README.md for instructions on how to run the language model - -def connect_to_redis_with_retry(host, port, password, db=0, max_retries=10, retry_delay=3): - """Connect to Redis with retry logic""" - for attempt in range(max_retries): - try: - print(f"Attempting to connect to Redis at {host}:{port} (attempt {attempt + 1}/{max_retries})...") - r = redis.Redis(host=host, port=port, db=db, password=password) - r.ping() # Test the connection - print(f"Successfully connected to Redis at {host}:{port}") - return r - except redis.exceptions.ConnectionError as e: - print(f"Redis connection failed (attempt {attempt + 1}/{max_retries}): {e}") - if attempt < max_retries - 1: - print(f"Retrying in {retry_delay} seconds...") - time.sleep(retry_delay) - else: - print("Max retries reached. Could not connect to Redis.") - raise e - except Exception as e: - print(f"Unexpected error connecting to Redis: {e}") - if attempt < max_retries - 1: - print(f"Retrying in {retry_delay} seconds...") - time.sleep(retry_delay) - else: - raise e - -r = connect_to_redis_with_retry('hs.zchens.cn', 6379, 'admin01') -r.flushall() # clear all streams in redis - -# define redis streams for the remote language model -remote_lm_input_stream = 'remote_lm_input' -remote_lm_output_partial_stream = 'remote_lm_output_partial' -remote_lm_output_final_stream = 'remote_lm_output_final' - -# set timestamps for last entries seen in the redis streams -remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r) -remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r) -remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r) -remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r) -remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r) - -lm_results = { - 'session': [], - 'block': [], - 'trial': [], - 'true_sentence': [], - 'pred_sentence': [], -} - -# loop through all trials and put logits into the remote language model to get text predictions -# note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090) -with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar: - for session in test_data.keys(): - for trial in range(len(test_data[session]['logits'])): - # get trial logits and rearrange them for the LM - logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0] - - # reset language model - remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen) - - ''' - # update language model parameters - remote_lm_done_updating_lastEntrySeen = update_remote_lm_params( - r, - remote_lm_done_updating_lastEntrySeen, - acoustic_scale=0.35, - blank_penalty=90.0, - alpha=0.55, - ) - ''' - - # put logits into LM - remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm( - r, - remote_lm_input_stream, - remote_lm_output_partial_stream, - remote_lm_output_partial_lastEntrySeen, - logits, - ) - - # finalize remote LM - remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm( - r, - remote_lm_output_final_stream, - remote_lm_output_final_lastEntrySeen, - ) - - # get the best candidate sentence - best_candidate_sentence = lm_out['candidate_sentences'][0] - - # store results - lm_results['session'].append(session) - lm_results['block'].append(test_data[session]['block_num'][trial]) - lm_results['trial'].append(test_data[session]['trial_num'][trial]) - if eval_type == 'val': - lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial]) - else: - lm_results['true_sentence'].append(None) - lm_results['pred_sentence'].append(best_candidate_sentence) - - # update progress bar - pbar.update(1) -pbar.close() - - -# if using the validation set, lets calculate the aggregate word error rate (WER) -if eval_type == 'val': - total_true_length = 0 - total_edit_distance = 0 - - lm_results['edit_distance'] = [] - lm_results['num_words'] = [] - - for i in range(len(lm_results['pred_sentence'])): - true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip() - pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip() - ed = editdistance.eval(true_sentence.split(), pred_sentence.split()) - - total_true_length += len(true_sentence.split()) - total_edit_distance += ed - - lm_results['edit_distance'].append(ed) - lm_results['num_words'].append(len(true_sentence.split())) - - print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}') - print(f'True sentence: {true_sentence}') - print(f'Predicted sentence: {pred_sentence}') - print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%') - print() - - print(f'Total true sentence length: {total_true_length}') - print(f'Total edit distance: {total_edit_distance}') - print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%') - - -# write predicted sentences to a csv file. put a timestamp in the filename (YYYYMMDD_HHMMSS) -output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.csv') -ids = [i for i in range(len(lm_results['pred_sentence']))] -df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']}) -df_out.to_csv(output_file, index=False) \ No newline at end of file diff --git a/model_training_nnn_tpu/rnn_model.py b/model_training_nnn_tpu/rnn_model.py deleted file mode 100644 index 12d4581..0000000 --- a/model_training_nnn_tpu/rnn_model.py +++ /dev/null @@ -1,580 +0,0 @@ -import torch -from torch import nn -from typing import cast - -class GradientReversalFn(torch.autograd.Function): - """ - Gradient Reversal Layer (GRL) - Forward: identity - Backward: multiply incoming gradient by -lambda - """ - @staticmethod - def forward(ctx, x, lambd: float): - ctx.lambd = lambd - return x.view_as(x) - - @staticmethod - def backward(ctx, grad_output): - return -ctx.lambd * grad_output, None - -def gradient_reverse(x, lambd: float = 1.0): - return GradientReversalFn.apply(x, lambd) - -class NoiseModel(nn.Module): - ''' - Noise Model: 2-layer GRU that learns to estimate noise in the neural data - ''' - def __init__(self, - neural_dim, - n_units, - n_days, - rnn_dropout=0.0, - input_dropout=0.0, - patch_size=0, - patch_stride=0): - super(NoiseModel, self).__init__() - - self.neural_dim = neural_dim - self.n_units = n_units - self.n_days = n_days - self.rnn_dropout = rnn_dropout - self.input_dropout = input_dropout - self.patch_size = patch_size - self.patch_stride = patch_stride - - # Day-specific input layers - self.day_layer_activation = nn.Softsign() - # Let Accelerator handle dtype automatically for TPU compatibility - self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) - self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) - self.day_layer_dropout = nn.Dropout(input_dropout) - - # Calculate input size after patching - self.input_size = self.neural_dim - if self.patch_size > 0: - self.input_size *= self.patch_size - - # 2-layer GRU for noise estimation - self.gru = nn.GRU( - input_size=self.input_size, - hidden_size=self.input_size, # Output same dimension as input - num_layers=2, - dropout=self.rnn_dropout, - batch_first=True, - bidirectional=False, - ) - - # Initialize GRU parameters - for name, param in self.gru.named_parameters(): - if "weight_hh" in name: - nn.init.orthogonal_(param) - if "weight_ih" in name: - nn.init.xavier_uniform_(param) - - # Learnable initial hidden state - let Accelerator handle dtype - self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size))) - - def forward(self, x, day_idx, states=None): - # XLA-friendly day-specific transformation using gather instead of dynamic indexing - batch_size = x.size(0) - - # Stack all day weights and biases upfront for static indexing - all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim] - all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim] - - # XLA-friendly gather operation - day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim] - day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim] - - # Use bmm (batch matrix multiply) which is highly optimized in XLA - # Ensure dtype consistency for mixed precision training - x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) - x = self.day_layer_activation(x) - - # XLA-friendly conditional dropout - if self.input_dropout > 0: - x = self.day_layer_dropout(x) - - # Apply patch processing if enabled with dtype preservation for mixed precision training - if self.patch_size > 0: - original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility - x = x.unsqueeze(1) - x = x.permute(0, 3, 1, 2) - x_unfold = x.unfold(3, self.patch_size, self.patch_stride) - x_unfold = x_unfold.squeeze(2) - x_unfold = x_unfold.permute(0, 2, 3, 1) - x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) - # Ensure dtype consistency after patch processing operations - x = x.to(original_dtype) - - gru_dtype = next(self.gru.parameters()).dtype - if x.dtype != gru_dtype: - x = x.to(gru_dtype) - - # XLA-friendly hidden state initialization - avoid dynamic allocation - if states is None: - states = self.h0.expand(2, batch_size, self.input_size).contiguous() - if states.dtype != gru_dtype: - states = states.to(gru_dtype) - - # Disable autocast for GRU to avoid dtype mismatches on XLA - device_type = x.device.type - with torch.autocast(device_type=device_type, enabled=False): - output, hidden_states = self.gru(x, states) - - return output, hidden_states - - -class CleanSpeechModel(nn.Module): - ''' - Clean Speech Model: 3-layer GRU that processes denoised signal for speech recognition - ''' - def __init__(self, - neural_dim, - n_units, - n_days, - n_classes, - rnn_dropout=0.0, - input_dropout=0.0, - patch_size=0, - patch_stride=0): - super(CleanSpeechModel, self).__init__() - - self.neural_dim = neural_dim - self.n_units = n_units - self.n_days = n_days - self.n_classes = n_classes - self.rnn_dropout = rnn_dropout - self.input_dropout = input_dropout - self.patch_size = patch_size - self.patch_stride = patch_stride - - # Day-specific input layers - self.day_layer_activation = nn.Softsign() - # Let Accelerator handle dtype automatically for TPU compatibility - self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]) - self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]) - self.day_layer_dropout = nn.Dropout(input_dropout) - - # Calculate input size after patching - self.input_size = self.neural_dim - if self.patch_size > 0: - self.input_size *= self.patch_size - - # 3-layer GRU for clean speech recognition - self.gru = nn.GRU( - input_size=self.input_size, - hidden_size=self.n_units, - num_layers=3, - dropout=self.rnn_dropout, - batch_first=True, - bidirectional=False, - ) - - # Initialize GRU parameters - for name, param in self.gru.named_parameters(): - if "weight_hh" in name: - nn.init.orthogonal_(param) - if "weight_ih" in name: - nn.init.xavier_uniform_(param) - - # Output classification layer - self.out = nn.Linear(self.n_units, self.n_classes) - nn.init.xavier_uniform_(self.out.weight) - - # Learnable initial hidden state - self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) - - def forward(self, x, day_idx, states=None, return_state=False): - # XLA-friendly day-specific transformation using gather instead of dynamic indexing - batch_size = x.size(0) - - # Stack all day weights and biases upfront for static indexing - all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim] - all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim] - - # XLA-friendly gather operation - day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim] - day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim] - - # Use bmm (batch matrix multiply) which is highly optimized in XLA - # Ensure dtype consistency for mixed precision training - x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) - x = self.day_layer_activation(x) - - if self.input_dropout > 0: - x = self.day_layer_dropout(x) - - # Apply patch processing if enabled with dtype preservation for mixed precision training - if self.patch_size > 0: - original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility - x = x.unsqueeze(1) - x = x.permute(0, 3, 1, 2) - x_unfold = x.unfold(3, self.patch_size, self.patch_stride) - x_unfold = x_unfold.squeeze(2) - x_unfold = x_unfold.permute(0, 2, 3, 1) - x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) - # Ensure dtype consistency after patch processing operations - x = x.to(original_dtype) - - gru_dtype = next(self.gru.parameters()).dtype - if x.dtype != gru_dtype: - x = x.to(gru_dtype) - - # XLA-friendly hidden state initialization - if states is None: - states = self.h0.expand(3, batch_size, self.n_units).contiguous() - if states.dtype != gru_dtype: - states = states.to(gru_dtype) - - device_type = x.device.type - with torch.autocast(device_type=device_type, enabled=False): - output, hidden_states = self.gru(x, states) - - # Classification - logits = self.out(output) - - if return_state: - return logits, hidden_states - return logits - - -class NoisySpeechModel(nn.Module): - ''' - Noisy Speech Model: 2-layer GRU that processes noise signal for speech recognition - ''' - def __init__(self, - neural_dim, - n_units, - n_days, - n_classes, - rnn_dropout=0.0, - input_dropout=0.0, - patch_size=0, - patch_stride=0): - super(NoisySpeechModel, self).__init__() - - self.neural_dim = neural_dim - self.n_units = n_units - self.n_days = n_days - self.n_classes = n_classes - self.rnn_dropout = rnn_dropout - self.input_dropout = input_dropout - self.patch_size = patch_size - self.patch_stride = patch_stride - - # Calculate input size after patching - self.input_size = self.neural_dim - if self.patch_size > 0: - self.input_size *= self.patch_size - - # 2-layer GRU for noisy speech recognition - self.gru = nn.GRU( - input_size=self.input_size, - hidden_size=self.n_units, - num_layers=2, - dropout=self.rnn_dropout, - batch_first=True, - bidirectional=False, - ) - - # Initialize GRU parameters - for name, param in self.gru.named_parameters(): - if "weight_hh" in name: - nn.init.orthogonal_(param) - if "weight_ih" in name: - nn.init.xavier_uniform_(param) - - # Output classification layer - self.out = nn.Linear(self.n_units, self.n_classes) - nn.init.xavier_uniform_(self.out.weight) - - # Learnable initial hidden state - self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) - - def forward(self, x, states=None, return_state=False): - # Note: NoisySpeechModel doesn't need day-specific layers as it processes noise - batch_size = x.size(0) - - gru_dtype = next(self.gru.parameters()).dtype - if x.dtype != gru_dtype: - x = x.to(gru_dtype) - - # XLA-friendly hidden state initialization - if states is None: - states = self.h0.expand(2, batch_size, self.n_units).contiguous() - if states.dtype != gru_dtype: - states = states.to(gru_dtype) - - device_type = x.device.type - with torch.autocast(device_type=device_type, enabled=False): - output, hidden_states = self.gru(x, states) - - # Classification - logits = self.out(output) - - if return_state: - return logits, hidden_states - return logits - - -class TripleGRUDecoder(nn.Module): - ''' - Three-model adversarial architecture for neural speech decoding - - Combines: - - NoiseModel: estimates noise in neural data - - CleanSpeechModel: processes denoised signal for recognition - - NoisySpeechModel: processes noise signal for recognition - ''' - def __init__(self, - neural_dim, - n_units, - n_days, - n_classes, - rnn_dropout=0.0, - input_dropout=0.0, - patch_size=0, - patch_stride=0, - ): - ''' - neural_dim (int) - number of channels in a single timestep (e.g. 512) - n_units (int) - number of hidden units in each recurrent layer - n_days (int) - number of days in the dataset - n_classes (int) - number of classes (phonemes) - rnn_dropout (float) - percentage of units to dropout during training - input_dropout (float) - percentage of input units to dropout during training - patch_size (int) - number of timesteps to concat on initial input layer - patch_stride(int) - number of timesteps to stride over when concatenating initial input - ''' - super(TripleGRUDecoder, self).__init__() - - self.neural_dim = neural_dim - self.n_units = n_units - self.n_classes = n_classes - self.n_days = n_days - - self.rnn_dropout = rnn_dropout - self.input_dropout = input_dropout - self.patch_size = patch_size - self.patch_stride = patch_stride - - # Create the three models - self.noise_model = NoiseModel( - neural_dim=neural_dim, - n_units=n_units, - n_days=n_days, - rnn_dropout=rnn_dropout, - input_dropout=input_dropout, - patch_size=patch_size, - patch_stride=patch_stride - ) - - self.clean_speech_model = CleanSpeechModel( - neural_dim=neural_dim, - n_units=n_units, - n_days=n_days, - n_classes=n_classes, - rnn_dropout=rnn_dropout, - input_dropout=input_dropout, - patch_size=patch_size, - patch_stride=patch_stride - ) - - self.noisy_speech_model = NoisySpeechModel( - neural_dim=neural_dim, - n_units=n_units, - n_days=n_days, - n_classes=n_classes, - rnn_dropout=rnn_dropout, - input_dropout=input_dropout, - patch_size=patch_size, - patch_stride=patch_stride - ) - - # Training mode flag - self.training_mode = 'full' # 'full', 'inference' - - def _apply_preprocessing(self, x, day_idx): - '''XLA-friendly preprocessing with static operations''' - batch_size = x.size(0) - - # XLA-friendly day-specific transformation using gather instead of dynamic indexing - all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0) - all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0) - - # XLA-friendly gather operation - day_weights = torch.index_select(all_day_weights, 0, day_idx) - day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) - - # Use bmm (batch matrix multiply) which is highly optimized in XLA - # Ensure dtype consistency for mixed precision training - x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) - x_processed = self.clean_speech_model.day_layer_activation(x_processed) - - # Apply patch processing if enabled with dtype preservation for mixed precision training - if self.patch_size > 0: - original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility - x_processed = x_processed.unsqueeze(1) - x_processed = x_processed.permute(0, 3, 1, 2) - x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride) - x_unfold = x_unfold.squeeze(2) - x_unfold = x_unfold.permute(0, 2, 3, 1) - x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1) - # Ensure dtype consistency after patch processing operations - x_processed = x_processed.to(original_dtype) - - return x_processed - - def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None): - '''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)''' - batch_size = x_processed.size(0) - - clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype - if x_processed.dtype != clean_gru_dtype: - x_processed = x_processed.to(clean_gru_dtype) - - # XLA-friendly hidden state initialization with dtype consistency - if states is None: - states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous() - # Ensure hidden states match input dtype for mixed precision training - if states.dtype != clean_gru_dtype: - states = states.to(clean_gru_dtype) - - # GRU forward pass (skip preprocessing since input is already processed) - device_type = x_processed.device.type - with torch.autocast(device_type=device_type, enabled=False): - output, hidden_states = self.clean_speech_model.gru(x_processed, states) - - # Classification - logits = self.clean_speech_model.out(output) - return logits - - def _noisy_forward_with_processed_input(self, x_processed, states=None): - '''Forward pass for NoisySpeechModel with already processed input''' - batch_size = x_processed.size(0) - - noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype - if x_processed.dtype != noisy_gru_dtype: - x_processed = x_processed.to(noisy_gru_dtype) - - # XLA-friendly hidden state initialization with dtype consistency - if states is None: - states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous() - # Ensure hidden states match input dtype for mixed precision training - if states.dtype != noisy_gru_dtype: - states = states.to(noisy_gru_dtype) - - # GRU forward pass (NoisySpeechModel doesn't have day layers anyway) - device_type = x_processed.device.type - with torch.autocast(device_type=device_type, enabled=False): - output, hidden_states = self.noisy_speech_model.gru(x_processed, states) - - # Classification - logits = self.noisy_speech_model.out(output) - return logits - - def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0): - ''' - Three-model adversarial forward pass - - x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim) - day_idx (tensor) - tensor of day indices for each example in the batch - states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None - mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only) - grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input - ''' - - if mode == 'full': - # Training mode: run all three models - - # 1. Noise model estimates noise in the data - noise_output, noise_hidden = self.noise_model(x, day_idx, - states['noise'] if states else None) - - # 2. For residual connection, we need x in the same space as noise_output - # Apply the same preprocessing that the models use internally - x_processed = self._apply_preprocessing(x, day_idx) - clean_dtype = next(self.clean_speech_model.parameters()).dtype - if x_processed.dtype != clean_dtype: - x_processed = x_processed.to(clean_dtype) - - # Ensure dtype consistency between processed input and noise output - if noise_output.dtype != clean_dtype: - noise_output = noise_output.to(clean_dtype) - - # 3. Clean speech model processes denoised signal - denoised_input = x_processed - noise_output # Residual connection in processed space - # Clean speech model will apply its own preprocessing, so we pass the denoised processed data - # But we need to reverse the preprocessing first, then let clean model do its own - # Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing - clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, - states['clean'] if states else None) - - # 4. Noisy speech model processes noise signal directly (no day layers needed) - # Optionally apply Gradient Reversal to enforce adversarial training on noise output - noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output - noisy_input = cast(torch.Tensor, noisy_input) - noisy_dtype = next(self.noisy_speech_model.parameters()).dtype - if noisy_input.dtype != noisy_dtype: - noisy_input = noisy_input.to(noisy_dtype) - noisy_logits = self._noisy_forward_with_processed_input(noisy_input, - states['noisy'] if states else None) - - # XLA-friendly return - use tuple instead of dict for better compilation - if return_state: - return (clean_logits, noisy_logits, noise_output), noise_hidden - return clean_logits, noisy_logits, noise_output - - elif mode == 'inference': - # Inference mode: only noise model + clean speech model - - # 1. Estimate noise - noise_output, noise_hidden = self.noise_model(x, day_idx, - states['noise'] if states else None) - - # 2. For residual connection, we need x in the same space as noise_output - x_processed = self._apply_preprocessing(x, day_idx) - clean_dtype = next(self.clean_speech_model.parameters()).dtype - if x_processed.dtype != clean_dtype: - x_processed = x_processed.to(clean_dtype) - - # Ensure dtype consistency for mixed precision residual connection - if noise_output.dtype != clean_dtype: - noise_output = noise_output.to(clean_dtype) - denoised_input = x_processed - noise_output - clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, - states['clean'] if states else None) - - # XLA-friendly return - use tuple for consistency - if return_state: - return clean_logits, noise_hidden - return clean_logits - - else: - raise ValueError(f"Unknown mode: {mode}. Use 'full' or 'inference'") - - def apply_gradient_combination(self, clean_grad, noisy_grad, learning_rate=1e-3): - ''' - Apply combined gradients to noise model parameters - - clean_grad (tensor) - gradients from clean speech model output layer - noisy_grad (tensor) - gradients from noisy speech model output layer - ''' - # Combine gradients: negative from clean model, positive from noisy model - combined_grad = -clean_grad + noisy_grad - - # Apply gradients to noise model parameters - # This is a simplified implementation - in practice you'd want more sophisticated update rules - with torch.no_grad(): - for param in self.noise_model.parameters(): - if param.grad is not None: - # Scale the combined gradient appropriately - # This is a placeholder - you'd need to implement proper gradient mapping - param.data -= learning_rate * combined_grad.mean() * torch.ones_like(param.data) - - def set_mode(self, mode): - '''Set the operating mode''' - self.training_mode = mode - - diff --git a/model_training_nnn_tpu/rnn_trainer.py b/model_training_nnn_tpu/rnn_trainer.py deleted file mode 100644 index e4bf28e..0000000 --- a/model_training_nnn_tpu/rnn_trainer.py +++ /dev/null @@ -1,952 +0,0 @@ -import os - -# XLA multi-threading optimization - MUST be set before importing torch_xla -# Set these environment variables early to ensure they take effect -if 'TPU_CORES' in os.environ or 'COLAB_TPU_ADDR' in os.environ: - # Enable XLA multi-threading for compilation speedup - os.environ.setdefault('XLA_FLAGS', - '--xla_cpu_multi_thread_eigen=true ' + - '--xla_cpu_enable_fast_math=true ' + - f'--xla_force_host_platform_device_count={os.cpu_count()}' - ) - # Set PyTorch XLA threading - os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count())) - print(f"Set XLA compilation threads to {os.cpu_count()}") - -import torch -from torch.utils.data import DataLoader -from torch.optim.lr_scheduler import LambdaLR -import random -import time -import numpy as np -import math -import pathlib -import logging -import sys -import json -import pickle -from contextlib import nullcontext - -from dataset import BrainToTextDataset, train_test_split_indicies -from data_augmentations import gauss_smooth - -import torchaudio.functional as F # for edit distance -from omegaconf import OmegaConf - -# Import Accelerate for TPU support -from accelerate import Accelerator, DataLoaderConfiguration -from accelerate.utils import set_seed - -# Import XLA after setting environment variables -import torch_xla.core.xla_model as xm - -torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs -torch.backends.cudnn.deterministic = True # makes training more reproducible -torch._dynamo.config.cache_size_limit = 64 - -from rnn_model import TripleGRUDecoder - -class BrainToTextDecoder_Trainer: - """ - This class will initialize and train a brain-to-text phoneme decoder - - Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function - """ - - def __init__(self, args): - ''' - args : dictionary of training arguments - ''' - - # Configure DataLoader behavior for TPU compatibility - dataloader_config = DataLoaderConfiguration( - even_batches=False # Required for batch_size=None DataLoaders on TPU - ) - - # Initialize Accelerator for TPU/multi-device support - self.use_xla = bool(xm.get_xla_supported_devices()) - self.amp_requested = args.get('use_amp', True) - mixed_precision_mode = 'bf16' if self.amp_requested else 'no' - - self.accelerator = Accelerator( - mixed_precision=mixed_precision_mode, - gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), - log_with=None, # We'll use our own logging - project_dir=args.get('output_dir', './output'), - dataloader_config=dataloader_config, - ) - - - # Trainer fields - self.args = args - self.logger = None - self.device = self.accelerator.device # Use accelerator device instead of manual device selection - self.model = None - self.optimizer = None - self.learning_rate_scheduler = None - self.ctc_loss = None - - self.best_val_PER = torch.inf # track best PER for checkpointing - self.best_val_loss = torch.inf # track best loss for checkpointing - - self.train_dataset = None - self.val_dataset = None - self.train_loader = None - self.val_loader = None - - self.transform_args = self.args['dataset']['data_transforms'] - - # Adversarial training config (safe defaults if not provided) - adv_cfg = self.args.get('adversarial', {}) - self.adv_enabled = adv_cfg.get('enabled', False) - self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength - self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC - self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output - self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps - - # Create output directory - if args['mode'] == 'train': - os.makedirs(self.args['output_dir'], exist_ok=True) - - # Create checkpoint directory - if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']: - os.makedirs(self.args['checkpoint_dir'], exist_ok=True) - - # Set up logging - self.logger = logging.getLogger(__name__) - for handler in self.logger.handlers[:]: # make a copy of the list - self.logger.removeHandler(handler) - self.logger.setLevel(logging.INFO) - formatter = logging.Formatter(fmt='%(asctime)s: %(message)s') - - if args['mode']=='train': - # During training, save logs to file in output directory - fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log'))) - fh.setFormatter(formatter) - self.logger.addHandler(fh) - - # Always print logs to stdout - sh = logging.StreamHandler(sys.stdout) - sh.setFormatter(formatter) - self.logger.addHandler(sh) - - # Log device information (managed by Accelerator) - self.logger.info(f'Using device: {self.device}') - self.logger.info(f'Accelerator state: {self.accelerator.state}') - if self.accelerator.num_processes > 1: - self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes') - if self.use_xla and self.amp_requested: - self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.') - - # Set seed if provided (using Accelerator's set_seed for proper distributed seeding) - if self.args['seed'] != -1: - set_seed(self.args['seed']) - - # Initialize the model - self.model = TripleGRUDecoder( - neural_dim = self.args['model']['n_input_features'], - n_units = self.args['model']['n_units'], - n_days = len(self.args['dataset']['sessions']), - n_classes = self.args['dataset']['n_classes'], - rnn_dropout = self.args['model']['rnn_dropout'], - input_dropout = self.args['model']['input_network']['input_layer_dropout'], - patch_size = self.args['model']['patch_size'], - patch_stride = self.args['model']['patch_stride'], - ) - - if self.use_xla and self.amp_requested: - self.model = self.model.to(torch.bfloat16) - self.logger.info('Converted model parameters to bfloat16 for TPU training.') - - self.model_dtype = next(self.model.parameters()).dtype - - # Temporarily disable torch.compile for compatibility with new model architecture - # TODO: Re-enable torch.compile once model is stable - # self.logger.info("Using torch.compile") - # self.model = torch.compile(self.model) - self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility") - - self.logger.info(f"Initialized RNN decoding model") - - self.logger.info(self.model) - - # Log how many parameters are in the model - total_params = sum(p.numel() for p in self.model.parameters()) - self.logger.info(f"Model has {total_params:,} parameters") - - # Determine how many day-specific parameters are in the model - day_params = 0 - for name, param in self.model.named_parameters(): - if 'day' in name: - day_params += param.numel() - - self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters") - - # Create datasets and dataloaders - train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']] - val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']] - - # Ensure that there are no duplicate days - if len(set(train_file_paths)) != len(train_file_paths): - raise ValueError("There are duplicate sessions listed in the train dataset") - if len(set(val_file_paths)) != len(val_file_paths): - raise ValueError("There are duplicate sessions listed in the val dataset") - - # Split trials into train and test sets - train_trials, _ = train_test_split_indicies( - file_paths = train_file_paths, - test_percentage = 0, - seed = self.args['dataset']['seed'], - bad_trials_dict = None, - ) - _, val_trials = train_test_split_indicies( - file_paths = val_file_paths, - test_percentage = 1, - seed = self.args['dataset']['seed'], - bad_trials_dict = None, - ) - - # Save dictionaries to output directory to know which trials were train vs val - with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f: - json.dump({'train' : train_trials, 'val': val_trials}, f) - - # Determine if a only a subset of neural features should be used - feature_subset = None - if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None: - feature_subset = self.args['dataset']['feature_subset'] - self.logger.info(f'Using only a subset of features: {feature_subset}') - - # train dataset and dataloader - self.train_dataset = BrainToTextDataset( - trial_indicies = train_trials, - split = 'train', - days_per_batch = self.args['dataset']['days_per_batch'], - n_batches = self.args['num_training_batches'], - batch_size = self.args['dataset']['batch_size'], - must_include_days = None, - random_seed = self.args['dataset']['seed'], - feature_subset = feature_subset - ) - # Custom collate function that handles pre-batched data from our dataset - def collate_fn(batch): - # Our dataset returns full batches, so batch will be a list of single batch dict - # Extract the first (and only) element since our dataset.__getitem__() returns a full batch - if len(batch) == 1 and isinstance(batch[0], dict): - return batch[0] - else: - # Fallback for unexpected batch structure - return batch - - # DataLoader configuration compatible with Accelerate - self.train_loader = DataLoader( - self.train_dataset, - batch_size = 1, # Use batch_size=1 since dataset returns full batches - shuffle = self.args['dataset']['loader_shuffle'], - num_workers = self.args['dataset']['num_dataloader_workers'], - pin_memory = True, - collate_fn = collate_fn - ) - - # val dataset and dataloader - self.val_dataset = BrainToTextDataset( - trial_indicies = val_trials, - split = 'test', - days_per_batch = None, - n_batches = None, - batch_size = self.args['dataset']['batch_size'], - must_include_days = None, - random_seed = self.args['dataset']['seed'], - feature_subset = feature_subset - ) - # Validation DataLoader with same collate function - self.val_loader = DataLoader( - self.val_dataset, - batch_size = 1, # Use batch_size=1 since dataset returns full batches - shuffle = False, - num_workers = 0, # Keep validation dataloader single-threaded for consistency - pin_memory = True, - collate_fn = collate_fn # Use same collate function - ) - - self.logger.info("Successfully initialized datasets") - - # Create optimizer, learning rate scheduler, and loss - self.optimizer = self.create_optimizer() - - if self.args['lr_scheduler_type'] == 'linear': - self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer = self.optimizer, - start_factor = 1.0, - end_factor = self.args['lr_min'] / self.args['lr_max'], - total_iters = self.args['lr_decay_steps'], - ) - elif self.args['lr_scheduler_type'] == 'cosine': - self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer) - - else: - raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}") - - self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False) - - # If a checkpoint is provided, then load from checkpoint - if self.args['init_from_checkpoint']: - self.load_model_checkpoint(self.args['init_checkpoint_path']) - - # Set rnn and/or input layers to not trainable if specified - for name, param in self.model.named_parameters(): - if not self.args['model']['rnn_trainable'] and 'gru' in name: - param.requires_grad = False - - elif not self.args['model']['input_network']['input_trainable'] and 'day' in name: - param.requires_grad = False - - # Prepare model, optimizer, scheduler, and dataloaders for distributed training - # Let Accelerator handle everything automatically for both GPU and TPU - ( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) = self.accelerator.prepare( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) - - self.model_dtype = next(self.model.parameters()).dtype - - self.logger.info("Prepared model and dataloaders with Accelerator") - if self.adv_enabled: - self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}") - - def autocast_context(self): - """Return appropriate autocast context; disable on XLA to avoid dtype mismatches.""" - if self.device.type == 'xla': - return nullcontext() - return self.accelerator.autocast() - - def create_optimizer(self): - ''' - Create the optimizer with special param groups - - Biases and day weights should not be decayed - - Day weights should have a separate learning rate - ''' - bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name] - day_params = [p for name, p in self.model.named_parameters() if 'day_' in name] - other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name] - - if len(day_params) != 0: - param_groups = [ - {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'}, - {'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'}, - {'params' : other_params, 'group_type' : 'other'} - ] - else: - param_groups = [ - {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'}, - {'params' : other_params, 'group_type' : 'other'} - ] - - optim = torch.optim.AdamW( - param_groups, - lr = self.args['lr_max'], - betas = (self.args['beta0'], self.args['beta1']), - eps = self.args['epsilon'], - weight_decay = self.args['weight_decay'], - fused = True - ) - - return optim - - def create_cosine_lr_scheduler(self, optim): - lr_max = self.args['lr_max'] - lr_min = self.args['lr_min'] - lr_decay_steps = self.args['lr_decay_steps'] - - lr_max_day = self.args['lr_max_day'] - lr_min_day = self.args['lr_min_day'] - lr_decay_steps_day = self.args['lr_decay_steps_day'] - - lr_warmup_steps = self.args['lr_warmup_steps'] - lr_warmup_steps_day = self.args['lr_warmup_steps_day'] - - def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps): - ''' - Create lr lambdas for each param group that implement cosine decay - - Different lr lambda decaying for day params vs rest of the model - ''' - # Warmup phase - if current_step < warmup_steps: - return float(current_step) / float(max(1, warmup_steps)) - - # Cosine decay phase - if current_step < decay_steps: - progress = float(current_step - warmup_steps) / float( - max(1, decay_steps - warmup_steps) - ) - cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) - # Scale from 1.0 to min_lr_ratio - return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay) - - # After cosine decay is complete, maintain min_lr_ratio - return min_lr_ratio - - if len(optim.param_groups) == 3: - lr_lambdas = [ - lambda step: lr_lambda( - step, - lr_min / lr_max, - lr_decay_steps, - lr_warmup_steps), # biases - lambda step: lr_lambda( - step, - lr_min_day / lr_max_day, - lr_decay_steps_day, - lr_warmup_steps_day, - ), # day params - lambda step: lr_lambda( - step, - lr_min / lr_max, - lr_decay_steps, - lr_warmup_steps), # rest of model weights - ] - elif len(optim.param_groups) == 2: - lr_lambdas = [ - lambda step: lr_lambda( - step, - lr_min / lr_max, - lr_decay_steps, - lr_warmup_steps), # biases - lambda step: lr_lambda( - step, - lr_min / lr_max, - lr_decay_steps, - lr_warmup_steps), # rest of model weights - ] - else: - raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}") - - return LambdaLR(optim, lr_lambdas, -1) - - def load_model_checkpoint(self, load_path): - ''' - Load a training checkpoint for distributed training - ''' - # Load checkpoint on CPU first to avoid OOM issues - checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict - - # Get unwrapped model for loading state dict - unwrapped_model = self.accelerator.unwrap_model(self.model) - unwrapped_model.load_state_dict(checkpoint['model_state_dict']) - - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) - self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate - self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf - - # Device handling is managed by Accelerator, no need to manually move to device - - self.logger.info("Loaded model from checkpoint: " + load_path) - - def save_model_checkpoint(self, save_path, PER, loss): - ''' - Save a training checkpoint using Accelerator for distributed training - ''' - # Only save on main process to avoid conflicts - if self.accelerator.is_main_process: - # Unwrap model to get base model for saving - unwrapped_model = self.accelerator.unwrap_model(self.model) - - checkpoint = { - 'model_state_dict' : unwrapped_model.state_dict(), - 'optimizer_state_dict' : self.optimizer.state_dict(), - 'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(), - 'val_PER' : PER, - 'val_loss' : loss - } - - torch.save(checkpoint, save_path) - - self.logger.info("Saved model to checkpoint: " + save_path) - - # Save the args file alongside the checkpoint - with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: - OmegaConf.save(config=self.args, f=f) - - # Wait for all processes to complete checkpoint saving - self.accelerator.wait_for_everyone() - - def create_attention_mask(self, sequence_lengths): - - max_length = torch.max(sequence_lengths).item() - - batch_size = sequence_lengths.size(0) - - # Create a mask for valid key positions (columns) - # Shape: [batch_size, max_length] - key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length) - key_mask = key_mask < sequence_lengths.unsqueeze(1) - - # Expand key_mask to [batch_size, 1, 1, max_length] - # This will be broadcast across all query positions - key_mask = key_mask.unsqueeze(1).unsqueeze(1) - - # Create the attention mask of shape [batch_size, 1, max_length, max_length] - # by broadcasting key_mask across all query positions - attention_mask = key_mask.expand(batch_size, 1, max_length, max_length) - - # Convert boolean mask to float mask: - # - True (valid key positions) -> 0.0 (no change to attention scores) - # - False (padding key positions) -> -inf (will become 0 after softmax) - attention_mask_float = torch.where(attention_mask, - True, - False) - - return attention_mask_float - - def transform_data(self, features, n_time_steps, mode = 'train'): - ''' - Apply various augmentations and smoothing to data - Performing augmentations is much faster on GPU than CPU - ''' - - # TPU and GPU should now handle data consistently with our improved DataLoader configuration - - data_shape = features.shape - batch_size = data_shape[0] - channels = data_shape[-1] - - # We only apply these augmentations in training - if mode == 'train': - # add static gain noise - if self.transform_args['static_gain_std'] > 0: - warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1)) - warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std'] - - features = torch.matmul(features, warp_mat) - - # add white noise - if self.transform_args['white_noise_std'] > 0: - features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std'] - - # add constant offset noise - if self.transform_args['constant_offset_std'] > 0: - features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std'] - - # add random walk noise - if self.transform_args['random_walk_std'] > 0: - features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis']) - - # randomly cutoff part of the data timecourse - if self.transform_args['random_cut'] > 0: - cut = np.random.randint(0, self.transform_args['random_cut']) - features = features[:, cut:, :] - n_time_steps = n_time_steps - cut - - # Apply Gaussian smoothing to data - # This is done in both training and validation - if self.transform_args['smooth_data']: - features = gauss_smooth( - inputs = features, - device = self.device, - smooth_kernel_std = self.transform_args['smooth_kernel_std'], - smooth_kernel_size= self.transform_args['smooth_kernel_size'], - ) - - if hasattr(self, 'model_dtype'): - features = features.to(self.model_dtype) - - - return features, n_time_steps - - def train(self): - ''' - Train the model - ''' - - # Set model to train mode (specificially to make sure dropout layers are engaged) - self.model.train() - - # create vars to track performance - train_losses = [] - val_losses = [] - val_PERs = [] - val_results = [] - - val_steps_since_improvement = 0 - - # training params - save_best_checkpoint = self.args.get('save_best_checkpoint', True) - early_stopping = self.args.get('early_stopping', True) - - early_stopping_val_steps = self.args.get('early_stopping_val_steps', 20) - - train_start_time = time.time() - - # train for specified number of batches - self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...") - for i, batch in enumerate(self.train_loader): - - self.model.train() - self.optimizer.zero_grad() - - # Train step - start_time = time.time() - - # Data is automatically moved to device by Accelerator - features = batch['input_features'] - labels = batch['seq_class_ids'] - n_time_steps = batch['n_time_steps'] - phone_seq_lens = batch['phone_seq_lens'] - day_indicies = batch['day_indicies'] - - # Use Accelerator's autocast (mixed precision handled by Accelerator init) - with self.autocast_context(): - - # Apply augmentations to the data - features, n_time_steps = self.transform_data(features, n_time_steps, 'train') - - # Ensure proper dtype handling for TPU mixed precision - adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) - - # Get phoneme predictions using inference mode during training - # (We use inference mode for simplicity - only clean logits are used for CTC loss) - # Ensure features tensor matches model parameter dtype for TPU compatibility - if features.dtype != self.model_dtype: - features = features.to(self.model_dtype) - - # Forward pass: enable full adversarial mode if configured and past warmup - use_full = self.adv_enabled and (i >= self.adv_warmup_steps) - if use_full: - clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda) - else: - logits = self.model(features, day_indicies, None, False, 'inference') - - # Calculate CTC Loss - if use_full: - # Clean CTC loss - clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2) - clean_loss = self.ctc_loss( - clean_log_probs, - labels, - adjusted_lens, - phone_seq_lens - ) - clean_loss = torch.mean(clean_loss) - - # Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗) - noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2) - noisy_loss = self.ctc_loss( - noisy_log_probs, - labels, - adjusted_lens, - phone_seq_lens - ) - noisy_loss = torch.mean(noisy_loss) - - # Optional noise energy regularization - noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype) - if self.adv_noise_l2_weight > 0.0: - noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype) - - loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 - else: - log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2) - loss = self.ctc_loss( - log_probs=log_probs, - targets=labels, - input_lengths=adjusted_lens, - target_lengths=phone_seq_lens - ) - loss = torch.mean(loss) # take mean loss over batches - - # Use Accelerator's backward for distributed training - self.accelerator.backward(loss) - - # Clip gradient using Accelerator's clip_grad_norm_ - if self.args['grad_norm_clip_value'] > 0: - grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(), - max_norm = self.args['grad_norm_clip_value']) - - self.optimizer.step() - self.learning_rate_scheduler.step() - - # Save training metrics - train_step_duration = time.time() - start_time - train_losses.append(loss.detach().item()) - - # Incrementally log training progress - if i % self.args['batches_per_train_log'] == 0: - self.logger.info(f'Train batch {i}: ' + - f'loss: {(loss.detach().item()):.2f} ' + - f'grad norm: {grad_norm:.2f} ' - f'time: {train_step_duration:.3f}') - - # Incrementally run a test step - if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)): - self.logger.info(f"Running test after training batch: {i}") - - # Calculate metrics on val data - start_time = time.time() - val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data']) - val_step_duration = time.time() - start_time - - - # Log info - self.logger.info(f'Val batch {i}: ' + - f'PER (avg): {val_metrics["avg_PER"]:.4f} ' + - f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' + - f'time: {val_step_duration:.3f}') - - if self.args['log_individual_day_val_PER']: - for day in val_metrics['day_PERs'].keys(): - self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}") - - # Save metrics - val_PERs.append(val_metrics['avg_PER']) - val_losses.append(val_metrics['avg_loss']) - val_results.append(val_metrics) - - # Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower - new_best = False - if val_metrics['avg_PER'] < self.best_val_PER: - self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}") - self.best_val_PER = val_metrics['avg_PER'] - self.best_val_loss = val_metrics['avg_loss'] - new_best = True - elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss): - self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}") - self.best_val_loss = val_metrics['avg_loss'] - new_best = True - - if new_best: - - # Checkpoint if metrics have improved - if save_best_checkpoint: - self.logger.info(f"Checkpointing model") - self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss) - - # save validation metrics to pickle file - if self.args['save_val_metrics']: - with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f: - pickle.dump(val_metrics, f) - - val_steps_since_improvement = 0 - - else: - val_steps_since_improvement +=1 - - # Optionally save this validation checkpoint, regardless of performance - if self.args['save_all_val_steps']: - self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss']) - - # Early stopping - if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps): - self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}') - break - - # Log final training steps - training_duration = time.time() - train_start_time - - - self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}') - self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes') - - # Save final model - if self.args['save_final_model']: - last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf') - self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss) - - train_stats = {} - train_stats['train_losses'] = train_losses - train_stats['val_losses'] = val_losses - train_stats['val_PERs'] = val_PERs - train_stats['val_metrics'] = val_results - - return train_stats - - def validation(self, loader, return_logits = False, return_data = False): - ''' - Calculate metrics on the validation dataset - ''' - self.model.eval() - - metrics = {} - - # Record metrics - if return_logits: - metrics['logits'] = [] - metrics['n_time_steps'] = [] - - if return_data: - metrics['input_features'] = [] - - metrics['decoded_seqs'] = [] - metrics['true_seq'] = [] - metrics['phone_seq_lens'] = [] - metrics['transcription'] = [] - metrics['losses'] = [] - metrics['block_nums'] = [] - metrics['trial_nums'] = [] - metrics['day_indicies'] = [] - - total_edit_distance = 0 - total_seq_length = 0 - - # Calculate PER for each specific day - day_per = {} - for d in range(len(self.args['dataset']['sessions'])): - if self.args['dataset']['dataset_probability_val'][d] == 1: - day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0} - - for i, batch in enumerate(loader): - - # Data is automatically moved to device by Accelerator - features = batch['input_features'] - labels = batch['seq_class_ids'] - n_time_steps = batch['n_time_steps'] - phone_seq_lens = batch['phone_seq_lens'] - day_indicies = batch['day_indicies'] - - # Determine if we should perform validation on this batch - day = day_indicies[0].item() - if self.args['dataset']['dataset_probability_val'][day] == 0: - if self.args['log_val_skip_logs']: - self.logger.info(f"Skipping validation on day {day}") - continue - - with torch.no_grad(): - - with self.autocast_context(): - features, n_time_steps = self.transform_data(features, n_time_steps, 'val') - - # Ensure proper dtype handling for TPU mixed precision - adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) - - # Ensure features tensor matches model parameter dtype for TPU compatibility - model_param = next(self.model.parameters()) if self.model is not None else None - if model_param is not None and features.dtype != model_param.dtype: - features = features.to(model_param.dtype) - - logits = self.model(features, day_indicies, None, False, 'inference') - - val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2) - loss = self.ctc_loss( - val_log_probs, - labels, - adjusted_lens, - phone_seq_lens, - ) - loss = torch.mean(loss) - - metrics['losses'].append(loss.cpu().detach().numpy()) - - # Calculate PER per day and also avg over entire validation set - batch_edit_distance = 0 - decoded_seqs = [] - for iterIdx in range(logits.shape[0]): - decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1) - decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1) - decoded_seq = decoded_seq.cpu().detach().numpy() - decoded_seq = np.array([i for i in decoded_seq if i != 0]) - - trueSeq = np.array( - labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach() - ) - - batch_edit_distance += F.edit_distance(decoded_seq, trueSeq) - - decoded_seqs.append(decoded_seq) - - day = batch['day_indicies'][0].item() - - day_per[day]['total_edit_distance'] += batch_edit_distance - day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item() - - - total_edit_distance += batch_edit_distance - total_seq_length += torch.sum(phone_seq_lens) - - # Record metrics - if return_logits: - metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32 - metrics['n_time_steps'].append(adjusted_lens.cpu().numpy()) - - if return_data: - metrics['input_features'].append(batch['input_features'].cpu().numpy()) - - metrics['decoded_seqs'].append(decoded_seqs) - metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy()) - metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy()) - metrics['transcription'].append(batch['transcriptions'].cpu().numpy()) - metrics['losses'].append(loss.detach().item()) - metrics['block_nums'].append(batch['block_nums'].numpy()) - metrics['trial_nums'].append(batch['trial_nums'].numpy()) - metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy()) - - if isinstance(total_seq_length, torch.Tensor): - total_length_value = float(total_seq_length.item()) - else: - total_length_value = float(total_seq_length) - - avg_PER = total_edit_distance / max(total_length_value, 1e-6) - - metrics['day_PERs'] = day_per - metrics['avg_PER'] = avg_PER - metrics['avg_loss'] = float(np.mean(metrics['losses'])) - - return metrics - - def inference(self, features, day_indicies, n_time_steps, mode='inference'): - ''' - TPU-compatible inference method for generating phoneme logits - ''' - self.model.eval() - - with torch.no_grad(): - with self.autocast_context(): - # Apply data transformations (no augmentation for inference) - features, n_time_steps = self.transform_data(features, n_time_steps, 'val') - - # Ensure features tensor matches model parameter dtype for TPU compatibility - if features.dtype != self.model_dtype: - features = features.to(self.model_dtype) - - # Get phoneme predictions - logits = self.model(features, day_indicies, None, False, mode) - - return logits - - def inference_batch(self, batch, mode='inference'): - ''' - Inference method for processing a full batch - ''' - self.model.eval() - - # Data is automatically moved to device by Accelerator - features = batch['input_features'] - day_indicies = batch['day_indicies'] - n_time_steps = batch['n_time_steps'] - - with torch.no_grad(): - with self.autocast_context(): - # Apply data transformations (no augmentation for inference) - features, n_time_steps = self.transform_data(features, n_time_steps, 'val') - - # Calculate adjusted sequence lengths for CTC with proper dtype handling - adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) - - # Ensure features tensor matches model parameter dtype for TPU compatibility - if features.dtype != self.model_dtype: - features = features.to(self.model_dtype) - - # Get phoneme predictions - logits = self.model(features, day_indicies, None, False, mode) - - return logits, adjusted_lens \ No newline at end of file diff --git a/model_training_nnn_tpu/setup_tensorflow_tpu.sh b/model_training_nnn_tpu/setup_tensorflow_tpu.sh deleted file mode 100644 index 19da6c3..0000000 --- a/model_training_nnn_tpu/setup_tensorflow_tpu.sh +++ /dev/null @@ -1,150 +0,0 @@ -#!/bin/bash -# Setup script for TensorFlow Brain-to-Text training on TPU v5e-8 -# -# Usage: ./setup_tensorflow_tpu.sh -# -# This script prepares the environment for training the brain-to-text model -# using TensorFlow on TPU v5e-8 hardware. - -set -e # Exit on any error - -echo "=== TensorFlow TPU v5e-8 Setup Script ===" -echo "Setting up environment for brain-to-text training..." - -# Check if we're in a TPU environment -if [[ -z "${TPU_NAME}" ]] && [[ -z "${COLAB_TPU_ADDR}" ]]; then - echo "Warning: TPU environment variables not detected." - echo "Make sure you're running on a TPU v5e-8 instance." -fi - -# Create conda environment for TensorFlow TPU -ENV_NAME="b2txt_tf" -echo "Creating conda environment: ${ENV_NAME}" - -if conda env list | grep -q "^${ENV_NAME} "; then - echo "Environment ${ENV_NAME} already exists. Activating..." - conda activate ${ENV_NAME} -else - echo "Creating new environment..." - conda create -n ${ENV_NAME} python=3.10 -y - conda activate ${ENV_NAME} -fi - -# Install TensorFlow with TPU support -echo "Installing TensorFlow with TPU support..." -pip install tensorflow[and-cuda]>=2.15.0 - -# Install additional requirements -echo "Installing additional requirements..." -pip install -r requirements_tf.txt - -# Set up TPU environment variables -echo "Configuring TPU environment variables..." - -# Create or update .bashrc with TPU optimizations -cat >> ~/.bashrc << 'EOF' - -# TPU v5e-8 Environment Variables -export TPU_ML_PLATFORM="TensorFlow" -export XLA_USE_BF16=1 -export TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" -export TPU_MEGACORE=1 -export LIBTPU_INIT_ARGS="--xla_tpu_spmd_threshold_for_allgather_cse=10000" - -# Disable TensorFlow warnings for cleaner output -export TF_CPP_MIN_LOG_LEVEL=2 - -# Memory optimizations -export TF_FORCE_GPU_ALLOW_GROWTH=true -export TF_GPU_THREAD_MODE=gpu_private - -EOF - -# Source the updated .bashrc -source ~/.bashrc - -# Test TPU connectivity -echo "Testing TPU connectivity..." -python3 << 'EOF' -import tensorflow as tf -print("TensorFlow version:", tf.__version__) - -try: - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - tf.config.experimental_connect_to_cluster(resolver) - tf.tpu.experimental.initialize_tpu_system(resolver) - strategy = tf.distribute.TPUStrategy(resolver) - print(f"TPU cluster initialized successfully!") - print(f"Number of TPU cores: {strategy.num_replicas_in_sync}") - print(f"TPU devices: {tf.config.list_logical_devices('TPU')}") -except Exception as e: - print(f"TPU initialization failed: {e}") - print("You may be running on CPU/GPU instead of TPU") - -# Test mixed precision -policy = tf.keras.mixed_precision.Policy('mixed_bfloat16') -tf.keras.mixed_precision.set_global_policy(policy) -print(f"Mixed precision policy: {policy.name}") -EOF - -# Verify data directory exists -DATA_DIR="../data/hdf5_data_final" -if [ -d "$DATA_DIR" ]; then - echo "Data directory found: $DATA_DIR" - # Count available sessions - SESSION_COUNT=$(ls -d $DATA_DIR/t*.20* 2>/dev/null | wc -l) - echo "Available sessions: $SESSION_COUNT" -else - echo "Warning: Data directory not found at $DATA_DIR" - echo "Please ensure the dataset is available before training." -fi - -# Create output directories -echo "Creating output directories..." -mkdir -p trained_models/tensorflow_tpu -mkdir -p logs/tensorflow_tpu -mkdir -p eval_output - -# Make scripts executable -echo "Setting script permissions..." -chmod +x train_model_tf.py -chmod +x evaluate_model_tf.py - -# Display system information -echo "=== System Information ===" -echo "Python version: $(python --version)" -echo "Conda environment: $CONDA_DEFAULT_ENV" -echo "Available memory: $(free -h | grep '^Mem:' | awk '{print $7}')" -echo "CPU cores: $(nproc)" - -# Check for GPU/TPU -echo "=== Hardware Information ===" -if nvidia-smi &> /dev/null; then - echo "NVIDIA GPUs detected:" - nvidia-smi --list-gpus -else - echo "No NVIDIA GPUs detected" -fi - -if [[ -n "${TPU_NAME}" ]]; then - echo "TPU Name: $TPU_NAME" -elif [[ -n "${COLAB_TPU_ADDR}" ]]; then - echo "Colab TPU Address: $COLAB_TPU_ADDR" -else - echo "No TPU environment variables detected" -fi - -echo "" -echo "=== Setup Complete ===" -echo "Environment '$ENV_NAME' is ready for TensorFlow TPU training." -echo "" -echo "To activate the environment:" -echo " conda activate $ENV_NAME" -echo "" -echo "To start training:" -echo " python train_model_tf.py --config_path rnn_args.yaml" -echo "" -echo "To run evaluation:" -echo " python evaluate_model_tf.py --model_path path/to/checkpoint --config_path rnn_args.yaml" -echo "" -echo "For more options, use --help with any script." \ No newline at end of file diff --git a/model_training_nnn_tpu/tpu_memory_monitor.py b/model_training_nnn_tpu/tpu_memory_monitor.py deleted file mode 100644 index 970d6f3..0000000 --- a/model_training_nnn_tpu/tpu_memory_monitor.py +++ /dev/null @@ -1,236 +0,0 @@ -#!/usr/bin/env python3 -""" -TPU内存监控工具 - 专门用于训练过程 -解决tf.config.experimental.get_memory_info()在TPU上无法工作的问题 -""" - -import tensorflow as tf -import time -import psutil -import os - -class TPUMemoryMonitor: - """TPU内存监控类""" - - def __init__(self): - self.tpu_devices = tf.config.list_logical_devices('TPU') - self.baseline_memory = None - self.peak_allocations = {} - - def get_tpu_status(self) -> str: - """获取TPU状态 - 实用版本,不依赖get_memory_info""" - try: - if not self.tpu_devices: - return "TPU: No devices" - - num_cores = len(self.tpu_devices) - - # 测试TPU响应性 - try: - with tf.device('/TPU:0'): - test_tensor = tf.constant([1.0, 2.0, 3.0]) - result = tf.reduce_sum(test_tensor) - _ = result.numpy() # 强制执行 - activity = "active" - except Exception: - activity = "inactive" - - # 获取主机内存作为参考 - try: - memory = psutil.virtual_memory() - host_mem = f"Host:{memory.percent:.1f}%" - except: - host_mem = "Host:unknown" - - return f"TPU: {num_cores}cores {activity} {host_mem}" - - except Exception as e: - return f"TPU: error({str(e)[:20]})" - - def estimate_tensor_memory(self, tensor_shape, dtype=tf.float32): - """估算张量内存使用量""" - if dtype == tf.float32: - bytes_per_element = 4 - elif dtype == tf.float16 or dtype == tf.bfloat16: - bytes_per_element = 2 - elif dtype == tf.int32: - bytes_per_element = 4 - elif dtype == tf.int64: - bytes_per_element = 8 - else: - bytes_per_element = 4 # 默认 - - total_elements = 1 - for dim in tensor_shape: - total_elements *= dim - - total_bytes = total_elements * bytes_per_element - return total_bytes / (1024 * 1024) # 返回MB - - def track_allocation(self, name: str, tensor_shape, dtype=tf.float32): - """跟踪内存分配""" - mb = self.estimate_tensor_memory(tensor_shape, dtype) - self.peak_allocations[name] = self.peak_allocations.get(name, 0) + mb - return mb - - def get_allocation_summary(self) -> str: - """获取分配汇总""" - if not self.peak_allocations: - return "No allocations tracked" - - total_mb = sum(self.peak_allocations.values()) - top_3 = sorted(self.peak_allocations.items(), key=lambda x: x[1], reverse=True)[:3] - - summary = f"Tracked:{total_mb:.1f}MB " - summary += f"Top:({top_3[0][0]}:{top_3[0][1]:.1f}MB)" - - return summary - - def test_memory_allocation_across_cores(self): - """测试8个核心的内存分配""" - print("🧪 测试所有TPU核心内存分配") - print("=" * 40) - - allocations_per_core = [] - - for i, device in enumerate(self.tpu_devices): - print(f"核心 {i+1}: {device.name}") - - try: - with tf.device(device.name): - # 创建不同大小的测试张量 - test_sizes = [ - ([1000, 1000], "1K×1K"), - ([3000, 3000], "3K×3K"), - ([5000, 5000], "5K×5K"), - ([7000, 7000], "7K×7K"), - ] - - core_total = 0 - successful_allocs = [] - - for shape, desc in test_sizes: - try: - tensor = tf.ones(shape, dtype=tf.float32) - mb = self.estimate_tensor_memory(shape) - core_total += mb - successful_allocs.append(f"{desc}({mb:.1f}MB)") - - # 实际使用张量防止被优化 - _ = tf.reduce_mean(tensor) - - except Exception as e: - print(f" {desc} 失败: {str(e)[:30]}") - break - - allocations_per_core.append(core_total) - print(f" 成功分配: {' + '.join(successful_allocs)}") - print(f" 核心总计: {core_total:.1f}MB") - - except Exception as e: - print(f" 核心{i+1}失败: {e}") - allocations_per_core.append(0) - - # 汇总结果 - total_all_cores = sum(allocations_per_core) - avg_per_core = total_all_cores / len(self.tpu_devices) if self.tpu_devices else 0 - - print(f"\n📊 汇总结果:") - print(f" 总分配: {total_all_cores:.1f}MB ({total_all_cores/1024:.2f}GB)") - print(f" 平均每核: {avg_per_core:.1f}MB ({avg_per_core/1024:.2f}GB)") - - # 推测内存配置 - if avg_per_core > 8000: # > 8GB - print(" 推测: 每核心≥16GB (高端配置)") - elif avg_per_core > 4000: # > 4GB - print(" 推测: 每核心8-16GB (标准配置)") - elif avg_per_core > 1000: # > 1GB - print(" 推测: 每核心2-8GB (受限或共享)") - else: - print(" 推测: 每核心<2GB (严重受限)") - - return allocations_per_core - -def test_training_memory_pattern(): - """测试模拟训练的内存模式""" - print("\n🏋️ 模拟训练内存模式测试") - print("=" * 30) - - monitor = TPUMemoryMonitor() - - # 模拟典型的brain-to-text模型内存使用 - with tf.device('/TPU:0'): - print("创建模拟模型组件...") - - # 1. 输入数据 (batch_size=32, seq_len=1000, features=512) - batch_size, seq_len, features = 32, 1000, 512 - input_data = tf.random.normal([batch_size, seq_len, features]) - input_mb = monitor.track_allocation("input_data", [batch_size, seq_len, features]) - print(f" 输入数据: {input_mb:.1f}MB") - - # 2. GRU权重 (假设3层, 每层256单元) - n_layers, n_units = 3, 256 - for layer in range(n_layers): - # GRU有3个门,每个门需要权重矩阵 - weight_shape = [features if layer == 0 else n_units, n_units * 3] - weights = tf.random.normal(weight_shape) - weight_mb = monitor.track_allocation(f"gru_layer_{layer}", weight_shape) - print(f" GRU层{layer+1}权重: {weight_mb:.1f}MB") - - # 3. 输出投影层 (n_units -> n_classes=41) - n_classes = 41 - output_weights = tf.random.normal([n_units, n_classes]) - output_mb = monitor.track_allocation("output_projection", [n_units, n_classes]) - print(f" 输出投影: {output_mb:.1f}MB") - - # 4. 中间激活值 (前向传播) - hidden_states = tf.random.normal([batch_size, seq_len, n_units]) - hidden_mb = monitor.track_allocation("hidden_states", [batch_size, seq_len, n_units]) - print(f" 隐藏状态: {hidden_mb:.1f}MB") - - # 5. 梯度 (反向传播时会翻倍内存) - total_params_mb = sum([v for k, v in monitor.peak_allocations.items() if 'layer' in k or 'projection' in k]) - gradient_mb = total_params_mb # 梯度内存约等于参数内存 - print(f" 梯度内存: {gradient_mb:.1f}MB (估算)") - - print(f"\n模型总内存估算: {monitor.get_allocation_summary()}") - - # 实际执行一些操作确保内存被分配 - result = tf.reduce_mean(input_data) + tf.reduce_mean(hidden_states) - print(f"验证计算结果: {result.numpy():.4f}") - -if __name__ == "__main__": - print("🚀 TPU内存监控工具启动") - - monitor = TPUMemoryMonitor() - - # 基础状态检查 - print(f"当前TPU状态: {monitor.get_tpu_status()}") - - # 测试所有核心 - print("\n" + "="*50) - core_allocations = monitor.test_memory_allocation_across_cores() - - # 训练内存模式测试 - print("\n" + "="*50) - test_training_memory_pattern() - - print(f"\n🎯 关键发现:") - if core_allocations: - max_core = max(core_allocations) - min_core = min([x for x in core_allocations if x > 0]) - print(f" 最大单核分配: {max_core:.1f}MB") - print(f" 最小单核分配: {min_core:.1f}MB") - - if max_core > 9000: # 你之前测试到9.4GB - print(" ✅ 内存充足,可支持大模型训练") - elif max_core > 5000: - print(" ⚠️ 内存中等,建议优化模型大小") - else: - print(" ❌ 内存不足,需要大幅减少模型参数") - - print(f"\n💡 针对你的训练卡顿问题:") - print(f" - SetPriority错误通常是XLA编译问题,不是内存问题") - print(f" - 你的9.4GB测试说明TPU内存工作正常") - print(f" - 建议检查模型是否有导致XLA编译卡顿的操作") - print(f" - 考虑使用更简单的操作或关闭某些XLA优化") \ No newline at end of file diff --git a/model_training_nnn_tpu/train_model.py b/model_training_nnn_tpu/train_model.py deleted file mode 100644 index 81390c2..0000000 --- a/model_training_nnn_tpu/train_model.py +++ /dev/null @@ -1,25 +0,0 @@ -import argparse -from omegaconf import OmegaConf -from rnn_trainer import BrainToTextDecoder_Trainer - -def main(): - parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model') - parser.add_argument('--config_path', default='rnn_args.yaml', - help='Path to configuration file (default: rnn_args.yaml)') - - args = parser.parse_args() - - # Load configuration - config = OmegaConf.load(args.config_path) - - # Initialize trainer - trainer = BrainToTextDecoder_Trainer(config) - - # Start training - trainer.train() - - print("Training completed successfully!") - print(f"Best validation PER: {trainer.best_val_PER:.5f}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/model_training_nnn_tpu/train_model_tf.py b/model_training_nnn_tpu/train_model_tf.py index 1e10e6f..d0f397c 100644 --- a/model_training_nnn_tpu/train_model_tf.py +++ b/model_training_nnn_tpu/train_model_tf.py @@ -8,7 +8,7 @@ It provides the same functionality as the PyTorch version but with TensorFlow operations optimized for TPU performance. Usage: - python train_model_tf.py --config_path rnn_args.yaml + python train_model_tf.py -config_path rnn_args.yaml Requirements: - TensorFlow >= 2.15.0