修复B模型未启用的错误
This commit is contained in:
18
CLAUDE.md
18
CLAUDE.md
@@ -162,7 +162,17 @@ day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
||||
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||||
|
||||
# After (XLA-optimized):
|
||||
x = torch.bmm(x, day_weights) + day_biases # bmm is highly optimized in XLA
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency
|
||||
```
|
||||
|
||||
#### 5. Mixed Precision Dtype Consistency
|
||||
**Problem**: Mixed precision training causes dtype mismatches in bmm operations
|
||||
**Solution**: Ensure all operands match input tensor dtype
|
||||
|
||||
```python
|
||||
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
||||
# Fix: Add dtype conversions for all bmm operands
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
```
|
||||
|
||||
#### 3. Hidden State Initialization
|
||||
@@ -199,11 +209,11 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return
|
||||
### Files Modified for XLA Optimization
|
||||
|
||||
- **`model_training_nnn/rnn_model.py`**: All three models optimized
|
||||
- `NoiseModel.forward()`: Dynamic indexing → static gather operations
|
||||
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops
|
||||
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency
|
||||
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency
|
||||
- `NoisySpeechModel.forward()`: Hidden state optimization
|
||||
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns
|
||||
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations
|
||||
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency
|
||||
|
||||
### Benefits of XLA Optimizations
|
||||
|
||||
|
Reference in New Issue
Block a user