Compare commits

...

2 Commits

Author SHA1 Message Date
Zchen
9288bde126 Merge branch 'dev2' of http://ecs.zchens.cn:3000/zchen/b2txt25 into dev2 2025-10-14 13:31:28 +08:00
Zchen
06c4c6c267 tpu 2025-10-14 13:31:26 +08:00

View File

@@ -68,7 +68,8 @@ class NoiseModel(nn.Module):
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
x = torch.bmm(x, day_weights) + day_biases
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x = self.day_layer_activation(x)
# XLA-friendly conditional dropout
@@ -167,7 +168,8 @@ class CleanSpeechModel(nn.Module):
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
x = torch.bmm(x, day_weights) + day_biases
# Ensure dtype consistency for mixed precision training
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x = self.day_layer_activation(x)
if self.input_dropout > 0: