tpu
This commit is contained in:
@@ -166,13 +166,16 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtyp
|
|||||||
```
|
```
|
||||||
|
|
||||||
#### 5. Mixed Precision Dtype Consistency
|
#### 5. Mixed Precision Dtype Consistency
|
||||||
**Problem**: Mixed precision training causes dtype mismatches in bmm operations
|
**Problem**: Mixed precision training causes dtype mismatches in bmm operations and adversarial residual connections
|
||||||
**Solution**: Ensure all operands match input tensor dtype
|
**Solution**: Ensure all operands match input tensor dtype
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
||||||
# Fix: Add dtype conversions for all bmm operands
|
# Fix 1: Add dtype conversions for all bmm operands
|
||||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
|
|
||||||
|
# Fix 2: Ensure dtype consistency in adversarial training residual connections
|
||||||
|
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 3. Hidden State Initialization
|
#### 3. Hidden State Initialization
|
||||||
@@ -212,7 +215,7 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return
|
|||||||
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency
|
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency
|
||||||
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency
|
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency
|
||||||
- `NoisySpeechModel.forward()`: Hidden state optimization
|
- `NoisySpeechModel.forward()`: Hidden state optimization
|
||||||
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns
|
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns + adversarial residual connection dtype fix
|
||||||
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency
|
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency
|
||||||
|
|
||||||
### Benefits of XLA Optimizations
|
### Benefits of XLA Optimizations
|
||||||
|
@@ -447,7 +447,8 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
x_processed = self._apply_preprocessing(x, day_idx)
|
x_processed = self._apply_preprocessing(x, day_idx)
|
||||||
|
|
||||||
# 3. Clean speech model processes denoised signal
|
# 3. Clean speech model processes denoised signal
|
||||||
denoised_input = x_processed - noise_output # Residual connection in processed space
|
# Ensure dtype consistency for mixed precision training in residual connection
|
||||||
|
denoised_input = x_processed - noise_output.to(x_processed.dtype) # Residual connection in processed space
|
||||||
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
|
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
|
||||||
# But we need to reverse the preprocessing first, then let clean model do its own
|
# But we need to reverse the preprocessing first, then let clean model do its own
|
||||||
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
|
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
|
||||||
@@ -476,7 +477,8 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
x_processed = self._apply_preprocessing(x, day_idx)
|
x_processed = self._apply_preprocessing(x, day_idx)
|
||||||
|
|
||||||
# 3. Process denoised signal
|
# 3. Process denoised signal
|
||||||
denoised_input = x_processed - noise_output
|
# Ensure dtype consistency for mixed precision training in residual connection
|
||||||
|
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
||||||
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||||
states['clean'] if states else None)
|
states['clean'] if states else None)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user