This commit is contained in:
Zchen
2025-10-14 23:11:54 +08:00
parent f67ed2b820
commit 989ba67618
2 changed files with 10 additions and 5 deletions

View File

@@ -166,13 +166,16 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtyp
```
#### 5. Mixed Precision Dtype Consistency
**Problem**: Mixed precision training causes dtype mismatches in bmm operations
**Problem**: Mixed precision training causes dtype mismatches in bmm operations and adversarial residual connections
**Solution**: Ensure all operands match input tensor dtype
```python
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
# Fix: Add dtype conversions for all bmm operands
# Fix 1: Add dtype conversions for all bmm operands
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
# Fix 2: Ensure dtype consistency in adversarial training residual connections
denoised_input = x_processed - noise_output.to(x_processed.dtype)
```
#### 3. Hidden State Initialization
@@ -212,7 +215,7 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency
- `NoisySpeechModel.forward()`: Hidden state optimization
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns + adversarial residual connection dtype fix
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency
### Benefits of XLA Optimizations