This commit is contained in:
Zchen
2025-10-14 23:35:42 +08:00
parent cd52ba51ba
commit 4b6d680283
2 changed files with 118 additions and 23 deletions

113
CLAUDE.md
View File

@@ -165,19 +165,24 @@ x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency
```
#### 5. Mixed Precision Dtype Consistency
**Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations
**Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations
#### 5. Mixed Precision Dtype Consistency (Comprehensive Fix)
**Problem**: Mixed precision training causes dtype mismatches throughout the adversarial training pipeline
**Error**: `Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]`
**Root Cause Analysis**: The error occurred at dimension 7168 = 512 * 14, indicating patch processing with patch_size=14. The dtype mismatch cascaded through multiple layers:
1. Initial bmm operations in day-specific transformations
2. Adversarial training residual connections between models
3. Patch processing operations (unfold, permute, reshape)
4. Gradient Reversal Layer (GRL) operations
5. Hidden state initialization in adversarial training helper methods
**Comprehensive Solution**: Implement dtype consistency across the entire adversarial training data flow:
```python
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
# Fix 1: Add dtype conversions for all bmm operands
# Fix 1: Basic bmm operations with dtype consistency
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
# Fix 2: Ensure dtype consistency in adversarial training residual connections
denoised_input = x_processed - noise_output.to(x_processed.dtype)
# Fix 3: Preserve dtype through patch processing operations
# Fix 2: Patch processing with explicit dtype preservation
if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1)
@@ -188,8 +193,37 @@ if self.patch_size > 0:
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
# Fix 3: Adversarial training residual connections
noise_output = noise_output.to(x_processed.dtype)
denoised_input = x_processed - noise_output
# Fix 4: Gradient Reversal Layer dtype handling
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda else noise_output
# Ensure dtype consistency after GRL (preserves input dtype but explicit check)
noisy_input = noisy_input.to(x_processed.dtype)
# Fix 5: Hidden state dtype consistency in helper methods
# In _clean_forward_with_processed_input:
if states is None:
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype)
# In _noisy_forward_with_processed_input:
if states is None:
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype)
```
**Key Implementation Details**:
- **GradientReversalFn**: Preserves input dtype automatically (identity forward, gradient reversal backward)
- **Patch Processing**: Explicit dtype preservation prevents unfold operations from changing precision
- **Residual Connections**: All tensor arithmetic operations ensure matching dtypes
- **Helper Methods**: Hidden state initialization matches processed input dtype
- **Data Flow**: NoiseModel → GRL → NoisySpeechModel maintains dtype consistency throughout
#### 3. Hidden State Initialization
**Problem**: Dynamic batch size allocation causes XLA recompilation
**Solution**: Use static shapes and avoid x.shape[0] in tensor creation
@@ -223,12 +257,25 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return
### Files Modified for XLA Optimization
- **`model_training_nnn/rnn_model.py`**: All three models optimized
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency
- `NoisySpeechModel.forward()`: Hidden state optimization
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns + adversarial residual connection dtype fix
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency
- **`model_training_nnn/rnn_model.py`**: Comprehensive XLA optimization with dtype consistency
- **`GradientReversalFn`**: Added adversarial training gradient reversal layer
- **`NoiseModel.forward()`**: Dynamic indexing → static gather operations + comprehensive dtype consistency + patch processing dtype preservation
- **`CleanSpeechModel.forward()`**: Same optimizations + bmm for matrix ops + comprehensive dtype consistency + patch processing dtype preservation
- **`NoisySpeechModel.forward()`**: Hidden state optimization (no day layers, simplified)
- **`TripleGRUDecoder.forward()`**: Complex return values → tuple returns + comprehensive adversarial training dtype fixes + residual connection dtype consistency + GRL dtype handling
- **`TripleGRUDecoder._apply_preprocessing()`**: Static preprocessing operations + dtype consistency + patch processing dtype preservation
- **`TripleGRUDecoder._clean_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
- **`TripleGRUDecoder._noisy_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
**Specific Dtype Consistency Fixes Applied**:
1. **Basic Operations**: All `torch.bmm()` operations with `.to(x.dtype)` conversions
2. **Patch Processing**: Explicit dtype preservation through unfold/permute/reshape operations
3. **Adversarial Training**: Residual connections with `.to(x_processed.dtype)` conversions
4. **Gradient Reversal**: Dtype consistency after GRL operations
5. **Hidden States**: All hidden state initialization with `.to(x_processed.dtype)` conversions
6. **Data Flow**: End-to-end dtype consistency in NoiseModel → GRL → NoisySpeechModel pipeline
**Error Resolved**: `f32[32,7168] vs bf16[32,7168]` dtype mismatch in mixed precision TPU training
### Benefits of XLA Optimizations
@@ -252,5 +299,41 @@ Created test scripts to verify model consistency:
- Backward compatibility with existing training scripts is maintained
- TPU training should now show improved compilation times and memory efficiency
### Troubleshooting Dtype Issues in Mixed Precision Training
**Common Error Pattern**:
```
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[X,Y], argument shape: bf16[X,Y]
```
**Diagnosis Steps**:
1. **Identify Operation**: Look at the tensor dimensions to identify which operation is failing
- `7168 = 512 * 14`: Patch processing operation with patch_size=14
- `512`: Basic neural features
- Other patterns may indicate different operations
2. **Check Data Flow**: Trace the tensor through the adversarial training pipeline
- Input → NoiseModel → residual connection → CleanSpeechModel
- Input → NoiseModel → GRL → NoisySpeechModel
3. **Verify Dtype Consistency**: Ensure all operations maintain input dtype
- Use `.to(x.dtype)` for all operand tensors
- Preserve dtype through complex operations (unfold, permute, reshape)
- Match hidden state dtype to input tensor dtype
**Quick Fix Template**:
```python
# For any tensor operation between tensors a and b:
result = operation(a, b.to(a.dtype))
# For complex operations that might change dtype:
original_dtype = tensor.dtype
tensor = complex_operation(tensor)
tensor = tensor.to(original_dtype)
# For hidden state initialization:
states = states.to(input_tensor.dtype)
```
## Competition Context
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.