tpu支持
This commit is contained in:
@@ -182,12 +182,15 @@ class BrainToTextDecoder_Trainer:
|
||||
random_seed = self.args['dataset']['seed'],
|
||||
feature_subset = feature_subset
|
||||
)
|
||||
# Use TPU-optimized dataloader settings if TPU is enabled
|
||||
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
|
||||
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||
shuffle = self.args['dataset']['loader_shuffle'],
|
||||
num_workers = self.args['dataset']['num_dataloader_workers'],
|
||||
pin_memory = True
|
||||
num_workers = num_workers,
|
||||
pin_memory = True
|
||||
)
|
||||
|
||||
# val dataset and dataloader
|
||||
@@ -204,9 +207,9 @@ class BrainToTextDecoder_Trainer:
|
||||
self.val_loader = DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||
shuffle = False,
|
||||
num_workers = 0,
|
||||
pin_memory = True
|
||||
shuffle = False,
|
||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||
pin_memory = True
|
||||
)
|
||||
|
||||
self.logger.info("Successfully initialized datasets")
|
||||
@@ -365,47 +368,52 @@ class BrainToTextDecoder_Trainer:
|
||||
return LambdaLR(optim, lr_lambdas, -1)
|
||||
|
||||
def load_model_checkpoint(self, load_path):
|
||||
'''
|
||||
Load a training checkpoint
|
||||
'''
|
||||
checkpoint = torch.load(load_path, weights_only = False) # checkpoint is just a dict
|
||||
Load a training checkpoint for distributed training
|
||||
'''
|
||||
# Load checkpoint on CPU first to avoid OOM issues
|
||||
checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
|
||||
|
||||
# Get unwrapped model for loading state dict
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
|
||||
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
# Send optimizer params back to GPU
|
||||
for state in self.optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(self.device)
|
||||
# Device handling is managed by Accelerator, no need to manually move to device
|
||||
|
||||
self.logger.info("Loaded model from checkpoint: " + load_path)
|
||||
|
||||
def save_model_checkpoint(self, save_path, PER, loss):
|
||||
'''
|
||||
Save a training checkpoint
|
||||
Save a training checkpoint using Accelerator for distributed training
|
||||
'''
|
||||
# Only save on main process to avoid conflicts
|
||||
if self.accelerator.is_main_process:
|
||||
# Unwrap model to get base model for saving
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
checkpoint = {
|
||||
'model_state_dict' : self.model.state_dict(),
|
||||
'optimizer_state_dict' : self.optimizer.state_dict(),
|
||||
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
|
||||
'val_PER' : PER,
|
||||
'val_loss' : loss
|
||||
}
|
||||
|
||||
torch.save(checkpoint, save_path)
|
||||
|
||||
self.logger.info("Saved model to checkpoint: " + save_path)
|
||||
checkpoint = {
|
||||
'model_state_dict' : unwrapped_model.state_dict(),
|
||||
'optimizer_state_dict' : self.optimizer.state_dict(),
|
||||
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
|
||||
'val_PER' : PER,
|
||||
'val_loss' : loss
|
||||
}
|
||||
|
||||
# Save the args file alongside the checkpoint
|
||||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||||
OmegaConf.save(config=self.args, f=f)
|
||||
torch.save(checkpoint, save_path)
|
||||
|
||||
self.logger.info("Saved model to checkpoint: " + save_path)
|
||||
|
||||
# Save the args file alongside the checkpoint
|
||||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||||
OmegaConf.save(config=self.args, f=f)
|
||||
|
||||
# Wait for all processes to complete checkpoint saving
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def create_attention_mask(self, sequence_lengths):
|
||||
|
||||
@@ -685,13 +693,14 @@ class BrainToTextDecoder_Trainer:
|
||||
if self.args['dataset']['dataset_probability_val'][d] == 1:
|
||||
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
|
||||
|
||||
for i, batch in enumerate(loader):
|
||||
for i, batch in enumerate(loader):
|
||||
|
||||
features = batch['input_features'].to(self.device)
|
||||
labels = batch['seq_class_ids'].to(self.device)
|
||||
n_time_steps = batch['n_time_steps'].to(self.device)
|
||||
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
|
||||
day_indicies = batch['day_indicies'].to(self.device)
|
||||
# Data is automatically moved to device by Accelerator
|
||||
features = batch['input_features']
|
||||
labels = batch['seq_class_ids']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
phone_seq_lens = batch['phone_seq_lens']
|
||||
day_indicies = batch['day_indicies']
|
||||
|
||||
# Determine if we should perform validation on this batch
|
||||
day = day_indicies[0].item()
|
||||
@@ -702,7 +711,7 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16):
|
||||
with self.accelerator.autocast():
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
@@ -768,4 +777,44 @@ class BrainToTextDecoder_Trainer:
|
||||
metrics['avg_PER'] = avg_PER.item()
|
||||
metrics['avg_loss'] = np.mean(metrics['losses'])
|
||||
|
||||
return metrics
|
||||
return metrics
|
||||
|
||||
def inference(self, features, day_indicies, n_time_steps, mode='inference'):
|
||||
'''
|
||||
TPU-compatible inference method for generating phoneme logits
|
||||
'''
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with self.accelerator.autocast():
|
||||
# Apply data transformations (no augmentation for inference)
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies, None, False, mode)
|
||||
|
||||
return logits
|
||||
|
||||
def inference_batch(self, batch, mode='inference'):
|
||||
'''
|
||||
TPU-compatible inference method for processing a full batch
|
||||
'''
|
||||
self.model.eval()
|
||||
|
||||
# Data is automatically moved to device by Accelerator
|
||||
features = batch['input_features']
|
||||
day_indicies = batch['day_indicies']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
|
||||
with torch.no_grad():
|
||||
with self.accelerator.autocast():
|
||||
# Apply data transformations (no augmentation for inference)
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
# Calculate adjusted sequence lengths for CTC
|
||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies, None, False, mode)
|
||||
|
||||
return logits, adjusted_lens
|
Reference in New Issue
Block a user