Files
b2txt25/model_training_nnn_tpu/simple_tpu_model.py
2025-10-15 15:14:01 +08:00

367 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
简单TPU模型训练和测试脚本
基于大脑到文本数据的简化版本专门为TPU优化
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, Any, Tuple
# 设置XLA环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count())
os.environ['XLA_USE_BF16'] = '1'
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
class SimpleBrainToTextModel(nn.Module):
"""简化的大脑到文本模型 - TPU优化版本"""
def __init__(self, input_features=512, hidden_size=256, num_classes=41, num_layers=3):
super().__init__()
# 输入处理层
self.input_proj = nn.Linear(input_features, hidden_size)
self.input_dropout = nn.Dropout(0.2)
# GRU层 - 使用较小的隐藏层以提高TPU效率
self.gru = nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=0.3 if num_layers > 1 else 0
)
# 输出层
self.output_proj = nn.Linear(hidden_size, num_classes)
# 初始化权重
self._init_weights()
def _init_weights(self):
"""初始化模型权重"""
for name, param in self.named_parameters():
if 'weight' in name:
if 'gru' in name:
nn.init.orthogonal_(param)
else:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)
def forward(self, x):
"""
前向传播
Args:
x: (batch_size, seq_len, input_features)
Returns:
logits: (batch_size, seq_len, num_classes)
"""
# 输入投影
x = torch.relu(self.input_proj(x))
x = self.input_dropout(x)
# GRU处理
output, _ = self.gru(x)
# 输出投影
logits = self.output_proj(output)
return logits
class SimpleDataGenerator:
"""简单的数据生成器 - 模拟大脑信号数据"""
def __init__(self, batch_size=16, seq_len=100, input_features=512, num_classes=41):
self.batch_size = batch_size
self.seq_len = seq_len
self.input_features = input_features
self.num_classes = num_classes
def generate_batch(self, device):
"""生成一个批次的模拟数据"""
# 生成模拟的神经信号数据
features = torch.randn(
self.batch_size, self.seq_len, self.input_features,
device=device, dtype=torch.float32
)
# 生成模拟的标签(音素序列)
labels = torch.randint(
0, self.num_classes,
(self.batch_size, self.seq_len),
device=device
)
# 生成序列长度
seq_lengths = torch.randint(
self.seq_len // 2, self.seq_len + 1,
(self.batch_size,),
device=device
)
return {
'features': features,
'labels': labels,
'seq_lengths': seq_lengths
}
class SimpleTpuTrainer:
"""简单的TPU训练器"""
def __init__(self, model, device, learning_rate=0.001):
self.model = model
self.device = device
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
# 数据生成器
self.data_generator = SimpleDataGenerator()
# 训练统计
self.step = 0
self.best_loss = float('inf')
def train_step(self, batch):
"""单个训练步骤"""
self.model.train()
self.optimizer.zero_grad()
# 前向传播
features = batch['features']
labels = batch['labels']
logits = self.model(features)
# 计算损失 - 重新调整形状以适应CrossEntropyLoss
batch_size, seq_len, num_classes = logits.shape
loss = self.criterion(
logits.reshape(-1, num_classes),
labels.reshape(-1)
)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.optimizer.step()
return loss.item()
def evaluate_step(self, batch):
"""单个评估步骤"""
self.model.eval()
with torch.no_grad():
features = batch['features']
labels = batch['labels']
logits = self.model(features)
# 计算损失
batch_size, seq_len, num_classes = logits.shape
loss = self.criterion(
logits.reshape(-1, num_classes),
labels.reshape(-1)
)
# 计算准确率
predictions = torch.argmax(logits, dim=-1)
correct = (predictions == labels).float()
accuracy = correct.mean()
return loss.item(), accuracy.item()
def train(self, num_steps=1000, eval_every=100, save_every=500):
"""训练模型"""
print(f"🚀 开始TPU训练 - 设备: {self.device}")
print(f"📊 模型参数: {sum(p.numel() for p in self.model.parameters()):,}")
train_losses = []
eval_losses = []
eval_accuracies = []
start_time = time.time()
for step in range(num_steps):
# 生成训练数据
train_batch = self.data_generator.generate_batch(self.device)
# 训练步骤
train_loss = self.train_step(train_batch)
train_losses.append(train_loss)
# XLA同步
if step % 10 == 0: # 每10步同步一次以提高效率
xm.mark_step()
# 评估
if step % eval_every == 0:
eval_batch = self.data_generator.generate_batch(self.device)
eval_loss, eval_acc = self.evaluate_step(eval_batch)
eval_losses.append(eval_loss)
eval_accuracies.append(eval_acc)
# 同步XLA操作以获得准确的时间
xm.mark_step()
xm.wait_device_ops()
current_time = time.time()
elapsed = current_time - start_time
print(f"步骤 {step:4d}/{num_steps} | "
f"训练损失: {train_loss:.4f} | "
f"验证损失: {eval_loss:.4f} | "
f"验证准确率: {eval_acc:.4f} | "
f"耗时: {elapsed:.1f}s")
# 保存最佳模型
if eval_loss < self.best_loss:
self.best_loss = eval_loss
print(f"🎯 新的最佳模型! 损失: {eval_loss:.4f}")
# 定期保存
if step > 0 and step % save_every == 0:
self.save_checkpoint(f"checkpoint_step_{step}.pt")
# 最终同步
xm.mark_step()
xm.wait_device_ops()
total_time = time.time() - start_time
print(f"\n✅ 训练完成!")
print(f"⏱️ 总耗时: {total_time:.1f}")
print(f"🎯 最终训练损失: {train_losses[-1]:.4f}")
if eval_losses:
print(f"🎯 最终验证损失: {eval_losses[-1]:.4f}")
print(f"🎯 最终验证准确率: {eval_accuracies[-1]:.4f}")
return {
'train_losses': train_losses,
'eval_losses': eval_losses,
'eval_accuracies': eval_accuracies,
'total_time': total_time
}
def save_checkpoint(self, filename):
"""保存检查点"""
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'step': self.step,
'best_loss': self.best_loss,
}
# 在TPU上需要先移动到CPU再保存
if 'xla' in str(self.device):
checkpoint = xm.send_cpu_data_to_device(checkpoint, torch.device('cpu'))
torch.save(checkpoint, filename)
print(f"💾 保存检查点: {filename}")
def load_checkpoint(self, filename):
"""加载检查点"""
checkpoint = torch.load(filename, map_location='cpu')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.step = checkpoint['step']
self.best_loss = checkpoint['best_loss']
print(f"📂 加载检查点: {filename}")
print(f" 步骤: {self.step}, 最佳损失: {self.best_loss:.4f}")
def test_simple_inference():
"""测试简单推理"""
print("\n🧪 测试简单推理...")
device = xm.xla_device()
# 创建模型
model = SimpleBrainToTextModel().to(device)
# 创建测试数据
batch_size = 4
seq_len = 50
test_input = torch.randn(batch_size, seq_len, 512, device=device)
# 推理
model.eval()
with torch.no_grad():
start_time = time.time()
output = model(test_input)
xm.mark_step()
xm.wait_device_ops()
inference_time = time.time() - start_time
print(f"✅ 推理完成!")
print(f" 输入形状: {test_input.shape}")
print(f" 输出形状: {output.shape}")
print(f" 推理时间: {inference_time:.4f}")
return True
def main():
"""主函数"""
print("=" * 60)
print("🧠 简单TPU大脑到文本模型训练")
print("=" * 60)
try:
# 检查TPU设备
device = xm.xla_device()
print(f"📱 使用设备: {device}")
# 创建模型
model = SimpleBrainToTextModel(
input_features=512,
hidden_size=256,
num_classes=41,
num_layers=3
).to(device)
# 创建训练器
trainer = SimpleTpuTrainer(model, device, learning_rate=0.001)
# 开始训练
results = trainer.train(
num_steps=1000,
eval_every=100,
save_every=500
)
# 保存最终模型
trainer.save_checkpoint("final_simple_model.pt")
# 测试推理
test_simple_inference()
print("\n🎉 所有测试完成!")
except Exception as e:
print(f"❌ 训练失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()