tpu
This commit is contained in:
@@ -126,8 +126,8 @@ class BrainToTextDataset(Dataset):
|
||||
try:
|
||||
g = f[f'trial_{t:04d}']
|
||||
|
||||
# Remove features is neccessary
|
||||
input_features = torch.from_numpy(g['input_features'][:]) # neural data
|
||||
# Remove features is neccessary
|
||||
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # neural data - convert to bf16 for TPU compatibility
|
||||
if self.feature_subset:
|
||||
input_features = input_features[:,self.feature_subset]
|
||||
|
||||
|
Reference in New Issue
Block a user