tpu test
This commit is contained in:
194
model_training_nnn_tpu/minimal_tpu_test.py
Normal file
194
model_training_nnn_tpu/minimal_tpu_test.py
Normal file
@@ -0,0 +1,194 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
最简单的TPU测试 - 完全避开bf16问题
|
||||
只使用float32,最基本的操作
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 完全不设置任何bf16相关的环境变量
|
||||
# 只设置最基本的XLA优化
|
||||
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true'
|
||||
|
||||
# 确保不使用bf16
|
||||
if 'XLA_USE_BF16' in os.environ:
|
||||
del os.environ['XLA_USE_BF16']
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
"""测试最基本的TPU操作"""
|
||||
print("🚀 测试最基本的TPU操作...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
print(f"📱 设备: {device}")
|
||||
|
||||
# 测试1: 基本张量操作
|
||||
print("🔧 测试基本张量操作...")
|
||||
a = torch.randn(4, 4, device=device, dtype=torch.float32)
|
||||
b = torch.randn(4, 4, device=device, dtype=torch.float32)
|
||||
c = a + b
|
||||
|
||||
print(f" a.shape: {a.shape}, dtype: {a.dtype}")
|
||||
print(f" b.shape: {b.shape}, dtype: {b.dtype}")
|
||||
print(f" c.shape: {c.shape}, dtype: {c.dtype}")
|
||||
|
||||
# 同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
print("✅ 基本张量操作成功")
|
||||
|
||||
# 测试2: 矩阵乘法
|
||||
print("🔧 测试矩阵乘法...")
|
||||
d = torch.mm(a, b)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}")
|
||||
print("✅ 矩阵乘法成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 基本操作失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_simple_model():
|
||||
"""测试最简单的模型"""
|
||||
print("\n🧠 测试最简单的模型...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
|
||||
# 超简单的线性模型
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 5),
|
||||
nn.ReLU(),
|
||||
nn.Linear(5, 2)
|
||||
).to(device)
|
||||
|
||||
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
# 确保所有参数都是float32
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
# 创建输入数据 - 明确指定float32
|
||||
x = torch.randn(8, 10, device=device, dtype=torch.float32)
|
||||
|
||||
print(f"📥 输入: shape={x.shape}, dtype={x.dtype}")
|
||||
|
||||
# 前向传播
|
||||
with torch.no_grad():
|
||||
output = model(x)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
print(f"📤 输出: shape={output.shape}, dtype={output.dtype}")
|
||||
print("✅ 简单模型前向传播成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 简单模型失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_training_step():
|
||||
"""测试最简单的训练步骤"""
|
||||
print("\n🎯 测试最简单的训练步骤...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
|
||||
# 超简单模型
|
||||
model = nn.Linear(10, 1).to(device)
|
||||
|
||||
# 确保权重是float32
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
# 创建数据 - 明确float32
|
||||
x = torch.randn(4, 10, device=device, dtype=torch.float32)
|
||||
y = torch.randn(4, 1, device=device, dtype=torch.float32)
|
||||
|
||||
print(f"📥 输入: {x.shape}, {x.dtype}")
|
||||
print(f"📥 标签: {y.shape}, {y.dtype}")
|
||||
|
||||
# 一个训练步骤
|
||||
optimizer.zero_grad()
|
||||
output = model(x)
|
||||
loss = criterion(output, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
print(f"🎯 损失: {loss.item():.4f}")
|
||||
print("✅ 训练步骤成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 训练步骤失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 50)
|
||||
print("🔬 最简TPU测试 (仅float32)")
|
||||
print("=" * 50)
|
||||
|
||||
all_passed = True
|
||||
|
||||
# 测试1: 基本操作
|
||||
if test_basic_operations():
|
||||
print("1️⃣ 基本操作 ✅")
|
||||
else:
|
||||
print("1️⃣ 基本操作 ❌")
|
||||
all_passed = False
|
||||
|
||||
# 测试2: 简单模型
|
||||
if test_simple_model():
|
||||
print("2️⃣ 简单模型 ✅")
|
||||
else:
|
||||
print("2️⃣ 简单模型 ❌")
|
||||
all_passed = False
|
||||
|
||||
# 测试3: 训练步骤
|
||||
if test_training_step():
|
||||
print("3️⃣ 训练步骤 ✅")
|
||||
else:
|
||||
print("3️⃣ 训练步骤 ❌")
|
||||
all_passed = False
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
if all_passed:
|
||||
print("🎉 所有测试通过! TPU工作正常")
|
||||
print("💡 现在可以尝试更复杂的模型")
|
||||
else:
|
||||
print("❌ 部分测试失败")
|
||||
print("💡 建议:")
|
||||
print(" 1. 检查TPU资源是否可用")
|
||||
print(" 2. 确认torch_xla安装正确")
|
||||
print(" 3. 重启runtime清理状态")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
253
model_training_nnn_tpu/mnist_tpu_simple.py
Normal file
253
model_training_nnn_tpu/mnist_tpu_simple.py
Normal file
@@ -0,0 +1,253 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
超简单MNIST TPU训练 - 完全避开混合精度问题
|
||||
只使用float32,确保稳定运行
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
# 清理所有可能导致bf16问题的环境变量
|
||||
for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# 只设置最基本的XLA优化
|
||||
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false'
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
|
||||
|
||||
class SimpleMNISTNet(nn.Module):
|
||||
"""超简单的MNIST分类器"""
|
||||
|
||||
def __init__(self):
|
||||
super(SimpleMNISTNet, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc1 = nn.Linear(28 * 28, 128)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.fc3 = nn.Linear(64, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.flatten(x)
|
||||
x = self.relu1(self.fc1(x))
|
||||
x = self.relu2(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_mnist_data(batch_size=64):
|
||||
"""获取MNIST数据"""
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
test_dataset = torchvision.datasets.MNIST(
|
||||
root='./mnist_data',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
def train_mnist():
|
||||
"""训练MNIST模型"""
|
||||
print("🚀 开始MNIST TPU训练...")
|
||||
|
||||
# 获取设备
|
||||
device = xm.xla_device()
|
||||
print(f"📱 设备: {device}")
|
||||
|
||||
# 创建模型
|
||||
model = SimpleMNISTNet().to(device)
|
||||
|
||||
# 确保所有参数都是float32
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# 损失函数和优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
# 获取数据
|
||||
print("📥 加载MNIST数据...")
|
||||
train_loader, test_loader = get_mnist_data(batch_size=64)
|
||||
|
||||
# 使用XLA并行加载器
|
||||
train_device_loader = pl.MpDeviceLoader(train_loader, device)
|
||||
|
||||
print("🎯 开始训练...")
|
||||
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
max_batches = 100 # 只训练100个批次,快速验证
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_device_loader):
|
||||
if batch_idx >= max_batches:
|
||||
break
|
||||
|
||||
# 确保数据类型正确
|
||||
data = data.to(torch.float32)
|
||||
target = target.to(torch.long)
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 统计
|
||||
total_loss += loss.item()
|
||||
pred = output.argmax(dim=1)
|
||||
correct += pred.eq(target).sum().item()
|
||||
total += target.size(0)
|
||||
|
||||
# 每10个批次同步一次
|
||||
if batch_idx % 10 == 0:
|
||||
xm.mark_step()
|
||||
current_acc = 100. * correct / total
|
||||
avg_loss = total_loss / (batch_idx + 1)
|
||||
|
||||
print(f'批次 {batch_idx:3d}/{max_batches} | '
|
||||
f'损失: {avg_loss:.4f} | '
|
||||
f'准确率: {current_acc:.2f}%')
|
||||
|
||||
# 最终同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
train_time = time.time() - start_time
|
||||
final_acc = 100. * correct / total
|
||||
final_loss = total_loss / min(batch_idx + 1, max_batches)
|
||||
|
||||
print(f"\n✅ 训练完成!")
|
||||
print(f"⏱️ 训练时间: {train_time:.2f}秒")
|
||||
print(f"🎯 最终损失: {final_loss:.4f}")
|
||||
print(f"🎯 训练准确率: {final_acc:.2f}%")
|
||||
|
||||
return model, final_loss, final_acc
|
||||
|
||||
|
||||
def test_mnist(model):
|
||||
"""测试MNIST模型"""
|
||||
print("\n🧪 开始测试...")
|
||||
|
||||
device = xm.xla_device()
|
||||
_, test_loader = get_mnist_data(batch_size=64)
|
||||
|
||||
test_device_loader = pl.MpDeviceLoader(test_loader, device)
|
||||
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
max_test_batches = 50 # 只测试50个批次
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, target) in enumerate(test_device_loader):
|
||||
if batch_idx >= max_test_batches:
|
||||
break
|
||||
|
||||
# 确保数据类型
|
||||
data = data.to(torch.float32)
|
||||
target = target.to(torch.long)
|
||||
|
||||
output = model(data)
|
||||
pred = output.argmax(dim=1)
|
||||
correct += pred.eq(target).sum().item()
|
||||
total += target.size(0)
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
xm.mark_step()
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
test_time = time.time() - start_time
|
||||
accuracy = 100. * correct / total
|
||||
|
||||
print(f"✅ 测试完成!")
|
||||
print(f"⏱️ 测试时间: {test_time:.2f}秒")
|
||||
print(f"🎯 测试准确率: {accuracy:.2f}%")
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("🔢 超简单MNIST TPU训练 (仅float32)")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 训练
|
||||
model, train_loss, train_acc = train_mnist()
|
||||
|
||||
# 测试
|
||||
test_acc = test_mnist(model)
|
||||
|
||||
# 保存模型
|
||||
print("\n💾 保存模型...")
|
||||
model_cpu = model.cpu()
|
||||
torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth')
|
||||
print("✅ 模型已保存")
|
||||
|
||||
print("\n🎉 全部完成!")
|
||||
print(f"📊 训练准确率: {train_acc:.2f}%")
|
||||
print(f"📊 测试准确率: {test_acc:.2f}%")
|
||||
|
||||
if train_acc > 80 and test_acc > 75:
|
||||
print("✅ 模型训练成功!")
|
||||
else:
|
||||
print("⚠️ 模型性能一般,但TPU功能正常")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 训练失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,129 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 设置环境变量
|
||||
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true'
|
||||
os.environ['XLA_USE_BF16'] = '1'
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
def quick_test():
|
||||
"""快速测试TPU是否工作正常"""
|
||||
print("🚀 开始快速TPU测试...")
|
||||
|
||||
try:
|
||||
# 获取TPU设备
|
||||
device = xm.xla_device()
|
||||
print(f"📱 TPU设备: {device}")
|
||||
|
||||
# 创建简单模型
|
||||
model = nn.Sequential(
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.GRU(256, 128, batch_first=True),
|
||||
nn.Linear(128, 41)
|
||||
).to(device)
|
||||
|
||||
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# 创建测试数据
|
||||
x = torch.randn(8, 50, 512, device=device)
|
||||
print(f"📥 输入形状: {x.shape}")
|
||||
|
||||
# 测试前向传播
|
||||
print("🔄 测试前向传播...")
|
||||
start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
if hasattr(model, '__getitem__'):
|
||||
# 对于Sequential模型,手动处理GRU层
|
||||
x_proj = model[1](model[0](x)) # Linear + ReLU
|
||||
gru_out, _ = model[2](x_proj) # GRU
|
||||
output = model[3](gru_out) # Final Linear
|
||||
else:
|
||||
output = model(x)
|
||||
|
||||
# 同步TPU操作
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
forward_time = time.time() - start_time
|
||||
print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}秒")
|
||||
print(f"📤 输出形状: {output.shape}")
|
||||
|
||||
# 测试反向传播
|
||||
print("🔄 测试反向传播...")
|
||||
model.train()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建虚拟标签
|
||||
labels = torch.randint(0, 41, (8, 50), device=device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# 前向传播
|
||||
if hasattr(model, '__getitem__'):
|
||||
x_proj = model[1](model[0](x))
|
||||
gru_out, _ = model[2](x_proj)
|
||||
output = model[3](gru_out)
|
||||
else:
|
||||
output = model(x)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output.view(-1, 41), labels.view(-1))
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 同步TPU操作
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
backward_time = time.time() - start_time
|
||||
print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}秒")
|
||||
print(f"🎯 损失值: {loss.item():.4f}")
|
||||
|
||||
# 总结
|
||||
print(f"\n📈 性能总结:")
|
||||
print(f" 前向传播: {forward_time:.3f}秒")
|
||||
print(f" 反向传播: {backward_time:.3f}秒")
|
||||
print(f" 总计: {forward_time + backward_time:.3f}秒")
|
||||
|
||||
if (forward_time + backward_time) < 10: # 10秒内完成
|
||||
print("✅ TPU测试通过! 可以进行完整训练")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ TPU性能较慢,可能需要优化")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ TPU测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 50)
|
||||
print("⚡ 快速TPU测试")
|
||||
print("=" * 50)
|
||||
|
||||
success = quick_test()
|
||||
|
||||
if success:
|
||||
print("\n🎉 测试成功! 现在可以运行:")
|
||||
print(" python simple_tpu_model.py")
|
||||
else:
|
||||
print("\n❌ 测试失败,请检查TPU配置")
|
||||
|
||||
print("=" * 50)
|
@@ -1,367 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单TPU模型训练和测试脚本
|
||||
基于大脑到文本数据的简化版本,专门为TPU优化
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
# 设置XLA环境变量
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true '
|
||||
f'--xla_force_host_platform_device_count={os.cpu_count()}'
|
||||
)
|
||||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
|
||||
os.environ['XLA_USE_BF16'] = '1'
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
|
||||
|
||||
class SimpleBrainToTextModel(nn.Module):
|
||||
"""简化的大脑到文本模型 - TPU优化版本"""
|
||||
|
||||
def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3):
|
||||
super().__init__()
|
||||
|
||||
# 输入处理层
|
||||
self.input_proj = nn.Linear(input_features, hidden_size)
|
||||
self.input_dropout = nn.Dropout(0.2)
|
||||
|
||||
# GRU层 - 使用较小的隐藏层以提高TPU效率
|
||||
self.gru = nn.GRU(
|
||||
input_size=hidden_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=0.3 if num_layers > 1 else 0
|
||||
)
|
||||
|
||||
# 输出层
|
||||
self.output_proj = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
# 初始化权重
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""初始化模型权重"""
|
||||
for name, param in self.named_parameters():
|
||||
if 'weight' in name:
|
||||
if 'gru' in name:
|
||||
nn.init.orthogonal_(param)
|
||||
else:
|
||||
nn.init.xavier_uniform_(param)
|
||||
elif 'bias' in name:
|
||||
nn.init.zeros_(param)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
Args:
|
||||
x: (batch_size, seq_len, input_features)
|
||||
Returns:
|
||||
logits: (batch_size, seq_len, num_classes)
|
||||
"""
|
||||
# 输入投影
|
||||
x = torch.relu(self.input_proj(x))
|
||||
x = self.input_dropout(x)
|
||||
|
||||
# GRU处理
|
||||
output, _ = self.gru(x)
|
||||
|
||||
# 输出投影
|
||||
logits = self.output_proj(output)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class SimpleDataGenerator:
|
||||
"""简单的数据生成器 - 模拟大脑信号数据"""
|
||||
|
||||
def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41):
|
||||
self.batch_size = batch_size
|
||||
self.seq_len = seq_len
|
||||
self.input_features = input_features
|
||||
self.num_classes = num_classes
|
||||
|
||||
def generate_batch(self, device):
|
||||
"""生成一个批次的模拟数据"""
|
||||
# 生成模拟的神经信号数据
|
||||
features = torch.randn(
|
||||
self.batch_size, self.seq_len, self.input_features,
|
||||
device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
# 生成模拟的标签(音素序列)
|
||||
labels = torch.randint(
|
||||
0, self.num_classes,
|
||||
(self.batch_size, self.seq_len),
|
||||
device=device
|
||||
)
|
||||
|
||||
# 生成序列长度
|
||||
seq_lengths = torch.randint(
|
||||
self.seq_len // 2, self.seq_len + 1,
|
||||
(self.batch_size,),
|
||||
device=device
|
||||
)
|
||||
|
||||
return {
|
||||
'features': features,
|
||||
'labels': labels,
|
||||
'seq_lengths': seq_lengths
|
||||
}
|
||||
|
||||
|
||||
class SimpleTpuTrainer:
|
||||
"""简单的TPU训练器"""
|
||||
|
||||
def __init__(self, model, device, learning_rate=0.001):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
# 数据生成器
|
||||
self.data_generator = SimpleDataGenerator()
|
||||
|
||||
# 训练统计
|
||||
self.step = 0
|
||||
self.best_loss = float('inf')
|
||||
|
||||
def train_step(self, batch):
|
||||
"""单个训练步骤"""
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# 前向传播
|
||||
features = batch['features']
|
||||
labels = batch['labels']
|
||||
|
||||
logits = self.model(features)
|
||||
|
||||
# 计算损失 - 重新调整形状以适应CrossEntropyLoss
|
||||
batch_size, seq_len, num_classes = logits.shape
|
||||
loss = self.criterion(
|
||||
logits.reshape(-1, num_classes),
|
||||
labels.reshape(-1)
|
||||
)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
|
||||
# 梯度裁剪
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
|
||||
# 更新参数
|
||||
self.optimizer.step()
|
||||
|
||||
return loss.item()
|
||||
|
||||
def evaluate_step(self, batch):
|
||||
"""单个评估步骤"""
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
features = batch['features']
|
||||
labels = batch['labels']
|
||||
|
||||
logits = self.model(features)
|
||||
|
||||
# 计算损失
|
||||
batch_size, seq_len, num_classes = logits.shape
|
||||
loss = self.criterion(
|
||||
logits.reshape(-1, num_classes),
|
||||
labels.reshape(-1)
|
||||
)
|
||||
|
||||
# 计算准确率
|
||||
predictions = torch.argmax(logits, dim=-1)
|
||||
correct = (predictions == labels).float()
|
||||
accuracy = correct.mean()
|
||||
|
||||
return loss.item(), accuracy.item()
|
||||
|
||||
def train(self, num_steps=1000, eval_every=100, save_every=500):
|
||||
"""训练模型"""
|
||||
print(f"🚀 开始TPU训练 - 设备: {self.device}")
|
||||
print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}")
|
||||
|
||||
train_losses = []
|
||||
eval_losses = []
|
||||
eval_accuracies = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for step in range(num_steps):
|
||||
# 生成训练数据
|
||||
train_batch = self.data_generator.generate_batch(self.device)
|
||||
|
||||
# 训练步骤
|
||||
train_loss = self.train_step(train_batch)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# XLA同步
|
||||
if step % 10 == 0: # 每10步同步一次以提高效率
|
||||
xm.mark_step()
|
||||
|
||||
# 评估
|
||||
if step % eval_every == 0:
|
||||
eval_batch = self.data_generator.generate_batch(self.device)
|
||||
eval_loss, eval_acc = self.evaluate_step(eval_batch)
|
||||
eval_losses.append(eval_loss)
|
||||
eval_accuracies.append(eval_acc)
|
||||
|
||||
# 同步XLA操作以获得准确的时间
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
current_time = time.time()
|
||||
elapsed = current_time - start_time
|
||||
|
||||
print(f"步骤 {step:4d}/{num_steps} | "
|
||||
f"训练损失: {train_loss:.4f} | "
|
||||
f"验证损失: {eval_loss:.4f} | "
|
||||
f"验证准确率: {eval_acc:.4f} | "
|
||||
f"耗时: {elapsed:.1f}s")
|
||||
|
||||
# 保存最佳模型
|
||||
if eval_loss < self.best_loss:
|
||||
self.best_loss = eval_loss
|
||||
print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}")
|
||||
|
||||
# 定期保存
|
||||
if step > 0 and step % save_every == 0:
|
||||
self.save_checkpoint(f"checkpoint_step_{step}.pt")
|
||||
|
||||
# 最终同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"\n✅ 训练完成!")
|
||||
print(f"⏱️ 总耗时: {total_time:.1f}秒")
|
||||
print(f"🎯 最终训练损失: {train_losses[-1]:.4f}")
|
||||
if eval_losses:
|
||||
print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}")
|
||||
print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}")
|
||||
|
||||
return {
|
||||
'train_losses': train_losses,
|
||||
'eval_losses': eval_losses,
|
||||
'eval_accuracies': eval_accuracies,
|
||||
'total_time': total_time
|
||||
}
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
"""保存检查点"""
|
||||
checkpoint = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'step': self.step,
|
||||
'best_loss': self.best_loss,
|
||||
}
|
||||
|
||||
# 在TPU上需要先移动到CPU再保存
|
||||
if 'xla' in str(self.device):
|
||||
checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu'))
|
||||
|
||||
torch.save(checkpoint, filename)
|
||||
print(f"💾 保存检查点: {filename}")
|
||||
|
||||
def load_checkpoint(self, filename):
|
||||
"""加载检查点"""
|
||||
checkpoint = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.step = checkpoint['step']
|
||||
self.best_loss = checkpoint['best_loss']
|
||||
|
||||
print(f"📂 加载检查点: {filename}")
|
||||
print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}")
|
||||
|
||||
|
||||
def test_simple_inference():
|
||||
"""测试简单推理"""
|
||||
print("\n🧪 测试简单推理...")
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
# 创建模型
|
||||
model = SimpleBrainToTextModel().to(device)
|
||||
|
||||
# 创建测试数据
|
||||
batch_size = 4
|
||||
seq_len = 50
|
||||
test_input = torch.randn(batch_size, seq_len, 512, device=device)
|
||||
|
||||
# 推理
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
output = model(test_input)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
print(f"✅ 推理完成!")
|
||||
print(f" 输入形状: {test_input.shape}")
|
||||
print(f" 输出形状: {output.shape}")
|
||||
print(f" 推理时间: {inference_time:.4f}秒")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("🧠 简单TPU大脑到文本模型训练")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 检查TPU设备
|
||||
device = xm.xla_device()
|
||||
print(f"📱 使用设备: {device}")
|
||||
|
||||
# 创建模型
|
||||
model = SimpleBrainToTextModel(
|
||||
input_features=512,
|
||||
hidden_size=256,
|
||||
num_classes=41,
|
||||
num_layers=3
|
||||
).to(device)
|
||||
|
||||
# 创建训练器
|
||||
trainer = SimpleTpuTrainer(model, device, learning_rate=0.001)
|
||||
|
||||
# 开始训练
|
||||
results = trainer.train(
|
||||
num_steps=1000,
|
||||
eval_every=100,
|
||||
save_every=500
|
||||
)
|
||||
|
||||
# 保存最终模型
|
||||
trainer.save_checkpoint("final_simple_model.pt")
|
||||
|
||||
# 测试推理
|
||||
test_simple_inference()
|
||||
|
||||
print("\n🎉 所有测试完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 训练失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user