This commit is contained in:
Zchen
2025-10-12 23:36:58 +08:00
parent 0d2a0aa8fa
commit 06c4c6c267
2 changed files with 5 additions and 3 deletions

View File

@@ -68,7 +68,8 @@ class NoiseModel(nn.Module):
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
x = torch.bmm(x, day_weights) + day_biases
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x = self.day_layer_activation(x)
# XLA-friendly conditional dropout
@@ -167,7 +168,8 @@ class CleanSpeechModel(nn.Module):
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
x = torch.bmm(x, day_weights) + day_biases
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x = self.day_layer_activation(x)
if self.input_dropout > 0: