Compare commits
2 Commits
eaa327267f
...
9288bde126
Author | SHA1 | Date | |
---|---|---|---|
![]() |
9288bde126 | ||
![]() |
06c4c6c267 |
@@ -68,7 +68,8 @@ class NoiseModel(nn.Module):
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
x = torch.bmm(x, day_weights) + day_biases
|
||||
# Ensure dtype consistency for mixed precision training
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
# XLA-friendly conditional dropout
|
||||
@@ -167,7 +168,8 @@ class CleanSpeechModel(nn.Module):
|
||||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
|
||||
|
||||
# Use bmm (batch matrix multiply) which is highly optimized in XLA
|
||||
x = torch.bmm(x, day_weights) + day_biases
|
||||
# Ensure dtype consistency for mixed precision training
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
if self.input_dropout > 0:
|
||||
|
Reference in New Issue
Block a user