#!/usr/bin/env python3 """ 简单TPU模型训练和测试脚本 基于大脑到文本数据的简化版本,专门为TPU优化 """ import os import time import torch import torch.nn as nn import torch.optim as optim import numpy as np from typing import Dict, Any, Tuple # 设置XLA环境变量 os.environ['XLA_FLAGS'] = ( '--xla_cpu_multi_thread_eigen=true ' '--xla_cpu_enable_fast_math=true ' f'--xla_force_host_platform_device_count={os.cpu_count()}' ) os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count()) os.environ['XLA_USE_BF16'] = '1' import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl class SimpleBrainToTextModel(nn.Module): """简化的大脑到文本模型 - TPU优化版本""" def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3): super().__init__() # 输入处理层 self.input_proj = nn.Linear(input_features, hidden_size) self.input_dropout = nn.Dropout(0.2) # GRU层 - 使用较小的隐藏层以提高TPU效率 self.gru = nn.GRU( input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=0.3 if num_layers > 1 else 0 ) # 输出层 self.output_proj = nn.Linear(hidden_size, num_classes) # 初始化权重 self._init_weights() def _init_weights(self): """初始化模型权重""" for name, param in self.named_parameters(): if 'weight' in name: if 'gru' in name: nn.init.orthogonal_(param) else: nn.init.xavier_uniform_(param) elif 'bias' in name: nn.init.zeros_(param) def forward(self, x): """ 前向传播 Args: x: (batch_size, seq_len, input_features) Returns: logits: (batch_size, seq_len, num_classes) """ # 输入投影 x = torch.relu(self.input_proj(x)) x = self.input_dropout(x) # GRU处理 output, _ = self.gru(x) # 输出投影 logits = self.output_proj(output) return logits class SimpleDataGenerator: """简单的数据生成器 - 模拟大脑信号数据""" def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41): self.batch_size = batch_size self.seq_len = seq_len self.input_features = input_features self.num_classes = num_classes def generate_batch(self, device): """生成一个批次的模拟数据""" # 生成模拟的神经信号数据 features = torch.randn( self.batch_size, self.seq_len, self.input_features, device=device, dtype=torch.float32 ) # 生成模拟的标签(音素序列) labels = torch.randint( 0, self.num_classes, (self.batch_size, self.seq_len), device=device ) # 生成序列长度 seq_lengths = torch.randint( self.seq_len // 2, self.seq_len + 1, (self.batch_size,), device=device ) return { 'features': features, 'labels': labels, 'seq_lengths': seq_lengths } class SimpleTpuTrainer: """简单的TPU训练器""" def __init__(self, model, device, learning_rate=0.001): self.model = model self.device = device self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) self.criterion = nn.CrossEntropyLoss(ignore_index=-1) # 数据生成器 self.data_generator = SimpleDataGenerator() # 训练统计 self.step = 0 self.best_loss = float('inf') def train_step(self, batch): """单个训练步骤""" self.model.train() self.optimizer.zero_grad() # 前向传播 features = batch['features'] labels = batch['labels'] logits = self.model(features) # 计算损失 - 重新调整形状以适应CrossEntropyLoss batch_size, seq_len, num_classes = logits.shape loss = self.criterion( logits.reshape(-1, num_classes), labels.reshape(-1) ) # 反向传播 loss.backward() # 梯度裁剪 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 更新参数 self.optimizer.step() return loss.item() def evaluate_step(self, batch): """单个评估步骤""" self.model.eval() with torch.no_grad(): features = batch['features'] labels = batch['labels'] logits = self.model(features) # 计算损失 batch_size, seq_len, num_classes = logits.shape loss = self.criterion( logits.reshape(-1, num_classes), labels.reshape(-1) ) # 计算准确率 predictions = torch.argmax(logits, dim=-1) correct = (predictions == labels).float() accuracy = correct.mean() return loss.item(), accuracy.item() def train(self, num_steps=1000, eval_every=100, save_every=500): """训练模型""" print(f"🚀 开始TPU训练 - 设备: {self.device}") print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}") train_losses = [] eval_losses = [] eval_accuracies = [] start_time = time.time() for step in range(num_steps): # 生成训练数据 train_batch = self.data_generator.generate_batch(self.device) # 训练步骤 train_loss = self.train_step(train_batch) train_losses.append(train_loss) # XLA同步 if step % 10 == 0: # 每10步同步一次以提高效率 xm.mark_step() # 评估 if step % eval_every == 0: eval_batch = self.data_generator.generate_batch(self.device) eval_loss, eval_acc = self.evaluate_step(eval_batch) eval_losses.append(eval_loss) eval_accuracies.append(eval_acc) # 同步XLA操作以获得准确的时间 xm.mark_step() xm.wait_device_ops() current_time = time.time() elapsed = current_time - start_time print(f"步骤 {step:4d}/{num_steps} | " f"训练损失: {train_loss:.4f} | " f"验证损失: {eval_loss:.4f} | " f"验证准确率: {eval_acc:.4f} | " f"耗时: {elapsed:.1f}s") # 保存最佳模型 if eval_loss < self.best_loss: self.best_loss = eval_loss print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}") # 定期保存 if step > 0 and step % save_every == 0: self.save_checkpoint(f"checkpoint_step_{step}.pt") # 最终同步 xm.mark_step() xm.wait_device_ops() total_time = time.time() - start_time print(f"\n✅ 训练完成!") print(f"⏱️ 总耗时: {total_time:.1f}秒") print(f"🎯 最终训练损失: {train_losses[-1]:.4f}") if eval_losses: print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}") print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}") return { 'train_losses': train_losses, 'eval_losses': eval_losses, 'eval_accuracies': eval_accuracies, 'total_time': total_time } def save_checkpoint(self, filename): """保存检查点""" checkpoint = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'step': self.step, 'best_loss': self.best_loss, } # 在TPU上需要先移动到CPU再保存 if 'xla' in str(self.device): checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu')) torch.save(checkpoint, filename) print(f"💾 保存检查点: {filename}") def load_checkpoint(self, filename): """加载检查点""" checkpoint = torch.load(filename, map_location='cpu') self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.step = checkpoint['step'] self.best_loss = checkpoint['best_loss'] print(f"📂 加载检查点: {filename}") print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}") def test_simple_inference(): """测试简单推理""" print("\n🧪 测试简单推理...") device = xm.xla_device() # 创建模型 model = SimpleBrainToTextModel().to(device) # 创建测试数据 batch_size = 4 seq_len = 50 test_input = torch.randn(batch_size, seq_len, 512, device=device) # 推理 model.eval() with torch.no_grad(): start_time = time.time() output = model(test_input) xm.mark_step() xm.wait_device_ops() inference_time = time.time() - start_time print(f"✅ 推理完成!") print(f" 输入形状: {test_input.shape}") print(f" 输出形状: {output.shape}") print(f" 推理时间: {inference_time:.4f}秒") return True def main(): """主函数""" print("=" * 60) print("🧠 简单TPU大脑到文本模型训练") print("=" * 60) try: # 检查TPU设备 device = xm.xla_device() print(f"📱 使用设备: {device}") # 创建模型 model = SimpleBrainToTextModel( input_features=512, hidden_size=256, num_classes=41, num_layers=3 ).to(device) # 创建训练器 trainer = SimpleTpuTrainer(model, device, learning_rate=0.001) # 开始训练 results = trainer.train( num_steps=1000, eval_every=100, save_every=500 ) # 保存最终模型 trainer.save_checkpoint("final_simple_model.pt") # 测试推理 test_simple_inference() print("\n🎉 所有测试完成!") except Exception as e: print(f"❌ 训练失败: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()