This commit is contained in:
Zchen
2025-10-12 21:47:30 +08:00
parent dfb3f7312c
commit 4dad570eea
2 changed files with 20 additions and 6 deletions

View File

@@ -145,8 +145,8 @@ class BrainToTextDataset(Dataset):
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
continue
# Pad data to form a cohesive batch
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])