From b466e9746320b289b9142cb676f257fcb2cd68e5 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 15 Oct 2025 15:22:13 +0800 Subject: [PATCH] tpu test --- model_training_nnn_tpu/minimal_tpu_test.py | 194 +++++++++++ model_training_nnn_tpu/mnist_tpu_simple.py | 253 ++++++++++++++ model_training_nnn_tpu/quick_tpu_test.py | 129 -------- model_training_nnn_tpu/simple_tpu_model.py | 367 --------------------- 4 files changed, 447 insertions(+), 496 deletions(-) create mode 100644 model_training_nnn_tpu/minimal_tpu_test.py create mode 100644 model_training_nnn_tpu/mnist_tpu_simple.py delete mode 100644 model_training_nnn_tpu/quick_tpu_test.py delete mode 100644 model_training_nnn_tpu/simple_tpu_model.py diff --git a/model_training_nnn_tpu/minimal_tpu_test.py b/model_training_nnn_tpu/minimal_tpu_test.py new file mode 100644 index 0000000..2d8e60f --- /dev/null +++ b/model_training_nnn_tpu/minimal_tpu_test.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +""" +最简单的TPU测试 - 完全避开bf16问题 +只使用float32,最基本的操作 +""" + +import os +import time +import torch +import torch.nn as nn + +# 完全不设置任何bf16相关的环境变量 +# 只设置最基本的XLA优化 +os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true' + +# 确保不使用bf16 +if 'XLA_USE_BF16' in os.environ: + del os.environ['XLA_USE_BF16'] + +import torch_xla.core.xla_model as xm + + +def test_basic_operations(): + """测试最基本的TPU操作""" + print("🚀 测试最基本的TPU操作...") + + try: + device = xm.xla_device() + print(f"📱 设备: {device}") + + # 测试1: 基本张量操作 + print("🔧 测试基本张量操作...") + a = torch.randn(4, 4, device=device, dtype=torch.float32) + b = torch.randn(4, 4, device=device, dtype=torch.float32) + c = a + b + + print(f" a.shape: {a.shape}, dtype: {a.dtype}") + print(f" b.shape: {b.shape}, dtype: {b.dtype}") + print(f" c.shape: {c.shape}, dtype: {c.dtype}") + + # 同步 + xm.mark_step() + xm.wait_device_ops() + print("✅ 基本张量操作成功") + + # 测试2: 矩阵乘法 + print("🔧 测试矩阵乘法...") + d = torch.mm(a, b) + xm.mark_step() + xm.wait_device_ops() + print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}") + print("✅ 矩阵乘法成功") + + return True + + except Exception as e: + print(f"❌ 基本操作失败: {e}") + return False + + +def test_simple_model(): + """测试最简单的模型""" + print("\n🧠 测试最简单的模型...") + + try: + device = xm.xla_device() + + # 超简单的线性模型 + model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 2) + ).to(device) + + print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}") + + # 确保所有参数都是float32 + for param in model.parameters(): + param.data = param.data.to(torch.float32) + + # 创建输入数据 - 明确指定float32 + x = torch.randn(8, 10, device=device, dtype=torch.float32) + + print(f"📥 输入: shape={x.shape}, dtype={x.dtype}") + + # 前向传播 + with torch.no_grad(): + output = model(x) + xm.mark_step() + xm.wait_device_ops() + + print(f"📤 输出: shape={output.shape}, dtype={output.dtype}") + print("✅ 简单模型前向传播成功") + + return True + + except Exception as e: + print(f"❌ 简单模型失败: {e}") + import traceback + traceback.print_exc() + return False + + +def test_training_step(): + """测试最简单的训练步骤""" + print("\n🎯 测试最简单的训练步骤...") + + try: + device = xm.xla_device() + + # 超简单模型 + model = nn.Linear(10, 1).to(device) + + # 确保权重是float32 + for param in model.parameters(): + param.data = param.data.to(torch.float32) + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + criterion = nn.MSELoss() + + # 创建数据 - 明确float32 + x = torch.randn(4, 10, device=device, dtype=torch.float32) + y = torch.randn(4, 1, device=device, dtype=torch.float32) + + print(f"📥 输入: {x.shape}, {x.dtype}") + print(f"📥 标签: {y.shape}, {y.dtype}") + + # 一个训练步骤 + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + + # 同步 + xm.mark_step() + xm.wait_device_ops() + + print(f"🎯 损失: {loss.item():.4f}") + print("✅ 训练步骤成功") + + return True + + except Exception as e: + print(f"❌ 训练步骤失败: {e}") + import traceback + traceback.print_exc() + return False + + +def main(): + """主函数""" + print("=" * 50) + print("🔬 最简TPU测试 (仅float32)") + print("=" * 50) + + all_passed = True + + # 测试1: 基本操作 + if test_basic_operations(): + print("1️⃣ 基本操作 ✅") + else: + print("1️⃣ 基本操作 ❌") + all_passed = False + + # 测试2: 简单模型 + if test_simple_model(): + print("2️⃣ 简单模型 ✅") + else: + print("2️⃣ 简单模型 ❌") + all_passed = False + + # 测试3: 训练步骤 + if test_training_step(): + print("3️⃣ 训练步骤 ✅") + else: + print("3️⃣ 训练步骤 ❌") + all_passed = False + + print("=" * 50) + + if all_passed: + print("🎉 所有测试通过! TPU工作正常") + print("💡 现在可以尝试更复杂的模型") + else: + print("❌ 部分测试失败") + print("💡 建议:") + print(" 1. 检查TPU资源是否可用") + print(" 2. 确认torch_xla安装正确") + print(" 3. 重启runtime清理状态") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model_training_nnn_tpu/mnist_tpu_simple.py b/model_training_nnn_tpu/mnist_tpu_simple.py new file mode 100644 index 0000000..f1c8c29 --- /dev/null +++ b/model_training_nnn_tpu/mnist_tpu_simple.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +""" +超简单MNIST TPU训练 - 完全避开混合精度问题 +只使用float32,确保稳定运行 +""" + +import os +import time +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms + +# 清理所有可能导致bf16问题的环境变量 +for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']: + if key in os.environ: + del os.environ[key] + +# 只设置最基本的XLA优化 +os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false' + +import torch_xla.core.xla_model as xm +import torch_xla.distributed.parallel_loader as pl + + +class SimpleMNISTNet(nn.Module): + """超简单的MNIST分类器""" + + def __init__(self): + super(SimpleMNISTNet, self).__init__() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(28 * 28, 128) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(128, 64) + self.relu2 = nn.ReLU() + self.fc3 = nn.Linear(64, 10) + + def forward(self, x): + x = self.flatten(x) + x = self.relu1(self.fc1(x)) + x = self.relu2(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_mnist_data(batch_size=64): + """获取MNIST数据""" + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + train_dataset = torchvision.datasets.MNIST( + root='./mnist_data', + train=True, + download=True, + transform=transform + ) + + test_dataset = torchvision.datasets.MNIST( + root='./mnist_data', + train=False, + download=True, + transform=transform + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0 + ) + + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0 + ) + + return train_loader, test_loader + + +def train_mnist(): + """训练MNIST模型""" + print("🚀 开始MNIST TPU训练...") + + # 获取设备 + device = xm.xla_device() + print(f"📱 设备: {device}") + + # 创建模型 + model = SimpleMNISTNet().to(device) + + # 确保所有参数都是float32 + for param in model.parameters(): + param.data = param.data.to(torch.float32) + + print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}") + + # 损失函数和优化器 + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + + # 获取数据 + print("📥 加载MNIST数据...") + train_loader, test_loader = get_mnist_data(batch_size=64) + + # 使用XLA并行加载器 + train_device_loader = pl.MpDeviceLoader(train_loader, device) + + print("🎯 开始训练...") + + model.train() + start_time = time.time() + + total_loss = 0.0 + correct = 0 + total = 0 + max_batches = 100 # 只训练100个批次,快速验证 + + for batch_idx, (data, target) in enumerate(train_device_loader): + if batch_idx >= max_batches: + break + + # 确保数据类型正确 + data = data.to(torch.float32) + target = target.to(torch.long) + + # 前向传播 + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + + # 反向传播 + loss.backward() + optimizer.step() + + # 统计 + total_loss += loss.item() + pred = output.argmax(dim=1) + correct += pred.eq(target).sum().item() + total += target.size(0) + + # 每10个批次同步一次 + if batch_idx % 10 == 0: + xm.mark_step() + current_acc = 100. * correct / total + avg_loss = total_loss / (batch_idx + 1) + + print(f'批次 {batch_idx:3d}/{max_batches} | ' + f'损失: {avg_loss:.4f} | ' + f'准确率: {current_acc:.2f}%') + + # 最终同步 + xm.mark_step() + xm.wait_device_ops() + + train_time = time.time() - start_time + final_acc = 100. * correct / total + final_loss = total_loss / min(batch_idx + 1, max_batches) + + print(f"\n✅ 训练完成!") + print(f"⏱️ 训练时间: {train_time:.2f}秒") + print(f"🎯 最终损失: {final_loss:.4f}") + print(f"🎯 训练准确率: {final_acc:.2f}%") + + return model, final_loss, final_acc + + +def test_mnist(model): + """测试MNIST模型""" + print("\n🧪 开始测试...") + + device = xm.xla_device() + _, test_loader = get_mnist_data(batch_size=64) + + test_device_loader = pl.MpDeviceLoader(test_loader, device) + + model.eval() + correct = 0 + total = 0 + max_test_batches = 50 # 只测试50个批次 + + start_time = time.time() + + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(test_device_loader): + if batch_idx >= max_test_batches: + break + + # 确保数据类型 + data = data.to(torch.float32) + target = target.to(torch.long) + + output = model(data) + pred = output.argmax(dim=1) + correct += pred.eq(target).sum().item() + total += target.size(0) + + if batch_idx % 10 == 0: + xm.mark_step() + + xm.mark_step() + xm.wait_device_ops() + + test_time = time.time() - start_time + accuracy = 100. * correct / total + + print(f"✅ 测试完成!") + print(f"⏱️ 测试时间: {test_time:.2f}秒") + print(f"🎯 测试准确率: {accuracy:.2f}%") + + return accuracy + + +def main(): + """主函数""" + print("=" * 60) + print("🔢 超简单MNIST TPU训练 (仅float32)") + print("=" * 60) + + try: + # 训练 + model, train_loss, train_acc = train_mnist() + + # 测试 + test_acc = test_mnist(model) + + # 保存模型 + print("\n💾 保存模型...") + model_cpu = model.cpu() + torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth') + print("✅ 模型已保存") + + print("\n🎉 全部完成!") + print(f"📊 训练准确率: {train_acc:.2f}%") + print(f"📊 测试准确率: {test_acc:.2f}%") + + if train_acc > 80 and test_acc > 75: + print("✅ 模型训练成功!") + else: + print("⚠️ 模型性能一般,但TPU功能正常") + + except Exception as e: + print(f"❌ 训练失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model_training_nnn_tpu/quick_tpu_test.py b/model_training_nnn_tpu/quick_tpu_test.py deleted file mode 100644 index 2cb93fb..0000000 --- a/model_training_nnn_tpu/quick_tpu_test.py +++ /dev/null @@ -1,129 +0,0 @@ -#!/usr/bin/env python3 -""" -快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行 -""" - -import os -import time -import torch -import torch.nn as nn - -# 设置环境变量 -os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true' -os.environ['XLA_USE_BF16'] = '1' - -import torch_xla.core.xla_model as xm - -def quick_test(): - """快速测试TPU是否工作正常""" - print("🚀 开始快速TPU测试...") - - try: - # 获取TPU设备 - device = xm.xla_device() - print(f"📱 TPU设备: {device}") - - # 创建简单模型 - model = nn.Sequential( - nn.Linear(512, 256), - nn.ReLU(), - nn.GRU(256, 128, batch_first=True), - nn.Linear(128, 41) - ).to(device) - - print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}") - - # 创建测试数据 - x = torch.randn(8, 50, 512, device=device) - print(f"📥 输入形状: {x.shape}") - - # 测试前向传播 - print("🔄 测试前向传播...") - start_time = time.time() - - with torch.no_grad(): - if hasattr(model, '__getitem__'): - # 对于Sequential模型,手动处理GRU层 - x_proj = model[1](model[0](x)) # Linear + ReLU - gru_out, _ = model[2](x_proj) # GRU - output = model[3](gru_out) # Final Linear - else: - output = model(x) - - # 同步TPU操作 - xm.mark_step() - xm.wait_device_ops() - - forward_time = time.time() - start_time - print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}秒") - print(f"📤 输出形状: {output.shape}") - - # 测试反向传播 - print("🔄 测试反向传播...") - model.train() - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - - start_time = time.time() - - # 创建虚拟标签 - labels = torch.randint(0, 41, (8, 50), device=device) - criterion = nn.CrossEntropyLoss() - - # 前向传播 - if hasattr(model, '__getitem__'): - x_proj = model[1](model[0](x)) - gru_out, _ = model[2](x_proj) - output = model[3](gru_out) - else: - output = model(x) - - # 计算损失 - loss = criterion(output.view(-1, 41), labels.view(-1)) - - # 反向传播 - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # 同步TPU操作 - xm.mark_step() - xm.wait_device_ops() - - backward_time = time.time() - start_time - print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}秒") - print(f"🎯 损失值: {loss.item():.4f}") - - # 总结 - print(f"\n📈 性能总结:") - print(f" 前向传播: {forward_time:.3f}秒") - print(f" 反向传播: {backward_time:.3f}秒") - print(f" 总计: {forward_time + backward_time:.3f}秒") - - if (forward_time + backward_time) < 10: # 10秒内完成 - print("✅ TPU测试通过! 可以进行完整训练") - return True - else: - print("⚠️ TPU性能较慢,可能需要优化") - return False - - except Exception as e: - print(f"❌ TPU测试失败: {e}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - print("=" * 50) - print("⚡ 快速TPU测试") - print("=" * 50) - - success = quick_test() - - if success: - print("\n🎉 测试成功! 现在可以运行:") - print(" python simple_tpu_model.py") - else: - print("\n❌ 测试失败,请检查TPU配置") - - print("=" * 50) \ No newline at end of file diff --git a/model_training_nnn_tpu/simple_tpu_model.py b/model_training_nnn_tpu/simple_tpu_model.py deleted file mode 100644 index 21089ed..0000000 --- a/model_training_nnn_tpu/simple_tpu_model.py +++ /dev/null @@ -1,367 +0,0 @@ -#!/usr/bin/env python3 -""" -简单TPU模型训练和测试脚本 -基于大脑到文本数据的简化版本,专门为TPU优化 -""" - -import os -import time -import torch -import torch.nn as nn -import torch.optim as optim -import numpy as np -from typing import Dict, Any, Tuple - -# 设置XLA环境变量 -os.environ['XLA_FLAGS'] = ( - '--xla_cpu_multi_thread_eigen=true ' - '--xla_cpu_enable_fast_math=true ' - f'--xla_force_host_platform_device_count={os.cpu_count()}' -) -os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count()) -os.environ['XLA_USE_BF16'] = '1' - -import torch_xla.core.xla_model as xm -import torch_xla.distributed.parallel_loader as pl - - -class SimpleBrainToTextModel(nn.Module): - """简化的大脑到文本模型 - TPU优化版本""" - - def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3): - super().__init__() - - # 输入处理层 - self.input_proj = nn.Linear(input_features, hidden_size) - self.input_dropout = nn.Dropout(0.2) - - # GRU层 - 使用较小的隐藏层以提高TPU效率 - self.gru = nn.GRU( - input_size=hidden_size, - hidden_size=hidden_size, - num_layers=num_layers, - batch_first=True, - dropout=0.3 if num_layers > 1 else 0 - ) - - # 输出层 - self.output_proj = nn.Linear(hidden_size, num_classes) - - # 初始化权重 - self._init_weights() - - def _init_weights(self): - """初始化模型权重""" - for name, param in self.named_parameters(): - if 'weight' in name: - if 'gru' in name: - nn.init.orthogonal_(param) - else: - nn.init.xavier_uniform_(param) - elif 'bias' in name: - nn.init.zeros_(param) - - def forward(self, x): - """ - 前向传播 - Args: - x: (batch_size, seq_len, input_features) - Returns: - logits: (batch_size, seq_len, num_classes) - """ - # 输入投影 - x = torch.relu(self.input_proj(x)) - x = self.input_dropout(x) - - # GRU处理 - output, _ = self.gru(x) - - # 输出投影 - logits = self.output_proj(output) - - return logits - - -class SimpleDataGenerator: - """简单的数据生成器 - 模拟大脑信号数据""" - - def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41): - self.batch_size = batch_size - self.seq_len = seq_len - self.input_features = input_features - self.num_classes = num_classes - - def generate_batch(self, device): - """生成一个批次的模拟数据""" - # 生成模拟的神经信号数据 - features = torch.randn( - self.batch_size, self.seq_len, self.input_features, - device=device, dtype=torch.float32 - ) - - # 生成模拟的标签(音素序列) - labels = torch.randint( - 0, self.num_classes, - (self.batch_size, self.seq_len), - device=device - ) - - # 生成序列长度 - seq_lengths = torch.randint( - self.seq_len // 2, self.seq_len + 1, - (self.batch_size,), - device=device - ) - - return { - 'features': features, - 'labels': labels, - 'seq_lengths': seq_lengths - } - - -class SimpleTpuTrainer: - """简单的TPU训练器""" - - def __init__(self, model, device, learning_rate=0.001): - self.model = model - self.device = device - self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) - self.criterion = nn.CrossEntropyLoss(ignore_index=-1) - - # 数据生成器 - self.data_generator = SimpleDataGenerator() - - # 训练统计 - self.step = 0 - self.best_loss = float('inf') - - def train_step(self, batch): - """单个训练步骤""" - self.model.train() - self.optimizer.zero_grad() - - # 前向传播 - features = batch['features'] - labels = batch['labels'] - - logits = self.model(features) - - # 计算损失 - 重新调整形状以适应CrossEntropyLoss - batch_size, seq_len, num_classes = logits.shape - loss = self.criterion( - logits.reshape(-1, num_classes), - labels.reshape(-1) - ) - - # 反向传播 - loss.backward() - - # 梯度裁剪 - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - # 更新参数 - self.optimizer.step() - - return loss.item() - - def evaluate_step(self, batch): - """单个评估步骤""" - self.model.eval() - - with torch.no_grad(): - features = batch['features'] - labels = batch['labels'] - - logits = self.model(features) - - # 计算损失 - batch_size, seq_len, num_classes = logits.shape - loss = self.criterion( - logits.reshape(-1, num_classes), - labels.reshape(-1) - ) - - # 计算准确率 - predictions = torch.argmax(logits, dim=-1) - correct = (predictions == labels).float() - accuracy = correct.mean() - - return loss.item(), accuracy.item() - - def train(self, num_steps=1000, eval_every=100, save_every=500): - """训练模型""" - print(f"🚀 开始TPU训练 - 设备: {self.device}") - print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}") - - train_losses = [] - eval_losses = [] - eval_accuracies = [] - - start_time = time.time() - - for step in range(num_steps): - # 生成训练数据 - train_batch = self.data_generator.generate_batch(self.device) - - # 训练步骤 - train_loss = self.train_step(train_batch) - train_losses.append(train_loss) - - # XLA同步 - if step % 10 == 0: # 每10步同步一次以提高效率 - xm.mark_step() - - # 评估 - if step % eval_every == 0: - eval_batch = self.data_generator.generate_batch(self.device) - eval_loss, eval_acc = self.evaluate_step(eval_batch) - eval_losses.append(eval_loss) - eval_accuracies.append(eval_acc) - - # 同步XLA操作以获得准确的时间 - xm.mark_step() - xm.wait_device_ops() - - current_time = time.time() - elapsed = current_time - start_time - - print(f"步骤 {step:4d}/{num_steps} | " - f"训练损失: {train_loss:.4f} | " - f"验证损失: {eval_loss:.4f} | " - f"验证准确率: {eval_acc:.4f} | " - f"耗时: {elapsed:.1f}s") - - # 保存最佳模型 - if eval_loss < self.best_loss: - self.best_loss = eval_loss - print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}") - - # 定期保存 - if step > 0 and step % save_every == 0: - self.save_checkpoint(f"checkpoint_step_{step}.pt") - - # 最终同步 - xm.mark_step() - xm.wait_device_ops() - - total_time = time.time() - start_time - print(f"\n✅ 训练完成!") - print(f"⏱️ 总耗时: {total_time:.1f}秒") - print(f"🎯 最终训练损失: {train_losses[-1]:.4f}") - if eval_losses: - print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}") - print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}") - - return { - 'train_losses': train_losses, - 'eval_losses': eval_losses, - 'eval_accuracies': eval_accuracies, - 'total_time': total_time - } - - def save_checkpoint(self, filename): - """保存检查点""" - checkpoint = { - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'step': self.step, - 'best_loss': self.best_loss, - } - - # 在TPU上需要先移动到CPU再保存 - if 'xla' in str(self.device): - checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu')) - - torch.save(checkpoint, filename) - print(f"💾 保存检查点: {filename}") - - def load_checkpoint(self, filename): - """加载检查点""" - checkpoint = torch.load(filename, map_location='cpu') - - self.model.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - self.step = checkpoint['step'] - self.best_loss = checkpoint['best_loss'] - - print(f"📂 加载检查点: {filename}") - print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}") - - -def test_simple_inference(): - """测试简单推理""" - print("\n🧪 测试简单推理...") - - device = xm.xla_device() - - # 创建模型 - model = SimpleBrainToTextModel().to(device) - - # 创建测试数据 - batch_size = 4 - seq_len = 50 - test_input = torch.randn(batch_size, seq_len, 512, device=device) - - # 推理 - model.eval() - with torch.no_grad(): - start_time = time.time() - output = model(test_input) - xm.mark_step() - xm.wait_device_ops() - inference_time = time.time() - start_time - - print(f"✅ 推理完成!") - print(f" 输入形状: {test_input.shape}") - print(f" 输出形状: {output.shape}") - print(f" 推理时间: {inference_time:.4f}秒") - - return True - - -def main(): - """主函数""" - print("=" * 60) - print("🧠 简单TPU大脑到文本模型训练") - print("=" * 60) - - try: - # 检查TPU设备 - device = xm.xla_device() - print(f"📱 使用设备: {device}") - - # 创建模型 - model = SimpleBrainToTextModel( - input_features=512, - hidden_size=256, - num_classes=41, - num_layers=3 - ).to(device) - - # 创建训练器 - trainer = SimpleTpuTrainer(model, device, learning_rate=0.001) - - # 开始训练 - results = trainer.train( - num_steps=1000, - eval_every=100, - save_every=500 - ) - - # 保存最终模型 - trainer.save_checkpoint("final_simple_model.pt") - - # 测试推理 - test_simple_inference() - - print("\n🎉 所有测试完成!") - - except Exception as e: - print(f"❌ 训练失败: {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - main() \ No newline at end of file