Files
b2txt25/model_training_nnn_tpu/amp_tpu_training.py
Zchen 7965f7dbfe TPU
2025-10-15 16:55:52 +08:00

315 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
使用AMP的TPU训练脚本
正确处理混合精度训练避免dtype不匹配问题
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 设置AMP相关的环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true'
)
os.environ['XLA_USE_BF16'] = '1' # 启用bf16
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.amp as xla_amp
class AMPModel(nn.Module):
"""支持AMP的简单模型"""
def __init__(self, input_size=784, hidden_size=512, num_classes=10):
super(AMPModel, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(hidden_size // 2, num_classes)
)
def forward(self, x):
# 展平输入
x = x.view(x.size(0), -1)
return self.network(x)
class AMPTrainer:
"""AMP训练器"""
def __init__(self, model, device, learning_rate=0.001):
self.model = model
self.device = device
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss()
# 初始化AMP scaler
self.scaler = xla_amp.GradScaler()
print(f"✅ AMP训练器初始化完成")
print(f" 设备: {device}")
print(f" 模型参数: {sum(p.numel() for p in model.parameters()):,}")
def train_step(self, data, target):
"""单个AMP训练步骤"""
self.model.train()
self.optimizer.zero_grad()
# 使用autocast进行混合精度前向传播
with xla_amp.autocast():
output = self.model(data)
loss = self.criterion(output, target)
# 使用scaler进行反向传播
self.scaler.scale(loss).backward()
# 梯度裁剪(可选)
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.scaler.step(self.optimizer)
self.scaler.update()
# 计算准确率
pred = output.argmax(dim=1)
correct = pred.eq(target).sum().item()
accuracy = correct / target.size(0)
return loss.item(), accuracy
def evaluate_step(self, data, target):
"""单个评估步骤"""
self.model.eval()
with torch.no_grad():
with xla_amp.autocast():
output = self.model(data)
loss = self.criterion(output, target)
pred = output.argmax(dim=1)
correct = pred.eq(target).sum().item()
accuracy = correct / target.size(0)
return loss.item(), accuracy
def get_mnist_loaders(batch_size=64):
"""获取MNIST数据加载器"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=False,
download=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
return train_loader, test_loader
def train_with_amp():
"""使用AMP进行训练"""
print("🚀 开始AMP TPU训练...")
# 获取设备
device = xm.xla_device()
print(f"📱 设备: {device}")
# 创建模型
model = AMPModel(input_size=784, hidden_size=512, num_classes=10).to(device)
# 创建训练器
trainer = AMPTrainer(model, device, learning_rate=0.001)
# 获取数据
print("📥 加载MNIST数据...")
train_loader, test_loader = get_mnist_loaders(batch_size=64)
# 使用XLA并行加载器
train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
print("🎯 开始AMP训练...")
# 训练循环
num_epochs = 2
train_losses = []
train_accuracies = []
for epoch in range(num_epochs):
print(f"\n📊 Epoch {epoch + 1}/{num_epochs}")
epoch_start = time.time()
epoch_loss = 0.0
epoch_acc = 0.0
num_batches = 0
max_batches_per_epoch = 200 # 限制每个epoch的批次数
for batch_idx, (data, target) in enumerate(train_device_loader):
if batch_idx >= max_batches_per_epoch:
break
# 训练步骤
loss, accuracy = trainer.train_step(data, target)
epoch_loss += loss
epoch_acc += accuracy
num_batches += 1
# 每20个批次同步一次
if batch_idx % 20 == 0:
xm.mark_step()
avg_loss = epoch_loss / num_batches
avg_acc = epoch_acc / num_batches * 100
print(f" 批次 {batch_idx:3d}/{max_batches_per_epoch} | "
f"损失: {avg_loss:.4f} | "
f"准确率: {avg_acc:.2f}%")
# Epoch结束同步
xm.mark_step()
xm.wait_device_ops()
epoch_time = time.time() - epoch_start
final_loss = epoch_loss / num_batches
final_acc = epoch_acc / num_batches * 100
train_losses.append(final_loss)
train_accuracies.append(final_acc)
print(f"✅ Epoch {epoch + 1} 完成 | "
f"耗时: {epoch_time:.2f}s | "
f"平均损失: {final_loss:.4f} | "
f"平均准确率: {final_acc:.2f}%")
return trainer, train_losses, train_accuracies
def test_with_amp(trainer):
"""使用AMP进行测试"""
print("\n🧪 开始AMP测试...")
device = xm.xla_device()
_, test_loader = get_mnist_loaders(batch_size=64)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
total_loss = 0.0
total_acc = 0.0
num_batches = 0
max_test_batches = 100
start_time = time.time()
for batch_idx, (data, target) in enumerate(test_device_loader):
if batch_idx >= max_test_batches:
break
loss, accuracy = trainer.evaluate_step(data, target)
total_loss += loss
total_acc += accuracy
num_batches += 1
if batch_idx % 20 == 0:
xm.mark_step()
xm.mark_step()
xm.wait_device_ops()
test_time = time.time() - start_time
avg_loss = total_loss / num_batches
avg_acc = total_acc / num_batches * 100
print(f"✅ 测试完成!")
print(f"⏱️ 测试时间: {test_time:.2f}")
print(f"🎯 测试损失: {avg_loss:.4f}")
print(f"🎯 测试准确率: {avg_acc:.2f}%")
return avg_loss, avg_acc
def main():
"""主函数"""
print("=" * 60)
print("⚡ AMP TPU训练示例")
print("=" * 60)
try:
# 训练
trainer, train_losses, train_accuracies = train_with_amp()
# 测试
test_loss, test_acc = test_with_amp(trainer)
# 保存模型
print("\n💾 保存模型...")
model_cpu = trainer.model.cpu()
torch.save({
'model_state_dict': model_cpu.state_dict(),
'train_losses': train_losses,
'train_accuracies': train_accuracies,
'test_loss': test_loss,
'test_accuracy': test_acc
}, 'amp_mnist_model.pth')
print("✅ 模型已保存到 amp_mnist_model.pth")
print("\n🎉 AMP训练完成!")
print(f"📊 最终训练准确率: {train_accuracies[-1]:.2f}%")
print(f"📊 测试准确率: {test_acc:.2f}%")
if train_accuracies[-1] > 85 and test_acc > 80:
print("✅ AMP训练成功! 模型性能优秀")
else:
print("⚠️ 模型性能一般但AMP功能正常")
except Exception as e:
print(f"❌ AMP训练失败: {e}")
import traceback
traceback.print_exc()
print("\n💡 故障排除建议:")
print(" 1. 确保PyTorch XLA版本支持AMP")
print(" 2. 检查TPU资源是否充足")
print(" 3. 尝试减小batch_size")
if __name__ == "__main__":
main()