diff --git a/CLAUDE.md b/CLAUDE.md index 98fa499..2958916 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -166,8 +166,8 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtyp ``` #### 5. Mixed Precision Dtype Consistency -**Problem**: Mixed precision training causes dtype mismatches in bmm operations and adversarial residual connections -**Solution**: Ensure all operands match input tensor dtype +**Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations +**Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations ```python # Error: f32[32,7168] vs bf16[32,7168] in mixed precision training @@ -176,6 +176,18 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # Fix 2: Ensure dtype consistency in adversarial training residual connections denoised_input = x_processed - noise_output.to(x_processed.dtype) + +# Fix 3: Preserve dtype through patch processing operations +if self.patch_size > 0: + original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility + x = x.unsqueeze(1) + x = x.permute(0, 3, 1, 2) + x_unfold = x.unfold(3, self.patch_size, self.patch_stride) + x_unfold = x_unfold.squeeze(2) + x_unfold = x_unfold.permute(0, 2, 3, 1) + x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) + # Ensure dtype consistency after patch processing operations + x = x.to(original_dtype) ``` #### 3. Hidden State Initialization diff --git a/model_training_nnn/rnn_model.py b/model_training_nnn/rnn_model.py index 0138a4f..2b56b13 100644 --- a/model_training_nnn/rnn_model.py +++ b/model_training_nnn/rnn_model.py @@ -94,14 +94,17 @@ class NoiseModel(nn.Module): if self.input_dropout > 0: x = self.day_layer_dropout(x) - # Apply patch processing if enabled (keep conditional for now, optimize later) + # Apply patch processing if enabled with dtype preservation for mixed precision training if self.patch_size > 0: + original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility x = x.unsqueeze(1) x = x.permute(0, 3, 1, 2) x_unfold = x.unfold(3, self.patch_size, self.patch_stride) x_unfold = x_unfold.squeeze(2) x_unfold = x_unfold.permute(0, 2, 3, 1) x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) + # Ensure dtype consistency after patch processing operations + x = x.to(original_dtype) # XLA-friendly hidden state initialization - avoid dynamic allocation if states is None: @@ -193,14 +196,17 @@ class CleanSpeechModel(nn.Module): if self.input_dropout > 0: x = self.day_layer_dropout(x) - # Apply patch processing if enabled + # Apply patch processing if enabled with dtype preservation for mixed precision training if self.patch_size > 0: + original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility x = x.unsqueeze(1) x = x.permute(0, 3, 1, 2) x_unfold = x.unfold(3, self.patch_size, self.patch_stride) x_unfold = x_unfold.squeeze(2) x_unfold = x_unfold.permute(0, 2, 3, 1) x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) + # Ensure dtype consistency after patch processing operations + x = x.to(original_dtype) # XLA-friendly hidden state initialization if states is None: @@ -383,14 +389,17 @@ class TripleGRUDecoder(nn.Module): x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) x_processed = self.clean_speech_model.day_layer_activation(x_processed) - # Apply patch processing if enabled + # Apply patch processing if enabled with dtype preservation for mixed precision training if self.patch_size > 0: + original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility x_processed = x_processed.unsqueeze(1) x_processed = x_processed.permute(0, 3, 1, 2) x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride) x_unfold = x_unfold.squeeze(2) x_unfold = x_unfold.permute(0, 2, 3, 1) x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1) + # Ensure dtype consistency after patch processing operations + x_processed = x_processed.to(original_dtype) return x_processed