tpu支持

This commit is contained in:
Zchen
2025-10-12 15:31:45 +08:00
parent 3892f13da8
commit 530b7c9d3d
6 changed files with 472 additions and 42 deletions

View File

@@ -182,12 +182,15 @@ class BrainToTextDecoder_Trainer:
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Use TPU-optimized dataloader settings if TPU is enabled
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = self.args['dataset']['num_dataloader_workers'],
pin_memory = True
num_workers = num_workers,
pin_memory = True
)
# val dataset and dataloader
@@ -204,9 +207,9 @@ class BrainToTextDecoder_Trainer:
self.val_loader = DataLoader(
self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
shuffle = False,
num_workers = 0,
pin_memory = True
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True
)
self.logger.info("Successfully initialized datasets")
@@ -365,47 +368,52 @@ class BrainToTextDecoder_Trainer:
return LambdaLR(optim, lr_lambdas, -1)
def load_model_checkpoint(self, load_path):
'''
Load a training checkpoint
'''
checkpoint = torch.load(load_path, weights_only = False) # checkpoint is just a dict
Load a training checkpoint for distributed training
'''
# Load checkpoint on CPU first to avoid OOM issues
checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
# Get unwrapped model for loading state dict
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
self.model.to(self.device)
# Send optimizer params back to GPU
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
# Device handling is managed by Accelerator, no need to manually move to device
self.logger.info("Loaded model from checkpoint: " + load_path)
def save_model_checkpoint(self, save_path, PER, loss):
'''
Save a training checkpoint
Save a training checkpoint using Accelerator for distributed training
'''
# Only save on main process to avoid conflicts
if self.accelerator.is_main_process:
# Unwrap model to get base model for saving
unwrapped_model = self.accelerator.unwrap_model(self.model)
checkpoint = {
'model_state_dict' : self.model.state_dict(),
'optimizer_state_dict' : self.optimizer.state_dict(),
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
'val_PER' : PER,
'val_loss' : loss
}
torch.save(checkpoint, save_path)
self.logger.info("Saved model to checkpoint: " + save_path)
checkpoint = {
'model_state_dict' : unwrapped_model.state_dict(),
'optimizer_state_dict' : self.optimizer.state_dict(),
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
'val_PER' : PER,
'val_loss' : loss
}
# Save the args file alongside the checkpoint
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
torch.save(checkpoint, save_path)
self.logger.info("Saved model to checkpoint: " + save_path)
# Save the args file alongside the checkpoint
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
# Wait for all processes to complete checkpoint saving
self.accelerator.wait_for_everyone()
def create_attention_mask(self, sequence_lengths):
@@ -685,13 +693,14 @@ class BrainToTextDecoder_Trainer:
if self.args['dataset']['dataset_probability_val'][d] == 1:
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
for i, batch in enumerate(loader):
for i, batch in enumerate(loader):
features = batch['input_features'].to(self.device)
labels = batch['seq_class_ids'].to(self.device)
n_time_steps = batch['n_time_steps'].to(self.device)
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
# Data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Determine if we should perform validation on this batch
day = day_indicies[0].item()
@@ -702,7 +711,7 @@ class BrainToTextDecoder_Trainer:
with torch.no_grad():
with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16):
with self.accelerator.autocast():
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
@@ -768,4 +777,44 @@ class BrainToTextDecoder_Trainer:
metrics['avg_PER'] = avg_PER.item()
metrics['avg_loss'] = np.mean(metrics['losses'])
return metrics
return metrics
def inference(self, features, day_indicies, n_time_steps, mode='inference'):
'''
TPU-compatible inference method for generating phoneme logits
'''
self.model.eval()
with torch.no_grad():
with self.accelerator.autocast():
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
return logits
def inference_batch(self, batch, mode='inference'):
'''
TPU-compatible inference method for processing a full batch
'''
self.model.eval()
# Data is automatically moved to device by Accelerator
features = batch['input_features']
day_indicies = batch['day_indicies']
n_time_steps = batch['n_time_steps']
with torch.no_grad():
with self.accelerator.autocast():
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Calculate adjusted sequence lengths for CTC
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
return logits, adjusted_lens