This commit is contained in:
Zchen
2025-10-15 15:22:13 +08:00
parent 082018cd46
commit b466e97463
4 changed files with 447 additions and 496 deletions

View File

@@ -0,0 +1,194 @@
#!/usr/bin/env python3
"""
最简单的TPU测试 - 完全避开bf16问题
只使用float32最基本的操作
"""
import os
import time
import torch
import torch.nn as nn
# 完全不设置任何bf16相关的环境变量
# 只设置最基本的XLA优化
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true'
# 确保不使用bf16
if 'XLA_USE_BF16' in os.environ:
del os.environ['XLA_USE_BF16']
import torch_xla.core.xla_model as xm
def test_basic_operations():
"""测试最基本的TPU操作"""
print("🚀 测试最基本的TPU操作...")
try:
device = xm.xla_device()
print(f"📱 设备: {device}")
# 测试1: 基本张量操作
print("🔧 测试基本张量操作...")
a = torch.randn(4, 4, device=device, dtype=torch.float32)
b = torch.randn(4, 4, device=device, dtype=torch.float32)
c = a + b
print(f" a.shape: {a.shape}, dtype: {a.dtype}")
print(f" b.shape: {b.shape}, dtype: {b.dtype}")
print(f" c.shape: {c.shape}, dtype: {c.dtype}")
# 同步
xm.mark_step()
xm.wait_device_ops()
print("✅ 基本张量操作成功")
# 测试2: 矩阵乘法
print("🔧 测试矩阵乘法...")
d = torch.mm(a, b)
xm.mark_step()
xm.wait_device_ops()
print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}")
print("✅ 矩阵乘法成功")
return True
except Exception as e:
print(f"❌ 基本操作失败: {e}")
return False
def test_simple_model():
"""测试最简单的模型"""
print("\n🧠 测试最简单的模型...")
try:
device = xm.xla_device()
# 超简单的线性模型
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
).to(device)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}")
# 确保所有参数都是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
# 创建输入数据 - 明确指定float32
x = torch.randn(8, 10, device=device, dtype=torch.float32)
print(f"📥 输入: shape={x.shape}, dtype={x.dtype}")
# 前向传播
with torch.no_grad():
output = model(x)
xm.mark_step()
xm.wait_device_ops()
print(f"📤 输出: shape={output.shape}, dtype={output.dtype}")
print("✅ 简单模型前向传播成功")
return True
except Exception as e:
print(f"❌ 简单模型失败: {e}")
import traceback
traceback.print_exc()
return False
def test_training_step():
"""测试最简单的训练步骤"""
print("\n🎯 测试最简单的训练步骤...")
try:
device = xm.xla_device()
# 超简单模型
model = nn.Linear(10, 1).to(device)
# 确保权重是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 创建数据 - 明确float32
x = torch.randn(4, 10, device=device, dtype=torch.float32)
y = torch.randn(4, 1, device=device, dtype=torch.float32)
print(f"📥 输入: {x.shape}, {x.dtype}")
print(f"📥 标签: {y.shape}, {y.dtype}")
# 一个训练步骤
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
# 同步
xm.mark_step()
xm.wait_device_ops()
print(f"🎯 损失: {loss.item():.4f}")
print("✅ 训练步骤成功")
return True
except Exception as e:
print(f"❌ 训练步骤失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""主函数"""
print("=" * 50)
print("🔬 最简TPU测试 (仅float32)")
print("=" * 50)
all_passed = True
# 测试1: 基本操作
if test_basic_operations():
print("1⃣ 基本操作 ✅")
else:
print("1⃣ 基本操作 ❌")
all_passed = False
# 测试2: 简单模型
if test_simple_model():
print("2⃣ 简单模型 ✅")
else:
print("2⃣ 简单模型 ❌")
all_passed = False
# 测试3: 训练步骤
if test_training_step():
print("3⃣ 训练步骤 ✅")
else:
print("3⃣ 训练步骤 ❌")
all_passed = False
print("=" * 50)
if all_passed:
print("🎉 所有测试通过! TPU工作正常")
print("💡 现在可以尝试更复杂的模型")
else:
print("❌ 部分测试失败")
print("💡 建议:")
print(" 1. 检查TPU资源是否可用")
print(" 2. 确认torch_xla安装正确")
print(" 3. 重启runtime清理状态")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,253 @@
#!/usr/bin/env python3
"""
超简单MNIST TPU训练 - 完全避开混合精度问题
只使用float32确保稳定运行
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 清理所有可能导致bf16问题的环境变量
for key in ['XLA_USE_BF16', 'XLA_DOWNCAST_BF16']:
if key in os.environ:
del os.environ[key]
# 只设置最基本的XLA优化
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=false'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
class SimpleMNISTNet(nn.Module):
"""超简单的MNIST分类器"""
def __init__(self):
super(SimpleMNISTNet, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(128, 64)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = self.flatten(x)
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x
def get_mnist_data(batch_size=64):
"""获取MNIST数据"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='./mnist_data',
train=False,
download=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
return train_loader, test_loader
def train_mnist():
"""训练MNIST模型"""
print("🚀 开始MNIST TPU训练...")
# 获取设备
device = xm.xla_device()
print(f"📱 设备: {device}")
# 创建模型
model = SimpleMNISTNet().to(device)
# 确保所有参数都是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 获取数据
print("📥 加载MNIST数据...")
train_loader, test_loader = get_mnist_data(batch_size=64)
# 使用XLA并行加载器
train_device_loader = pl.MpDeviceLoader(train_loader, device)
print("🎯 开始训练...")
model.train()
start_time = time.time()
total_loss = 0.0
correct = 0
total = 0
max_batches = 100 # 只训练100个批次快速验证
for batch_idx, (data, target) in enumerate(train_device_loader):
if batch_idx >= max_batches:
break
# 确保数据类型正确
data = data.to(torch.float32)
target = target.to(torch.long)
# 前向传播
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# 反向传播
loss.backward()
optimizer.step()
# 统计
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
# 每10个批次同步一次
if batch_idx % 10 == 0:
xm.mark_step()
current_acc = 100. * correct / total
avg_loss = total_loss / (batch_idx + 1)
print(f'批次 {batch_idx:3d}/{max_batches} | '
f'损失: {avg_loss:.4f} | '
f'准确率: {current_acc:.2f}%')
# 最终同步
xm.mark_step()
xm.wait_device_ops()
train_time = time.time() - start_time
final_acc = 100. * correct / total
final_loss = total_loss / min(batch_idx + 1, max_batches)
print(f"\n✅ 训练完成!")
print(f"⏱️ 训练时间: {train_time:.2f}")
print(f"🎯 最终损失: {final_loss:.4f}")
print(f"🎯 训练准确率: {final_acc:.2f}%")
return model, final_loss, final_acc
def test_mnist(model):
"""测试MNIST模型"""
print("\n🧪 开始测试...")
device = xm.xla_device()
_, test_loader = get_mnist_data(batch_size=64)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
model.eval()
correct = 0
total = 0
max_test_batches = 50 # 只测试50个批次
start_time = time.time()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_device_loader):
if batch_idx >= max_test_batches:
break
# 确保数据类型
data = data.to(torch.float32)
target = target.to(torch.long)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
if batch_idx % 10 == 0:
xm.mark_step()
xm.mark_step()
xm.wait_device_ops()
test_time = time.time() - start_time
accuracy = 100. * correct / total
print(f"✅ 测试完成!")
print(f"⏱️ 测试时间: {test_time:.2f}")
print(f"🎯 测试准确率: {accuracy:.2f}%")
return accuracy
def main():
"""主函数"""
print("=" * 60)
print("🔢 超简单MNIST TPU训练 (仅float32)")
print("=" * 60)
try:
# 训练
model, train_loss, train_acc = train_mnist()
# 测试
test_acc = test_mnist(model)
# 保存模型
print("\n💾 保存模型...")
model_cpu = model.cpu()
torch.save(model_cpu.state_dict(), 'mnist_simple_model.pth')
print("✅ 模型已保存")
print("\n🎉 全部完成!")
print(f"📊 训练准确率: {train_acc:.2f}%")
print(f"📊 测试准确率: {test_acc:.2f}%")
if train_acc > 80 and test_acc > 75:
print("✅ 模型训练成功!")
else:
print("⚠️ 模型性能一般但TPU功能正常")
except Exception as e:
print(f"❌ 训练失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,129 +0,0 @@
#!/usr/bin/env python3
"""
快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行
"""
import os
import time
import torch
import torch.nn as nn
# 设置环境变量
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true'
os.environ['XLA_USE_BF16'] = '1'
import torch_xla.core.xla_model as xm
def quick_test():
"""快速测试TPU是否工作正常"""
print("🚀 开始快速TPU测试...")
try:
# 获取TPU设备
device = xm.xla_device()
print(f"📱 TPU设备: {device}")
# 创建简单模型
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.GRU(256, 128, batch_first=True),
nn.Linear(128, 41)
).to(device)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 创建测试数据
x = torch.randn(8, 50, 512, device=device)
print(f"📥 输入形状: {x.shape}")
# 测试前向传播
print("🔄 测试前向传播...")
start_time = time.time()
with torch.no_grad():
if hasattr(model, '__getitem__'):
# 对于Sequential模型手动处理GRU层
x_proj = model[1](model[0](x)) # Linear + ReLU
gru_out, _ = model[2](x_proj) # GRU
output = model[3](gru_out) # Final Linear
else:
output = model(x)
# 同步TPU操作
xm.mark_step()
xm.wait_device_ops()
forward_time = time.time() - start_time
print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}")
print(f"📤 输出形状: {output.shape}")
# 测试反向传播
print("🔄 测试反向传播...")
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
start_time = time.time()
# 创建虚拟标签
labels = torch.randint(0, 41, (8, 50), device=device)
criterion = nn.CrossEntropyLoss()
# 前向传播
if hasattr(model, '__getitem__'):
x_proj = model[1](model[0](x))
gru_out, _ = model[2](x_proj)
output = model[3](gru_out)
else:
output = model(x)
# 计算损失
loss = criterion(output.view(-1, 41), labels.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 同步TPU操作
xm.mark_step()
xm.wait_device_ops()
backward_time = time.time() - start_time
print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}")
print(f"🎯 损失值: {loss.item():.4f}")
# 总结
print(f"\n📈 性能总结:")
print(f" 前向传播: {forward_time:.3f}")
print(f" 反向传播: {backward_time:.3f}")
print(f" 总计: {forward_time + backward_time:.3f}")
if (forward_time + backward_time) < 10: # 10秒内完成
print("✅ TPU测试通过! 可以进行完整训练")
return True
else:
print("⚠️ TPU性能较慢可能需要优化")
return False
except Exception as e:
print(f"❌ TPU测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("=" * 50)
print("⚡ 快速TPU测试")
print("=" * 50)
success = quick_test()
if success:
print("\n🎉 测试成功! 现在可以运行:")
print(" python simple_tpu_model.py")
else:
print("\n❌ 测试失败请检查TPU配置")
print("=" * 50)

View File

@@ -1,367 +0,0 @@
#!/usr/bin/env python3
"""
简单TPU模型训练和测试脚本
基于大脑到文本数据的简化版本专门为TPU优化
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, Any, Tuple
# 设置XLA环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
os.environ['XLA_USE_BF16'] = '1'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
class SimpleBrainToTextModel(nn.Module):
"""简化的大脑到文本模型 - TPU优化版本"""
def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3):
super().__init__()
# 输入处理层
self.input_proj = nn.Linear(input_features, hidden_size)
self.input_dropout = nn.Dropout(0.2)
# GRU层 - 使用较小的隐藏层以提高TPU效率
self.gru = nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=0.3 if num_layers > 1 else 0
)
# 输出层
self.output_proj = nn.Linear(hidden_size, num_classes)
# 初始化权重
self._init_weights()
def _init_weights(self):
"""初始化模型权重"""
for name, param in self.named_parameters():
if 'weight' in name:
if 'gru' in name:
nn.init.orthogonal_(param)
else:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)
def forward(self, x):
"""
前向传播
Args:
x: (batch_size, seq_len, input_features)
Returns:
logits: (batch_size, seq_len, num_classes)
"""
# 输入投影
x = torch.relu(self.input_proj(x))
x = self.input_dropout(x)
# GRU处理
output, _ = self.gru(x)
# 输出投影
logits = self.output_proj(output)
return logits
class SimpleDataGenerator:
"""简单的数据生成器 - 模拟大脑信号数据"""
def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41):
self.batch_size = batch_size
self.seq_len = seq_len
self.input_features = input_features
self.num_classes = num_classes
def generate_batch(self, device):
"""生成一个批次的模拟数据"""
# 生成模拟的神经信号数据
features = torch.randn(
self.batch_size, self.seq_len, self.input_features,
device=device, dtype=torch.float32
)
# 生成模拟的标签(音素序列)
labels = torch.randint(
0, self.num_classes,
(self.batch_size, self.seq_len),
device=device
)
# 生成序列长度
seq_lengths = torch.randint(
self.seq_len // 2, self.seq_len + 1,
(self.batch_size,),
device=device
)
return {
'features': features,
'labels': labels,
'seq_lengths': seq_lengths
}
class SimpleTpuTrainer:
"""简单的TPU训练器"""
def __init__(self, model, device, learning_rate=0.001):
self.model = model
self.device = device
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
# 数据生成器
self.data_generator = SimpleDataGenerator()
# 训练统计
self.step = 0
self.best_loss = float('inf')
def train_step(self, batch):
"""单个训练步骤"""
self.model.train()
self.optimizer.zero_grad()
# 前向传播
features = batch['features']
labels = batch['labels']
logits = self.model(features)
# 计算损失 - 重新调整形状以适应CrossEntropyLoss
batch_size, seq_len, num_classes = logits.shape
loss = self.criterion(
logits.reshape(-1, num_classes),
labels.reshape(-1)
)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.optimizer.step()
return loss.item()
def evaluate_step(self, batch):
"""单个评估步骤"""
self.model.eval()
with torch.no_grad():
features = batch['features']
labels = batch['labels']
logits = self.model(features)
# 计算损失
batch_size, seq_len, num_classes = logits.shape
loss = self.criterion(
logits.reshape(-1, num_classes),
labels.reshape(-1)
)
# 计算准确率
predictions = torch.argmax(logits, dim=-1)
correct = (predictions == labels).float()
accuracy = correct.mean()
return loss.item(), accuracy.item()
def train(self, num_steps=1000, eval_every=100, save_every=500):
"""训练模型"""
print(f"🚀 开始TPU训练 - 设备: {self.device}")
print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}")
train_losses = []
eval_losses = []
eval_accuracies = []
start_time = time.time()
for step in range(num_steps):
# 生成训练数据
train_batch = self.data_generator.generate_batch(self.device)
# 训练步骤
train_loss = self.train_step(train_batch)
train_losses.append(train_loss)
# XLA同步
if step % 10 == 0: # 每10步同步一次以提高效率
xm.mark_step()
# 评估
if step % eval_every == 0:
eval_batch = self.data_generator.generate_batch(self.device)
eval_loss, eval_acc = self.evaluate_step(eval_batch)
eval_losses.append(eval_loss)
eval_accuracies.append(eval_acc)
# 同步XLA操作以获得准确的时间
xm.mark_step()
xm.wait_device_ops()
current_time = time.time()
elapsed = current_time - start_time
print(f"步骤 {step:4d}/{num_steps} | "
f"训练损失: {train_loss:.4f} | "
f"验证损失: {eval_loss:.4f} | "
f"验证准确率: {eval_acc:.4f} | "
f"耗时: {elapsed:.1f}s")
# 保存最佳模型
if eval_loss < self.best_loss:
self.best_loss = eval_loss
print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}")
# 定期保存
if step > 0 and step % save_every == 0:
self.save_checkpoint(f"checkpoint_step_{step}.pt")
# 最终同步
xm.mark_step()
xm.wait_device_ops()
total_time = time.time() - start_time
print(f"\n✅ 训练完成!")
print(f"⏱️ 总耗时: {total_time:.1f}")
print(f"🎯 最终训练损失: {train_losses[-1]:.4f}")
if eval_losses:
print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}")
print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}")
return {
'train_losses': train_losses,
'eval_losses': eval_losses,
'eval_accuracies': eval_accuracies,
'total_time': total_time
}
def save_checkpoint(self, filename):
"""保存检查点"""
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'step': self.step,
'best_loss': self.best_loss,
}
# 在TPU上需要先移动到CPU再保存
if 'xla' in str(self.device):
checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu'))
torch.save(checkpoint, filename)
print(f"💾 保存检查点: {filename}")
def load_checkpoint(self, filename):
"""加载检查点"""
checkpoint = torch.load(filename, map_location='cpu')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.step = checkpoint['step']
self.best_loss = checkpoint['best_loss']
print(f"📂 加载检查点: {filename}")
print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}")
def test_simple_inference():
"""测试简单推理"""
print("\n🧪 测试简单推理...")
device = xm.xla_device()
# 创建模型
model = SimpleBrainToTextModel().to(device)
# 创建测试数据
batch_size = 4
seq_len = 50
test_input = torch.randn(batch_size, seq_len, 512, device=device)
# 推理
model.eval()
with torch.no_grad():
start_time = time.time()
output = model(test_input)
xm.mark_step()
xm.wait_device_ops()
inference_time = time.time() - start_time
print(f"✅ 推理完成!")
print(f" 输入形状: {test_input.shape}")
print(f" 输出形状: {output.shape}")
print(f" 推理时间: {inference_time:.4f}")
return True
def main():
"""主函数"""
print("=" * 60)
print("🧠 简单TPU大脑到文本模型训练")
print("=" * 60)
try:
# 检查TPU设备
device = xm.xla_device()
print(f"📱 使用设备: {device}")
# 创建模型
model = SimpleBrainToTextModel(
input_features=512,
hidden_size=256,
num_classes=41,
num_layers=3
).to(device)
# 创建训练器
trainer = SimpleTpuTrainer(model, device, learning_rate=0.001)
# 开始训练
results = trainer.train(
num_steps=1000,
eval_every=100,
save_every=500
)
# 保存最终模型
trainer.save_checkpoint("final_simple_model.pt")
# 测试推理
test_simple_inference()
print("\n🎉 所有测试完成!")
except Exception as e:
print(f"❌ 训练失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()