tpu
This commit is contained in:
@@ -94,14 +94,17 @@ class NoiseModel(nn.Module):
|
||||
if self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x)
|
||||
|
||||
# Apply patch processing if enabled (keep conditional for now, optimize later)
|
||||
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||
if self.patch_size > 0:
|
||||
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||
x = x.unsqueeze(1)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x = x.to(original_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization - avoid dynamic allocation
|
||||
if states is None:
|
||||
@@ -193,14 +196,17 @@ class CleanSpeechModel(nn.Module):
|
||||
if self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||
if self.patch_size > 0:
|
||||
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||
x = x.unsqueeze(1)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x = x.to(original_dtype)
|
||||
|
||||
# XLA-friendly hidden state initialization
|
||||
if states is None:
|
||||
@@ -383,14 +389,17 @@ class TripleGRUDecoder(nn.Module):
|
||||
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||
|
||||
# Apply patch processing if enabled
|
||||
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||
if self.patch_size > 0:
|
||||
original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||
x_processed = x_processed.unsqueeze(1)
|
||||
x_processed = x_processed.permute(0, 3, 1, 2)
|
||||
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x_processed = x_processed.to(original_dtype)
|
||||
|
||||
return x_processed
|
||||
|
||||
|
Reference in New Issue
Block a user