tpu
This commit is contained in:
@@ -117,7 +117,9 @@ class NoiseModel(nn.Module):
|
|||||||
if states.dtype != gru_dtype:
|
if states.dtype != gru_dtype:
|
||||||
states = states.to(gru_dtype)
|
states = states.to(gru_dtype)
|
||||||
|
|
||||||
# GRU forward pass
|
# Disable autocast for GRU to avoid dtype mismatches on XLA
|
||||||
|
device_type = x.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
output, hidden_states = self.gru(x, states)
|
output, hidden_states = self.gru(x, states)
|
||||||
|
|
||||||
return output, hidden_states
|
return output, hidden_states
|
||||||
@@ -225,7 +227,8 @@ class CleanSpeechModel(nn.Module):
|
|||||||
if states.dtype != gru_dtype:
|
if states.dtype != gru_dtype:
|
||||||
states = states.to(gru_dtype)
|
states = states.to(gru_dtype)
|
||||||
|
|
||||||
# GRU forward pass
|
device_type = x.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
output, hidden_states = self.gru(x, states)
|
output, hidden_states = self.gru(x, states)
|
||||||
|
|
||||||
# Classification
|
# Classification
|
||||||
@@ -309,7 +312,8 @@ class NoisySpeechModel(nn.Module):
|
|||||||
if states.dtype != gru_dtype:
|
if states.dtype != gru_dtype:
|
||||||
states = states.to(gru_dtype)
|
states = states.to(gru_dtype)
|
||||||
|
|
||||||
# GRU forward pass
|
device_type = x.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
output, hidden_states = self.gru(x, states)
|
output, hidden_states = self.gru(x, states)
|
||||||
|
|
||||||
# Classification
|
# Classification
|
||||||
@@ -444,6 +448,8 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
states = states.to(clean_gru_dtype)
|
states = states.to(clean_gru_dtype)
|
||||||
|
|
||||||
# GRU forward pass (skip preprocessing since input is already processed)
|
# GRU forward pass (skip preprocessing since input is already processed)
|
||||||
|
device_type = x_processed.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
||||||
|
|
||||||
# Classification
|
# Classification
|
||||||
@@ -466,6 +472,8 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
states = states.to(noisy_gru_dtype)
|
states = states.to(noisy_gru_dtype)
|
||||||
|
|
||||||
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
||||||
|
device_type = x_processed.device.type
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
||||||
|
|
||||||
# Classification
|
# Classification
|
||||||
|
Reference in New Issue
Block a user