tpu
This commit is contained in:
113
CLAUDE.md
113
CLAUDE.md
@@ -165,19 +165,24 @@ x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
|||||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 5. Mixed Precision Dtype Consistency
|
#### 5. Mixed Precision Dtype Consistency (Comprehensive Fix)
|
||||||
**Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations
|
**Problem**: Mixed precision training causes dtype mismatches throughout the adversarial training pipeline
|
||||||
**Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations
|
**Error**: `Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]`
|
||||||
|
|
||||||
|
**Root Cause Analysis**: The error occurred at dimension 7168 = 512 * 14, indicating patch processing with patch_size=14. The dtype mismatch cascaded through multiple layers:
|
||||||
|
1. Initial bmm operations in day-specific transformations
|
||||||
|
2. Adversarial training residual connections between models
|
||||||
|
3. Patch processing operations (unfold, permute, reshape)
|
||||||
|
4. Gradient Reversal Layer (GRL) operations
|
||||||
|
5. Hidden state initialization in adversarial training helper methods
|
||||||
|
|
||||||
|
**Comprehensive Solution**: Implement dtype consistency across the entire adversarial training data flow:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
# Fix 1: Basic bmm operations with dtype consistency
|
||||||
# Fix 1: Add dtype conversions for all bmm operands
|
|
||||||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
|
|
||||||
# Fix 2: Ensure dtype consistency in adversarial training residual connections
|
# Fix 2: Patch processing with explicit dtype preservation
|
||||||
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
|
||||||
|
|
||||||
# Fix 3: Preserve dtype through patch processing operations
|
|
||||||
if self.patch_size > 0:
|
if self.patch_size > 0:
|
||||||
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
@@ -188,8 +193,37 @@ if self.patch_size > 0:
|
|||||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||||
# Ensure dtype consistency after patch processing operations
|
# Ensure dtype consistency after patch processing operations
|
||||||
x = x.to(original_dtype)
|
x = x.to(original_dtype)
|
||||||
|
|
||||||
|
# Fix 3: Adversarial training residual connections
|
||||||
|
noise_output = noise_output.to(x_processed.dtype)
|
||||||
|
denoised_input = x_processed - noise_output
|
||||||
|
|
||||||
|
# Fix 4: Gradient Reversal Layer dtype handling
|
||||||
|
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda else noise_output
|
||||||
|
# Ensure dtype consistency after GRL (preserves input dtype but explicit check)
|
||||||
|
noisy_input = noisy_input.to(x_processed.dtype)
|
||||||
|
|
||||||
|
# Fix 5: Hidden state dtype consistency in helper methods
|
||||||
|
# In _clean_forward_with_processed_input:
|
||||||
|
if states is None:
|
||||||
|
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
||||||
|
# Ensure hidden states match input dtype for mixed precision training
|
||||||
|
states = states.to(x_processed.dtype)
|
||||||
|
|
||||||
|
# In _noisy_forward_with_processed_input:
|
||||||
|
if states is None:
|
||||||
|
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
||||||
|
# Ensure hidden states match input dtype for mixed precision training
|
||||||
|
states = states.to(x_processed.dtype)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Key Implementation Details**:
|
||||||
|
- **GradientReversalFn**: Preserves input dtype automatically (identity forward, gradient reversal backward)
|
||||||
|
- **Patch Processing**: Explicit dtype preservation prevents unfold operations from changing precision
|
||||||
|
- **Residual Connections**: All tensor arithmetic operations ensure matching dtypes
|
||||||
|
- **Helper Methods**: Hidden state initialization matches processed input dtype
|
||||||
|
- **Data Flow**: NoiseModel → GRL → NoisySpeechModel maintains dtype consistency throughout
|
||||||
|
|
||||||
#### 3. Hidden State Initialization
|
#### 3. Hidden State Initialization
|
||||||
**Problem**: Dynamic batch size allocation causes XLA recompilation
|
**Problem**: Dynamic batch size allocation causes XLA recompilation
|
||||||
**Solution**: Use static shapes and avoid x.shape[0] in tensor creation
|
**Solution**: Use static shapes and avoid x.shape[0] in tensor creation
|
||||||
@@ -223,12 +257,25 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return
|
|||||||
|
|
||||||
### Files Modified for XLA Optimization
|
### Files Modified for XLA Optimization
|
||||||
|
|
||||||
- **`model_training_nnn/rnn_model.py`**: All three models optimized
|
- **`model_training_nnn/rnn_model.py`**: Comprehensive XLA optimization with dtype consistency
|
||||||
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency
|
- **`GradientReversalFn`**: Added adversarial training gradient reversal layer
|
||||||
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency
|
- **`NoiseModel.forward()`**: Dynamic indexing → static gather operations + comprehensive dtype consistency + patch processing dtype preservation
|
||||||
- `NoisySpeechModel.forward()`: Hidden state optimization
|
- **`CleanSpeechModel.forward()`**: Same optimizations + bmm for matrix ops + comprehensive dtype consistency + patch processing dtype preservation
|
||||||
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns + adversarial residual connection dtype fix
|
- **`NoisySpeechModel.forward()`**: Hidden state optimization (no day layers, simplified)
|
||||||
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency
|
- **`TripleGRUDecoder.forward()`**: Complex return values → tuple returns + comprehensive adversarial training dtype fixes + residual connection dtype consistency + GRL dtype handling
|
||||||
|
- **`TripleGRUDecoder._apply_preprocessing()`**: Static preprocessing operations + dtype consistency + patch processing dtype preservation
|
||||||
|
- **`TripleGRUDecoder._clean_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
|
||||||
|
- **`TripleGRUDecoder._noisy_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
|
||||||
|
|
||||||
|
**Specific Dtype Consistency Fixes Applied**:
|
||||||
|
1. **Basic Operations**: All `torch.bmm()` operations with `.to(x.dtype)` conversions
|
||||||
|
2. **Patch Processing**: Explicit dtype preservation through unfold/permute/reshape operations
|
||||||
|
3. **Adversarial Training**: Residual connections with `.to(x_processed.dtype)` conversions
|
||||||
|
4. **Gradient Reversal**: Dtype consistency after GRL operations
|
||||||
|
5. **Hidden States**: All hidden state initialization with `.to(x_processed.dtype)` conversions
|
||||||
|
6. **Data Flow**: End-to-end dtype consistency in NoiseModel → GRL → NoisySpeechModel pipeline
|
||||||
|
|
||||||
|
**Error Resolved**: `f32[32,7168] vs bf16[32,7168]` dtype mismatch in mixed precision TPU training
|
||||||
|
|
||||||
### Benefits of XLA Optimizations
|
### Benefits of XLA Optimizations
|
||||||
|
|
||||||
@@ -252,5 +299,41 @@ Created test scripts to verify model consistency:
|
|||||||
- Backward compatibility with existing training scripts is maintained
|
- Backward compatibility with existing training scripts is maintained
|
||||||
- TPU training should now show improved compilation times and memory efficiency
|
- TPU training should now show improved compilation times and memory efficiency
|
||||||
|
|
||||||
|
### Troubleshooting Dtype Issues in Mixed Precision Training
|
||||||
|
|
||||||
|
**Common Error Pattern**:
|
||||||
|
```
|
||||||
|
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[X,Y], argument shape: bf16[X,Y]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Diagnosis Steps**:
|
||||||
|
1. **Identify Operation**: Look at the tensor dimensions to identify which operation is failing
|
||||||
|
- `7168 = 512 * 14`: Patch processing operation with patch_size=14
|
||||||
|
- `512`: Basic neural features
|
||||||
|
- Other patterns may indicate different operations
|
||||||
|
|
||||||
|
2. **Check Data Flow**: Trace the tensor through the adversarial training pipeline
|
||||||
|
- Input → NoiseModel → residual connection → CleanSpeechModel
|
||||||
|
- Input → NoiseModel → GRL → NoisySpeechModel
|
||||||
|
|
||||||
|
3. **Verify Dtype Consistency**: Ensure all operations maintain input dtype
|
||||||
|
- Use `.to(x.dtype)` for all operand tensors
|
||||||
|
- Preserve dtype through complex operations (unfold, permute, reshape)
|
||||||
|
- Match hidden state dtype to input tensor dtype
|
||||||
|
|
||||||
|
**Quick Fix Template**:
|
||||||
|
```python
|
||||||
|
# For any tensor operation between tensors a and b:
|
||||||
|
result = operation(a, b.to(a.dtype))
|
||||||
|
|
||||||
|
# For complex operations that might change dtype:
|
||||||
|
original_dtype = tensor.dtype
|
||||||
|
tensor = complex_operation(tensor)
|
||||||
|
tensor = tensor.to(original_dtype)
|
||||||
|
|
||||||
|
# For hidden state initialization:
|
||||||
|
states = states.to(input_tensor.dtype)
|
||||||
|
```
|
||||||
|
|
||||||
## Competition Context
|
## Competition Context
|
||||||
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.
|
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.
|
@@ -407,9 +407,11 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
|
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
|
||||||
batch_size = x_processed.size(0)
|
batch_size = x_processed.size(0)
|
||||||
|
|
||||||
# XLA-friendly hidden state initialization
|
# XLA-friendly hidden state initialization with dtype consistency
|
||||||
if states is None:
|
if states is None:
|
||||||
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
||||||
|
# Ensure hidden states match input dtype for mixed precision training
|
||||||
|
states = states.to(x_processed.dtype)
|
||||||
|
|
||||||
# GRU forward pass (skip preprocessing since input is already processed)
|
# GRU forward pass (skip preprocessing since input is already processed)
|
||||||
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
|
||||||
@@ -422,9 +424,11 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
'''Forward pass for NoisySpeechModel with already processed input'''
|
'''Forward pass for NoisySpeechModel with already processed input'''
|
||||||
batch_size = x_processed.size(0)
|
batch_size = x_processed.size(0)
|
||||||
|
|
||||||
# XLA-friendly hidden state initialization
|
# XLA-friendly hidden state initialization with dtype consistency
|
||||||
if states is None:
|
if states is None:
|
||||||
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
||||||
|
# Ensure hidden states match input dtype for mixed precision training
|
||||||
|
states = states.to(x_processed.dtype)
|
||||||
|
|
||||||
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
|
||||||
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
|
||||||
@@ -455,9 +459,11 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
# Apply the same preprocessing that the models use internally
|
# Apply the same preprocessing that the models use internally
|
||||||
x_processed = self._apply_preprocessing(x, day_idx)
|
x_processed = self._apply_preprocessing(x, day_idx)
|
||||||
|
|
||||||
|
# Ensure dtype consistency between processed input and noise output
|
||||||
|
noise_output = noise_output.to(x_processed.dtype)
|
||||||
|
|
||||||
# 3. Clean speech model processes denoised signal
|
# 3. Clean speech model processes denoised signal
|
||||||
# Ensure dtype consistency for mixed precision training in residual connection
|
denoised_input = x_processed - noise_output # Residual connection in processed space
|
||||||
denoised_input = x_processed - noise_output.to(x_processed.dtype) # Residual connection in processed space
|
|
||||||
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
|
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
|
||||||
# But we need to reverse the preprocessing first, then let clean model do its own
|
# But we need to reverse the preprocessing first, then let clean model do its own
|
||||||
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
|
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
|
||||||
@@ -467,6 +473,9 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
# 4. Noisy speech model processes noise signal directly (no day layers needed)
|
||||||
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
|
# Optionally apply Gradient Reversal to enforce adversarial training on noise output
|
||||||
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
|
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
|
||||||
|
# Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility
|
||||||
|
# Use x_processed.dtype as reference since it's the main data flow dtype
|
||||||
|
noisy_input = noisy_input.to(x_processed.dtype)
|
||||||
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
|
noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
|
||||||
states['noisy'] if states else None)
|
states['noisy'] if states else None)
|
||||||
|
|
||||||
@@ -485,9 +494,9 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
# 2. For residual connection, we need x in the same space as noise_output
|
# 2. For residual connection, we need x in the same space as noise_output
|
||||||
x_processed = self._apply_preprocessing(x, day_idx)
|
x_processed = self._apply_preprocessing(x, day_idx)
|
||||||
|
|
||||||
# 3. Process denoised signal
|
# Ensure dtype consistency for mixed precision residual connection
|
||||||
# Ensure dtype consistency for mixed precision training in residual connection
|
noise_output = noise_output.to(x_processed.dtype)
|
||||||
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
denoised_input = x_processed - noise_output
|
||||||
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||||
states['clean'] if states else None)
|
states['clean'] if states else None)
|
||||||
|
|
||||||
@@ -505,7 +514,10 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
|
|
||||||
clean_grad (tensor) - gradients from clean speech model output layer
|
clean_grad (tensor) - gradients from clean speech model output layer
|
||||||
noisy_grad (tensor) - gradients from noisy speech model output layer
|
noisy_grad (tensor) - gradients from noisy speech model output layer
|
||||||
learning_rate (float) - learning rate for gradient update
|
if grl_lambda and grl_lambda != 0.0:
|
||||||
|
noisy_input = gradient_reverse(noise_output, grl_lambda)
|
||||||
|
else:
|
||||||
|
noisy_input = noise_output
|
||||||
'''
|
'''
|
||||||
# Combine gradients: negative from clean model, positive from noisy model
|
# Combine gradients: negative from clean model, positive from noisy model
|
||||||
combined_grad = -clean_grad + noisy_grad
|
combined_grad = -clean_grad + noisy_grad
|
||||||
|
Reference in New Issue
Block a user