This commit is contained in:
Zchen
2025-10-14 23:22:59 +08:00
parent 989ba67618
commit cd52ba51ba
2 changed files with 26 additions and 5 deletions

View File

@@ -94,14 +94,17 @@ class NoiseModel(nn.Module):
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled (keep conditional for now, optimize later)
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
# XLA-friendly hidden state initialization - avoid dynamic allocation
if states is None:
@@ -193,14 +196,17 @@ class CleanSpeechModel(nn.Module):
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
# XLA-friendly hidden state initialization
if states is None:
@@ -383,14 +389,17 @@ class TripleGRUDecoder(nn.Module):
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
# Apply patch processing if enabled
# Apply patch processing if enabled with dtype preservation for mixed precision training
if self.patch_size > 0:
original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
x_processed = x_processed.unsqueeze(1)
x_processed = x_processed.permute(0, 3, 1, 2)
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x_processed = x_processed.to(original_dtype)
return x_processed