367 lines
10 KiB
Python
367 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
简单TPU模型训练和测试脚本
|
||
基于大脑到文本数据的简化版本,专门为TPU优化
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import numpy as np
|
||
from typing import Dict, Any, Tuple
|
||
|
||
# 设置XLA环境变量
|
||
os.environ['XLA_FLAGS'] = (
|
||
'--xla_cpu_multi_thread_eigen=true '
|
||
'--xla_cpu_enable_fast_math=true '
|
||
f'--xla_force_host_platform_device_count={os.cpu_count()}'
|
||
)
|
||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
|
||
os.environ['XLA_USE_BF16'] = '1'
|
||
|
||
import torch_xla.core.xla_model as xm
|
||
import torch_xla.distributed.parallel_loader as pl
|
||
|
||
|
||
class SimpleBrainToTextModel(nn.Module):
|
||
"""简化的大脑到文本模型 - TPU优化版本"""
|
||
|
||
def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3):
|
||
super().__init__()
|
||
|
||
# 输入处理层
|
||
self.input_proj = nn.Linear(input_features, hidden_size)
|
||
self.input_dropout = nn.Dropout(0.2)
|
||
|
||
# GRU层 - 使用较小的隐藏层以提高TPU效率
|
||
self.gru = nn.GRU(
|
||
input_size=hidden_size,
|
||
hidden_size=hidden_size,
|
||
num_layers=num_layers,
|
||
batch_first=True,
|
||
dropout=0.3 if num_layers > 1 else 0
|
||
)
|
||
|
||
# 输出层
|
||
self.output_proj = nn.Linear(hidden_size, num_classes)
|
||
|
||
# 初始化权重
|
||
self._init_weights()
|
||
|
||
def _init_weights(self):
|
||
"""初始化模型权重"""
|
||
for name, param in self.named_parameters():
|
||
if 'weight' in name:
|
||
if 'gru' in name:
|
||
nn.init.orthogonal_(param)
|
||
else:
|
||
nn.init.xavier_uniform_(param)
|
||
elif 'bias' in name:
|
||
nn.init.zeros_(param)
|
||
|
||
def forward(self, x):
|
||
"""
|
||
前向传播
|
||
Args:
|
||
x: (batch_size, seq_len, input_features)
|
||
Returns:
|
||
logits: (batch_size, seq_len, num_classes)
|
||
"""
|
||
# 输入投影
|
||
x = torch.relu(self.input_proj(x))
|
||
x = self.input_dropout(x)
|
||
|
||
# GRU处理
|
||
output, _ = self.gru(x)
|
||
|
||
# 输出投影
|
||
logits = self.output_proj(output)
|
||
|
||
return logits
|
||
|
||
|
||
class SimpleDataGenerator:
|
||
"""简单的数据生成器 - 模拟大脑信号数据"""
|
||
|
||
def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41):
|
||
self.batch_size = batch_size
|
||
self.seq_len = seq_len
|
||
self.input_features = input_features
|
||
self.num_classes = num_classes
|
||
|
||
def generate_batch(self, device):
|
||
"""生成一个批次的模拟数据"""
|
||
# 生成模拟的神经信号数据
|
||
features = torch.randn(
|
||
self.batch_size, self.seq_len, self.input_features,
|
||
device=device, dtype=torch.float32
|
||
)
|
||
|
||
# 生成模拟的标签(音素序列)
|
||
labels = torch.randint(
|
||
0, self.num_classes,
|
||
(self.batch_size, self.seq_len),
|
||
device=device
|
||
)
|
||
|
||
# 生成序列长度
|
||
seq_lengths = torch.randint(
|
||
self.seq_len // 2, self.seq_len + 1,
|
||
(self.batch_size,),
|
||
device=device
|
||
)
|
||
|
||
return {
|
||
'features': features,
|
||
'labels': labels,
|
||
'seq_lengths': seq_lengths
|
||
}
|
||
|
||
|
||
class SimpleTpuTrainer:
|
||
"""简单的TPU训练器"""
|
||
|
||
def __init__(self, model, device, learning_rate=0.001):
|
||
self.model = model
|
||
self.device = device
|
||
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||
self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
|
||
|
||
# 数据生成器
|
||
self.data_generator = SimpleDataGenerator()
|
||
|
||
# 训练统计
|
||
self.step = 0
|
||
self.best_loss = float('inf')
|
||
|
||
def train_step(self, batch):
|
||
"""单个训练步骤"""
|
||
self.model.train()
|
||
self.optimizer.zero_grad()
|
||
|
||
# 前向传播
|
||
features = batch['features']
|
||
labels = batch['labels']
|
||
|
||
logits = self.model(features)
|
||
|
||
# 计算损失 - 重新调整形状以适应CrossEntropyLoss
|
||
batch_size, seq_len, num_classes = logits.shape
|
||
loss = self.criterion(
|
||
logits.reshape(-1, num_classes),
|
||
labels.reshape(-1)
|
||
)
|
||
|
||
# 反向传播
|
||
loss.backward()
|
||
|
||
# 梯度裁剪
|
||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||
|
||
# 更新参数
|
||
self.optimizer.step()
|
||
|
||
return loss.item()
|
||
|
||
def evaluate_step(self, batch):
|
||
"""单个评估步骤"""
|
||
self.model.eval()
|
||
|
||
with torch.no_grad():
|
||
features = batch['features']
|
||
labels = batch['labels']
|
||
|
||
logits = self.model(features)
|
||
|
||
# 计算损失
|
||
batch_size, seq_len, num_classes = logits.shape
|
||
loss = self.criterion(
|
||
logits.reshape(-1, num_classes),
|
||
labels.reshape(-1)
|
||
)
|
||
|
||
# 计算准确率
|
||
predictions = torch.argmax(logits, dim=-1)
|
||
correct = (predictions == labels).float()
|
||
accuracy = correct.mean()
|
||
|
||
return loss.item(), accuracy.item()
|
||
|
||
def train(self, num_steps=1000, eval_every=100, save_every=500):
|
||
"""训练模型"""
|
||
print(f"🚀 开始TPU训练 - 设备: {self.device}")
|
||
print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}")
|
||
|
||
train_losses = []
|
||
eval_losses = []
|
||
eval_accuracies = []
|
||
|
||
start_time = time.time()
|
||
|
||
for step in range(num_steps):
|
||
# 生成训练数据
|
||
train_batch = self.data_generator.generate_batch(self.device)
|
||
|
||
# 训练步骤
|
||
train_loss = self.train_step(train_batch)
|
||
train_losses.append(train_loss)
|
||
|
||
# XLA同步
|
||
if step % 10 == 0: # 每10步同步一次以提高效率
|
||
xm.mark_step()
|
||
|
||
# 评估
|
||
if step % eval_every == 0:
|
||
eval_batch = self.data_generator.generate_batch(self.device)
|
||
eval_loss, eval_acc = self.evaluate_step(eval_batch)
|
||
eval_losses.append(eval_loss)
|
||
eval_accuracies.append(eval_acc)
|
||
|
||
# 同步XLA操作以获得准确的时间
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
|
||
current_time = time.time()
|
||
elapsed = current_time - start_time
|
||
|
||
print(f"步骤 {step:4d}/{num_steps} | "
|
||
f"训练损失: {train_loss:.4f} | "
|
||
f"验证损失: {eval_loss:.4f} | "
|
||
f"验证准确率: {eval_acc:.4f} | "
|
||
f"耗时: {elapsed:.1f}s")
|
||
|
||
# 保存最佳模型
|
||
if eval_loss < self.best_loss:
|
||
self.best_loss = eval_loss
|
||
print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}")
|
||
|
||
# 定期保存
|
||
if step > 0 and step % save_every == 0:
|
||
self.save_checkpoint(f"checkpoint_step_{step}.pt")
|
||
|
||
# 最终同步
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
|
||
total_time = time.time() - start_time
|
||
print(f"\n✅ 训练完成!")
|
||
print(f"⏱️ 总耗时: {total_time:.1f}秒")
|
||
print(f"🎯 最终训练损失: {train_losses[-1]:.4f}")
|
||
if eval_losses:
|
||
print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}")
|
||
print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}")
|
||
|
||
return {
|
||
'train_losses': train_losses,
|
||
'eval_losses': eval_losses,
|
||
'eval_accuracies': eval_accuracies,
|
||
'total_time': total_time
|
||
}
|
||
|
||
def save_checkpoint(self, filename):
|
||
"""保存检查点"""
|
||
checkpoint = {
|
||
'model_state_dict': self.model.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
'step': self.step,
|
||
'best_loss': self.best_loss,
|
||
}
|
||
|
||
# 在TPU上需要先移动到CPU再保存
|
||
if 'xla' in str(self.device):
|
||
checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu'))
|
||
|
||
torch.save(checkpoint, filename)
|
||
print(f"💾 保存检查点: {filename}")
|
||
|
||
def load_checkpoint(self, filename):
|
||
"""加载检查点"""
|
||
checkpoint = torch.load(filename, map_location='cpu')
|
||
|
||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
self.step = checkpoint['step']
|
||
self.best_loss = checkpoint['best_loss']
|
||
|
||
print(f"📂 加载检查点: {filename}")
|
||
print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}")
|
||
|
||
|
||
def test_simple_inference():
|
||
"""测试简单推理"""
|
||
print("\n🧪 测试简单推理...")
|
||
|
||
device = xm.xla_device()
|
||
|
||
# 创建模型
|
||
model = SimpleBrainToTextModel().to(device)
|
||
|
||
# 创建测试数据
|
||
batch_size = 4
|
||
seq_len = 50
|
||
test_input = torch.randn(batch_size, seq_len, 512, device=device)
|
||
|
||
# 推理
|
||
model.eval()
|
||
with torch.no_grad():
|
||
start_time = time.time()
|
||
output = model(test_input)
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
inference_time = time.time() - start_time
|
||
|
||
print(f"✅ 推理完成!")
|
||
print(f" 输入形状: {test_input.shape}")
|
||
print(f" 输出形状: {output.shape}")
|
||
print(f" 推理时间: {inference_time:.4f}秒")
|
||
|
||
return True
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("=" * 60)
|
||
print("🧠 简单TPU大脑到文本模型训练")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 检查TPU设备
|
||
device = xm.xla_device()
|
||
print(f"📱 使用设备: {device}")
|
||
|
||
# 创建模型
|
||
model = SimpleBrainToTextModel(
|
||
input_features=512,
|
||
hidden_size=256,
|
||
num_classes=41,
|
||
num_layers=3
|
||
).to(device)
|
||
|
||
# 创建训练器
|
||
trainer = SimpleTpuTrainer(model, device, learning_rate=0.001)
|
||
|
||
# 开始训练
|
||
results = trainer.train(
|
||
num_steps=1000,
|
||
eval_every=100,
|
||
save_every=500
|
||
)
|
||
|
||
# 保存最终模型
|
||
trainer.save_checkpoint("final_simple_model.pt")
|
||
|
||
# 测试推理
|
||
test_simple_inference()
|
||
|
||
print("\n🎉 所有测试完成!")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 训练失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |