tpu
This commit is contained in:
113
CLAUDE.md
113
CLAUDE.md
@@ -165,19 +165,24 @@ x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency
|
||||
```
|
||||
|
||||
#### 5. Mixed Precision Dtype Consistency
|
||||
**Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations
|
||||
**Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations
|
||||
#### 5. Mixed Precision Dtype Consistency (Comprehensive Fix)
|
||||
**Problem**: Mixed precision training causes dtype mismatches throughout the adversarial training pipeline
|
||||
**Error**: `Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]`
|
||||
|
||||
**Root Cause Analysis**: The error occurred at dimension 7168 = 512 * 14, indicating patch processing with patch_size=14. The dtype mismatch cascaded through multiple layers:
|
||||
1. Initial bmm operations in day-specific transformations
|
||||
2. Adversarial training residual connections between models
|
||||
3. Patch processing operations (unfold, permute, reshape)
|
||||
4. Gradient Reversal Layer (GRL) operations
|
||||
5. Hidden state initialization in adversarial training helper methods
|
||||
|
||||
**Comprehensive Solution**: Implement dtype consistency across the entire adversarial training data flow:
|
||||
|
||||
```python
|
||||
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
||||
# Fix 1: Add dtype conversions for all bmm operands
|
||||
# Fix 1: Basic bmm operations with dtype consistency
|
||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
|
||||
# Fix 2: Ensure dtype consistency in adversarial training residual connections
|
||||
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
||||
|
||||
# Fix 3: Preserve dtype through patch processing operations
|
||||
# Fix 2: Patch processing with explicit dtype preservation
|
||||
if self.patch_size > 0:
|
||||
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||
x = x.unsqueeze(1)
|
||||
@@ -188,8 +193,37 @@ if self.patch_size > 0:
|
||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x = x.to(original_dtype)
|
||||
|
||||
# Fix 3: Adversarial training residual connections
|
||||
noise_output = noise_output.to(x_processed.dtype)
|
||||
denoised_input = x_processed - noise_output
|
||||
|
||||
# Fix 4: Gradient Reversal Layer dtype handling
|
||||
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda else noise_output
|
||||
# Ensure dtype consistency after GRL (preserves input dtype but explicit check)
|
||||
noisy_input = noisy_input.to(x_processed.dtype)
|
||||
|
||||
# Fix 5: Hidden state dtype consistency in helper methods
|
||||
# In _clean_forward_with_processed_input:
|
||||
if states is None:
|
||||
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
||||
# Ensure hidden states match input dtype for mixed precision training
|
||||
states = states.to(x_processed.dtype)
|
||||
|
||||
# In _noisy_forward_with_processed_input:
|
||||
if states is None:
|
||||
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
||||
# Ensure hidden states match input dtype for mixed precision training
|
||||
states = states.to(x_processed.dtype)
|
||||
```
|
||||
|
||||
**Key Implementation Details**:
|
||||
- **GradientReversalFn**: Preserves input dtype automatically (identity forward, gradient reversal backward)
|
||||
- **Patch Processing**: Explicit dtype preservation prevents unfold operations from changing precision
|
||||
- **Residual Connections**: All tensor arithmetic operations ensure matching dtypes
|
||||
- **Helper Methods**: Hidden state initialization matches processed input dtype
|
||||
- **Data Flow**: NoiseModel → GRL → NoisySpeechModel maintains dtype consistency throughout
|
||||
|
||||
#### 3. Hidden State Initialization
|
||||
**Problem**: Dynamic batch size allocation causes XLA recompilation
|
||||
**Solution**: Use static shapes and avoid x.shape[0] in tensor creation
|
||||
@@ -223,12 +257,25 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return
|
||||
|
||||
### Files Modified for XLA Optimization
|
||||
|
||||
- **`model_training_nnn/rnn_model.py`**: All three models optimized
|
||||
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency
|
||||
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency
|
||||
- `NoisySpeechModel.forward()`: Hidden state optimization
|
||||
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns + adversarial residual connection dtype fix
|
||||
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency
|
||||
- **`model_training_nnn/rnn_model.py`**: Comprehensive XLA optimization with dtype consistency
|
||||
- **`GradientReversalFn`**: Added adversarial training gradient reversal layer
|
||||
- **`NoiseModel.forward()`**: Dynamic indexing → static gather operations + comprehensive dtype consistency + patch processing dtype preservation
|
||||
- **`CleanSpeechModel.forward()`**: Same optimizations + bmm for matrix ops + comprehensive dtype consistency + patch processing dtype preservation
|
||||
- **`NoisySpeechModel.forward()`**: Hidden state optimization (no day layers, simplified)
|
||||
- **`TripleGRUDecoder.forward()`**: Complex return values → tuple returns + comprehensive adversarial training dtype fixes + residual connection dtype consistency + GRL dtype handling
|
||||
- **`TripleGRUDecoder._apply_preprocessing()`**: Static preprocessing operations + dtype consistency + patch processing dtype preservation
|
||||
- **`TripleGRUDecoder._clean_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
|
||||
- **`TripleGRUDecoder._noisy_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
|
||||
|
||||
**Specific Dtype Consistency Fixes Applied**:
|
||||
1. **Basic Operations**: All `torch.bmm()` operations with `.to(x.dtype)` conversions
|
||||
2. **Patch Processing**: Explicit dtype preservation through unfold/permute/reshape operations
|
||||
3. **Adversarial Training**: Residual connections with `.to(x_processed.dtype)` conversions
|
||||
4. **Gradient Reversal**: Dtype consistency after GRL operations
|
||||
5. **Hidden States**: All hidden state initialization with `.to(x_processed.dtype)` conversions
|
||||
6. **Data Flow**: End-to-end dtype consistency in NoiseModel → GRL → NoisySpeechModel pipeline
|
||||
|
||||
**Error Resolved**: `f32[32,7168] vs bf16[32,7168]` dtype mismatch in mixed precision TPU training
|
||||
|
||||
### Benefits of XLA Optimizations
|
||||
|
||||
@@ -252,5 +299,41 @@ Created test scripts to verify model consistency:
|
||||
- Backward compatibility with existing training scripts is maintained
|
||||
- TPU training should now show improved compilation times and memory efficiency
|
||||
|
||||
### Troubleshooting Dtype Issues in Mixed Precision Training
|
||||
|
||||
**Common Error Pattern**:
|
||||
```
|
||||
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[X,Y], argument shape: bf16[X,Y]
|
||||
```
|
||||
|
||||
**Diagnosis Steps**:
|
||||
1. **Identify Operation**: Look at the tensor dimensions to identify which operation is failing
|
||||
- `7168 = 512 * 14`: Patch processing operation with patch_size=14
|
||||
- `512`: Basic neural features
|
||||
- Other patterns may indicate different operations
|
||||
|
||||
2. **Check Data Flow**: Trace the tensor through the adversarial training pipeline
|
||||
- Input → NoiseModel → residual connection → CleanSpeechModel
|
||||
- Input → NoiseModel → GRL → NoisySpeechModel
|
||||
|
||||
3. **Verify Dtype Consistency**: Ensure all operations maintain input dtype
|
||||
- Use `.to(x.dtype)` for all operand tensors
|
||||
- Preserve dtype through complex operations (unfold, permute, reshape)
|
||||
- Match hidden state dtype to input tensor dtype
|
||||
|
||||
**Quick Fix Template**:
|
||||
```python
|
||||
# For any tensor operation between tensors a and b:
|
||||
result = operation(a, b.to(a.dtype))
|
||||
|
||||
# For complex operations that might change dtype:
|
||||
original_dtype = tensor.dtype
|
||||
tensor = complex_operation(tensor)
|
||||
tensor = tensor.to(original_dtype)
|
||||
|
||||
# For hidden state initialization:
|
||||
states = states.to(input_tensor.dtype)
|
||||
```
|
||||
|
||||
## Competition Context
|
||||
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.
|
Reference in New Issue
Block a user