#!/usr/bin/env python3 """ 超简单MNIST TPU训练 - 完全避开混合精度问题 只使用float32,确保稳定运行 """ import os import time import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms # 清理所有可能导致bf16问题的环境变量 for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']: if key in os.environ: del os.environ[key] # 只设置最基本的XLA优化 os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false' import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl class SimpleMNISTNet(nn.Module): """超简单的MNIST分类器""" def __init__(self): super(SimpleMNISTNet, self).__init__() self.flatten = nn.Flatten() self.fc1 = nn.Linear(28 * 28, 128) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(128, 64) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(64, 10) def forward(self, x): x = self.flatten(x) x = self.relu1(self.fc1(x)) x = self.relu2(self.fc2(x)) x = self.fc3(x) return x def get_mnist_data(batch_size=64): """获取MNIST数据""" transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = torchvision.datasets.MNIST( root='./mnist_data', train=True, download=True, transform=transform ) test_dataset = torchvision.datasets.MNIST( root='./mnist_data', train=False, download=True, transform=transform ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=0 ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=0 ) return train_loader, test_loader def train_mnist(): """训练MNIST模型""" print("🚀 开始MNIST TPU训练...") # 获取设备 device = xm.xla_device() print(f"📱 设备: {device}") # 创建模型 model = SimpleMNISTNet().to(device) # 确保所有参数都是float32 for param in model.parameters(): param.data = param.data.to(torch.float32) print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}") # 损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 获取数据 print("📥 加载MNIST数据...") train_loader, test_loader = get_mnist_data(batch_size=64) # 使用XLA并行加载器 train_device_loader = pl.MpDeviceLoader(train_loader, device) print("🎯 开始训练...") model.train() start_time = time.time() total_loss = 0.0 correct = 0 total = 0 max_batches = 100 # 只训练100个批次,快速验证 for batch_idx, (data, target) in enumerate(train_device_loader): if batch_idx >= max_batches: break # 确保数据类型正确 data = data.to(torch.float32) target = target.to(torch.long) # 前向传播 optimizer.zero_grad() output = model(data) loss = criterion(output, target) # 反向传播 loss.backward() optimizer.step() # 统计 total_loss += loss.item() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() total += target.size(0) # 每10个批次同步一次 if batch_idx % 10 == 0: xm.mark_step() current_acc = 100. * correct / total avg_loss = total_loss / (batch_idx + 1) print(f'批次 {batch_idx:3d}/{max_batches} | ' f'损失: {avg_loss:.4f} | ' f'准确率: {current_acc:.2f}%') # 最终同步 xm.mark_step() xm.wait_device_ops() train_time = time.time() - start_time final_acc = 100. * correct / total final_loss = total_loss / min(batch_idx + 1, max_batches) print(f"\n✅ 训练完成!") print(f"⏱️ 训练时间: {train_time:.2f}秒") print(f"🎯 最终损失: {final_loss:.4f}") print(f"🎯 训练准确率: {final_acc:.2f}%") return model, final_loss, final_acc def test_mnist(model): """测试MNIST模型""" print("\n🧪 开始测试...") device = xm.xla_device() _, test_loader = get_mnist_data(batch_size=64) test_device_loader = pl.MpDeviceLoader(test_loader, device) model.eval() correct = 0 total = 0 max_test_batches = 50 # 只测试50个批次 start_time = time.time() with torch.no_grad(): for batch_idx, (data, target) in enumerate(test_device_loader): if batch_idx >= max_test_batches: break # 确保数据类型 data = data.to(torch.float32) target = target.to(torch.long) output = model(data) pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() total += target.size(0) if batch_idx % 10 == 0: xm.mark_step() xm.mark_step() xm.wait_device_ops() test_time = time.time() - start_time accuracy = 100. * correct / total print(f"✅ 测试完成!") print(f"⏱️ 测试时间: {test_time:.2f}秒") print(f"🎯 测试准确率: {accuracy:.2f}%") return accuracy def main(): """主函数""" print("=" * 60) print("🔢 超简单MNIST TPU训练 (仅float32)") print("=" * 60) try: # 训练 model, train_loss, train_acc = train_mnist() # 测试 test_acc = test_mnist(model) # 保存模型 print("\n💾 保存模型...") model_cpu = model.cpu() torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth') print("✅ 模型已保存") print("\n🎉 全部完成!") print(f"📊 训练准确率: {train_acc:.2f}%") print(f"📊 测试准确率: {test_acc:.2f}%") if train_acc > 80 and test_acc > 75: print("✅ 模型训练成功!") else: print("⚠️ 模型性能一般,但TPU功能正常") except Exception as e: print(f"❌ 训练失败: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()