253 lines
6.4 KiB
Python
253 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
超简单MNIST TPU训练 - 完全避开混合精度问题
|
||
只使用float32,确保稳定运行
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import torchvision
|
||
import torchvision.transforms as transforms
|
||
|
||
# 清理所有可能导致bf16问题的环境变量
|
||
for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']:
|
||
if key in os.environ:
|
||
del os.environ[key]
|
||
|
||
# 只设置最基本的XLA优化
|
||
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false'
|
||
|
||
import torch_xla.core.xla_model as xm
|
||
import torch_xla.distributed.parallel_loader as pl
|
||
|
||
|
||
class SimpleMNISTNet(nn.Module):
|
||
"""超简单的MNIST分类器"""
|
||
|
||
def __init__(self):
|
||
super(SimpleMNISTNet, self).__init__()
|
||
self.flatten = nn.Flatten()
|
||
self.fc1 = nn.Linear(28 * 28, 128)
|
||
self.relu1 = nn.ReLU()
|
||
self.fc2 = nn.Linear(128, 64)
|
||
self.relu2 = nn.ReLU()
|
||
self.fc3 = nn.Linear(64, 10)
|
||
|
||
def forward(self, x):
|
||
x = self.flatten(x)
|
||
x = self.relu1(self.fc1(x))
|
||
x = self.relu2(self.fc2(x))
|
||
x = self.fc3(x)
|
||
return x
|
||
|
||
|
||
def get_mnist_data(batch_size=64):
|
||
"""获取MNIST数据"""
|
||
transform = transforms.Compose([
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.1307,), (0.3081,))
|
||
])
|
||
|
||
train_dataset = torchvision.datasets.MNIST(
|
||
root='./mnist_data',
|
||
train=True,
|
||
download=True,
|
||
transform=transform
|
||
)
|
||
|
||
test_dataset = torchvision.datasets.MNIST(
|
||
root='./mnist_data',
|
||
train=False,
|
||
download=True,
|
||
transform=transform
|
||
)
|
||
|
||
train_loader = torch.utils.data.DataLoader(
|
||
train_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=True,
|
||
num_workers=0
|
||
)
|
||
|
||
test_loader = torch.utils.data.DataLoader(
|
||
test_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=False,
|
||
num_workers=0
|
||
)
|
||
|
||
return train_loader, test_loader
|
||
|
||
|
||
def train_mnist():
|
||
"""训练MNIST模型"""
|
||
print("🚀 开始MNIST TPU训练...")
|
||
|
||
# 获取设备
|
||
device = xm.xla_device()
|
||
print(f"📱 设备: {device}")
|
||
|
||
# 创建模型
|
||
model = SimpleMNISTNet().to(device)
|
||
|
||
# 确保所有参数都是float32
|
||
for param in model.parameters():
|
||
param.data = param.data.to(torch.float32)
|
||
|
||
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||
|
||
# 损失函数和优化器
|
||
criterion = nn.CrossEntropyLoss()
|
||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||
|
||
# 获取数据
|
||
print("📥 加载MNIST数据...")
|
||
train_loader, test_loader = get_mnist_data(batch_size=64)
|
||
|
||
# 使用XLA并行加载器
|
||
train_device_loader = pl.MpDeviceLoader(train_loader, device)
|
||
|
||
print("🎯 开始训练...")
|
||
|
||
model.train()
|
||
start_time = time.time()
|
||
|
||
total_loss = 0.0
|
||
correct = 0
|
||
total = 0
|
||
max_batches = 100 # 只训练100个批次,快速验证
|
||
|
||
for batch_idx, (data, target) in enumerate(train_device_loader):
|
||
if batch_idx >= max_batches:
|
||
break
|
||
|
||
# 确保数据类型正确
|
||
data = data.to(torch.float32)
|
||
target = target.to(torch.long)
|
||
|
||
# 前向传播
|
||
optimizer.zero_grad()
|
||
output = model(data)
|
||
loss = criterion(output, target)
|
||
|
||
# 反向传播
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# 统计
|
||
total_loss += loss.item()
|
||
pred = output.argmax(dim=1)
|
||
correct += pred.eq(target).sum().item()
|
||
total += target.size(0)
|
||
|
||
# 每10个批次同步一次
|
||
if batch_idx % 10 == 0:
|
||
xm.mark_step()
|
||
current_acc = 100. * correct / total
|
||
avg_loss = total_loss / (batch_idx + 1)
|
||
|
||
print(f'批次 {batch_idx:3d}/{max_batches} | '
|
||
f'损失: {avg_loss:.4f} | '
|
||
f'准确率: {current_acc:.2f}%')
|
||
|
||
# 最终同步
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
|
||
train_time = time.time() - start_time
|
||
final_acc = 100. * correct / total
|
||
final_loss = total_loss / min(batch_idx + 1, max_batches)
|
||
|
||
print(f"\n✅ 训练完成!")
|
||
print(f"⏱️ 训练时间: {train_time:.2f}秒")
|
||
print(f"🎯 最终损失: {final_loss:.4f}")
|
||
print(f"🎯 训练准确率: {final_acc:.2f}%")
|
||
|
||
return model, final_loss, final_acc
|
||
|
||
|
||
def test_mnist(model):
|
||
"""测试MNIST模型"""
|
||
print("\n🧪 开始测试...")
|
||
|
||
device = xm.xla_device()
|
||
_, test_loader = get_mnist_data(batch_size=64)
|
||
|
||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||
|
||
model.eval()
|
||
correct = 0
|
||
total = 0
|
||
max_test_batches = 50 # 只测试50个批次
|
||
|
||
start_time = time.time()
|
||
|
||
with torch.no_grad():
|
||
for batch_idx, (data, target) in enumerate(test_device_loader):
|
||
if batch_idx >= max_test_batches:
|
||
break
|
||
|
||
# 确保数据类型
|
||
data = data.to(torch.float32)
|
||
target = target.to(torch.long)
|
||
|
||
output = model(data)
|
||
pred = output.argmax(dim=1)
|
||
correct += pred.eq(target).sum().item()
|
||
total += target.size(0)
|
||
|
||
if batch_idx % 10 == 0:
|
||
xm.mark_step()
|
||
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
|
||
test_time = time.time() - start_time
|
||
accuracy = 100. * correct / total
|
||
|
||
print(f"✅ 测试完成!")
|
||
print(f"⏱️ 测试时间: {test_time:.2f}秒")
|
||
print(f"🎯 测试准确率: {accuracy:.2f}%")
|
||
|
||
return accuracy
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("=" * 60)
|
||
print("🔢 超简单MNIST TPU训练 (仅float32)")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 训练
|
||
model, train_loss, train_acc = train_mnist()
|
||
|
||
# 测试
|
||
test_acc = test_mnist(model)
|
||
|
||
# 保存模型
|
||
print("\n💾 保存模型...")
|
||
model_cpu = model.cpu()
|
||
torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth')
|
||
print("✅ 模型已保存")
|
||
|
||
print("\n🎉 全部完成!")
|
||
print(f"📊 训练准确率: {train_acc:.2f}%")
|
||
print(f"📊 测试准确率: {test_acc:.2f}%")
|
||
|
||
if train_acc > 80 and test_acc > 75:
|
||
print("✅ 模型训练成功!")
|
||
else:
|
||
print("⚠️ 模型性能一般,但TPU功能正常")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 训练失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |