From 989ba676181c3f9a0498b28c958c5e9aca56fd37 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:11:54 +0800 Subject: [PATCH] tpu --- CLAUDE.md | 9 ++++++--- model_training_nnn/rnn_model.py | 6 ++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 98728aa..98fa499 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -166,13 +166,16 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtyp ``` #### 5. Mixed Precision Dtype Consistency -**Problem**: Mixed precision training causes dtype mismatches in bmm operations +**Problem**: Mixed precision training causes dtype mismatches in bmm operations and adversarial residual connections **Solution**: Ensure all operands match input tensor dtype ```python # Error: f32[32,7168] vs bf16[32,7168] in mixed precision training -# Fix: Add dtype conversions for all bmm operands +# Fix 1: Add dtype conversions for all bmm operands x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) + +# Fix 2: Ensure dtype consistency in adversarial training residual connections +denoised_input = x_processed - noise_output.to(x_processed.dtype) ``` #### 3. Hidden State Initialization @@ -212,7 +215,7 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return - `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency - `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency - `NoisySpeechModel.forward()`: Hidden state optimization - - `TripleGRUDecoder.forward()`: Complex return values → tuple returns + - `TripleGRUDecoder.forward()`: Complex return values → tuple returns + adversarial residual connection dtype fix - `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency ### Benefits of XLA Optimizations diff --git a/model_training_nnn/rnn_model.py b/model_training_nnn/rnn_model.py index 625a22d..0138a4f 100644 --- a/model_training_nnn/rnn_model.py +++ b/model_training_nnn/rnn_model.py @@ -447,7 +447,8 @@ class TripleGRUDecoder(nn.Module): x_processed = self._apply_preprocessing(x, day_idx) # 3. Clean speech model processes denoised signal - denoised_input = x_processed - noise_output # Residual connection in processed space + # Ensure dtype consistency for mixed precision training in residual connection + denoised_input = x_processed - noise_output.to(x_processed.dtype) # Residual connection in processed space # Clean speech model will apply its own preprocessing, so we pass the denoised processed data # But we need to reverse the preprocessing first, then let clean model do its own # Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing @@ -476,7 +477,8 @@ class TripleGRUDecoder(nn.Module): x_processed = self._apply_preprocessing(x, day_idx) # 3. Process denoised signal - denoised_input = x_processed - noise_output + # Ensure dtype consistency for mixed precision training in residual connection + denoised_input = x_processed - noise_output.to(x_processed.dtype) clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, states['clean'] if states else None)