tpu
This commit is contained in:
@@ -145,8 +145,8 @@ class BrainToTextDataset(Dataset):
|
||||
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
|
||||
continue
|
||||
|
||||
# Pad data to form a cohesive batch
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
|
||||
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||||
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
|
||||
|
||||
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
|
||||
|
Reference in New Issue
Block a user