TPU
This commit is contained in:
315
model_training_nnn_tpu/amp_tpu_training.py
Normal file
315
model_training_nnn_tpu/amp_tpu_training.py
Normal file
@@ -0,0 +1,315 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
使用AMP的TPU训练脚本
|
||||
正确处理混合精度训练,避免dtype不匹配问题
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
# 设置AMP相关的环境变量
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true'
|
||||
)
|
||||
os.environ['XLA_USE_BF16'] = '1' # 启用bf16
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
import torch_xla.amp as xla_amp
|
||||
|
||||
|
||||
class AMPModel(nn.Module):
|
||||
"""支持AMP的简单模型"""
|
||||
|
||||
def __init__(self, input_size=784, hidden_size=512, num_classes=10):
|
||||
super(AMPModel, self).__init__()
|
||||
|
||||
self.network = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_size, hidden_size // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_size // 2, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# 展平输入
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.network(x)
|
||||
|
||||
|
||||
class AMPTrainer:
|
||||
"""AMP训练器"""
|
||||
|
||||
def __init__(self, model, device, learning_rate=0.001):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# 初始化AMP scaler
|
||||
self.scaler = xla_amp.GradScaler()
|
||||
|
||||
print(f"✅ AMP训练器初始化完成")
|
||||
print(f" 设备: {device}")
|
||||
print(f" 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
def train_step(self, data, target):
|
||||
"""单个AMP训练步骤"""
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 使用autocast进行混合精度前向传播
|
||||
with xla_amp.autocast():
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# 使用scaler进行反向传播
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# 梯度裁剪(可选)
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
|
||||
# 更新参数
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# 计算准确率
|
||||
pred = output.argmax(dim=1)
|
||||
correct = pred.eq(target).sum().item()
|
||||
accuracy = correct / target.size(0)
|
||||
|
||||
return loss.item(), accuracy
|
||||
|
||||
def evaluate_step(self, data, target):
|
||||
"""单个评估步骤"""
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with xla_amp.autocast():
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
pred = output.argmax(dim=1)
|
||||
correct = pred.eq(target).sum().item()
|
||||
accuracy = correct / target.size(0)
|
||||
|
||||
return loss.item(), accuracy
|
||||
|
||||
|
||||
def get_mnist_loaders(batch_size=64):
|
||||
"""获取MNIST数据加载器"""
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
test_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
def train_with_amp():
|
||||
"""使用AMP进行训练"""
|
||||
print("🚀 开始AMP TPU训练...")
|
||||
|
||||
# 获取设备
|
||||
device = xm.xla_device()
|
||||
print(f"📱 设备: {device}")
|
||||
|
||||
# 创建模型
|
||||
model = AMPModel(input_size=784, hidden_size=512, num_classes=10).to(device)
|
||||
|
||||
# 创建训练器
|
||||
trainer = AMPTrainer(model, device, learning_rate=0.001)
|
||||
|
||||
# 获取数据
|
||||
print("📥 加载MNIST数据...")
|
||||
train_loader, test_loader = get_mnist_loaders(batch_size=64)
|
||||
|
||||
# 使用XLA并行加载器
|
||||
train_device_loader = pl.MpDeviceLoader(train_loader, device)
|
||||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||||
|
||||
print("🎯 开始AMP训练...")
|
||||
|
||||
# 训练循环
|
||||
num_epochs = 2
|
||||
train_losses = []
|
||||
train_accuracies = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\n📊 Epoch {epoch + 1}/{num_epochs}")
|
||||
|
||||
epoch_start = time.time()
|
||||
epoch_loss = 0.0
|
||||
epoch_acc = 0.0
|
||||
num_batches = 0
|
||||
max_batches_per_epoch = 200 # 限制每个epoch的批次数
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_device_loader):
|
||||
if batch_idx >= max_batches_per_epoch:
|
||||
break
|
||||
|
||||
# 训练步骤
|
||||
loss, accuracy = trainer.train_step(data, target)
|
||||
|
||||
epoch_loss += loss
|
||||
epoch_acc += accuracy
|
||||
num_batches += 1
|
||||
|
||||
# 每20个批次同步一次
|
||||
if batch_idx % 20 == 0:
|
||||
xm.mark_step()
|
||||
|
||||
avg_loss = epoch_loss / num_batches
|
||||
avg_acc = epoch_acc / num_batches * 100
|
||||
|
||||
print(f" 批次 {batch_idx:3d}/{max_batches_per_epoch} | "
|
||||
f"损失: {avg_loss:.4f} | "
|
||||
f"准确率: {avg_acc:.2f}%")
|
||||
|
||||
# Epoch结束同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
final_loss = epoch_loss / num_batches
|
||||
final_acc = epoch_acc / num_batches * 100
|
||||
|
||||
train_losses.append(final_loss)
|
||||
train_accuracies.append(final_acc)
|
||||
|
||||
print(f"✅ Epoch {epoch + 1} 完成 | "
|
||||
f"耗时: {epoch_time:.2f}s | "
|
||||
f"平均损失: {final_loss:.4f} | "
|
||||
f"平均准确率: {final_acc:.2f}%")
|
||||
|
||||
return trainer, train_losses, train_accuracies
|
||||
|
||||
|
||||
def test_with_amp(trainer):
|
||||
"""使用AMP进行测试"""
|
||||
print("\n🧪 开始AMP测试...")
|
||||
|
||||
device = xm.xla_device()
|
||||
_, test_loader = get_mnist_loaders(batch_size=64)
|
||||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||||
|
||||
total_loss = 0.0
|
||||
total_acc = 0.0
|
||||
num_batches = 0
|
||||
max_test_batches = 100
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, (data, target) in enumerate(test_device_loader):
|
||||
if batch_idx >= max_test_batches:
|
||||
break
|
||||
|
||||
loss, accuracy = trainer.evaluate_step(data, target)
|
||||
|
||||
total_loss += loss
|
||||
total_acc += accuracy
|
||||
num_batches += 1
|
||||
|
||||
if batch_idx % 20 == 0:
|
||||
xm.mark_step()
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
test_time = time.time() - start_time
|
||||
avg_loss = total_loss / num_batches
|
||||
avg_acc = total_acc / num_batches * 100
|
||||
|
||||
print(f"✅ 测试完成!")
|
||||
print(f"⏱️ 测试时间: {test_time:.2f}秒")
|
||||
print(f"🎯 测试损失: {avg_loss:.4f}")
|
||||
print(f"🎯 测试准确率: {avg_acc:.2f}%")
|
||||
|
||||
return avg_loss, avg_acc
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("⚡ AMP TPU训练示例")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 训练
|
||||
trainer, train_losses, train_accuracies = train_with_amp()
|
||||
|
||||
# 测试
|
||||
test_loss, test_acc = test_with_amp(trainer)
|
||||
|
||||
# 保存模型
|
||||
print("\n💾 保存模型...")
|
||||
model_cpu = trainer.model.cpu()
|
||||
torch.save({
|
||||
'model_state_dict': model_cpu.state_dict(),
|
||||
'train_losses': train_losses,
|
||||
'train_accuracies': train_accuracies,
|
||||
'test_loss': test_loss,
|
||||
'test_accuracy': test_acc
|
||||
}, 'amp_mnist_model.pth')
|
||||
print("✅ 模型已保存到 amp_mnist_model.pth")
|
||||
|
||||
print("\n🎉 AMP训练完成!")
|
||||
print(f"📊 最终训练准确率: {train_accuracies[-1]:.2f}%")
|
||||
print(f"📊 测试准确率: {test_acc:.2f}%")
|
||||
|
||||
if train_accuracies[-1] > 85 and test_acc > 80:
|
||||
print("✅ AMP训练成功! 模型性能优秀")
|
||||
else:
|
||||
print("⚠️ 模型性能一般,但AMP功能正常")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ AMP训练失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n💡 故障排除建议:")
|
||||
print(" 1. 确保PyTorch XLA版本支持AMP")
|
||||
print(" 2. 检查TPU资源是否充足")
|
||||
print(" 3. 尝试减小batch_size")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user