Files
b2txt25/model_training_nnn_tpu/mnist_tpu_simple.py
2025-10-15 15:22:13 +08:00

253 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
超简单MNIST TPU训练 - 完全避开混合精度问题
只使用float32确保稳定运行
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 清理所有可能导致bf16问题的环境变量
for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']:
if key in os.environ:
del os.environ[key]
# 只设置最基本的XLA优化
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
class SimpleMNISTNet(nn.Module):
"""超简单的MNIST分类器"""
def __init__(self):
super(SimpleMNISTNet, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(128, 64)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = self.flatten(x)
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x
def get_mnist_data(batch_size=64):
"""获取MNIST数据"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=False,
download=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
return train_loader, test_loader
def train_mnist():
"""训练MNIST模型"""
print("🚀 开始MNIST TPU训练...")
# 获取设备
device = xm.xla_device()
print(f"📱 设备: {device}")
# 创建模型
model = SimpleMNISTNet().to(device)
# 确保所有参数都是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 获取数据
print("📥 加载MNIST数据...")
train_loader, test_loader = get_mnist_data(batch_size=64)
# 使用XLA并行加载器
train_device_loader = pl.MpDeviceLoader(train_loader, device)
print("🎯 开始训练...")
model.train()
start_time = time.time()
total_loss = 0.0
correct = 0
total = 0
max_batches = 100 # 只训练100个批次快速验证
for batch_idx, (data, target) in enumerate(train_device_loader):
if batch_idx >= max_batches:
break
# 确保数据类型正确
data = data.to(torch.float32)
target = target.to(torch.long)
# 前向传播
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# 反向传播
loss.backward()
optimizer.step()
# 统计
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
# 每10个批次同步一次
if batch_idx % 10 == 0:
xm.mark_step()
current_acc = 100. * correct / total
avg_loss = total_loss / (batch_idx + 1)
print(f'批次 {batch_idx:3d}/{max_batches} | '
f'损失: {avg_loss:.4f} | '
f'准确率: {current_acc:.2f}%')
# 最终同步
xm.mark_step()
xm.wait_device_ops()
train_time = time.time() - start_time
final_acc = 100. * correct / total
final_loss = total_loss / min(batch_idx + 1, max_batches)
print(f"\n✅ 训练完成!")
print(f"⏱️ 训练时间: {train_time:.2f}")
print(f"🎯 最终损失: {final_loss:.4f}")
print(f"🎯 训练准确率: {final_acc:.2f}%")
return model, final_loss, final_acc
def test_mnist(model):
"""测试MNIST模型"""
print("\n🧪 开始测试...")
device = xm.xla_device()
_, test_loader = get_mnist_data(batch_size=64)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
model.eval()
correct = 0
total = 0
max_test_batches = 50 # 只测试50个批次
start_time = time.time()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_device_loader):
if batch_idx >= max_test_batches:
break
# 确保数据类型
data = data.to(torch.float32)
target = target.to(torch.long)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
if batch_idx % 10 == 0:
xm.mark_step()
xm.mark_step()
xm.wait_device_ops()
test_time = time.time() - start_time
accuracy = 100. * correct / total
print(f"✅ 测试完成!")
print(f"⏱️ 测试时间: {test_time:.2f}")
print(f"🎯 测试准确率: {accuracy:.2f}%")
return accuracy
def main():
"""主函数"""
print("=" * 60)
print("🔢 超简单MNIST TPU训练 (仅float32)")
print("=" * 60)
try:
# 训练
model, train_loss, train_acc = train_mnist()
# 测试
test_acc = test_mnist(model)
# 保存模型
print("\n💾 保存模型...")
model_cpu = model.cpu()
torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth')
print("✅ 模型已保存")
print("\n🎉 全部完成!")
print(f"📊 训练准确率: {train_acc:.2f}%")
print(f"📊 测试准确率: {test_acc:.2f}%")
if train_acc > 80 and test_acc > 75:
print("✅ 模型训练成功!")
else:
print("⚠️ 模型性能一般但TPU功能正常")
except Exception as e:
print(f"❌ 训练失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()