This commit is contained in:
Zchen
2025-10-15 00:18:05 +08:00
parent 4a3d3f35ec
commit 603bb12220

View File

@@ -117,8 +117,10 @@ class NoiseModel(nn.Module):
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
# GRU forward pass
output, hidden_states = self.gru(x, states)
# Disable autocast for GRU to avoid dtype mismatches on XLA
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
return output, hidden_states
@@ -225,8 +227,9 @@ class CleanSpeechModel(nn.Module):
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
# GRU forward pass
output, hidden_states = self.gru(x, states)
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
@@ -309,8 +312,9 @@ class NoisySpeechModel(nn.Module):
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
# GRU forward pass
output, hidden_states = self.gru(x, states)
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.gru(x, states)
# Classification
logits = self.out(output)
@@ -444,7 +448,9 @@ class TripleGRUDecoder(nn.Module):
states = states.to(clean_gru_dtype)
# GRU forward pass (skip preprocessing since input is already processed)
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
device_type = x_processed.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
# Classification
logits = self.clean_speech_model.out(output)
@@ -466,7 +472,9 @@ class TripleGRUDecoder(nn.Module):
states = states.to(noisy_gru_dtype)
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
device_type = x_processed.device.type
with torch.autocast(device_type=device_type, enabled=False):
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
# Classification
logits = self.noisy_speech_model.out(output)