tpu
This commit is contained in:
		
							
								
								
									
										113
									
								
								CLAUDE.md
									
									
									
									
									
								
							
							
						
						
									
										113
									
								
								CLAUDE.md
									
									
									
									
									
								
							| @@ -165,19 +165,24 @@ x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases | ||||
| x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)  # bmm + dtype consistency | ||||
| ``` | ||||
|  | ||||
| #### 5. Mixed Precision Dtype Consistency | ||||
| **Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations | ||||
| **Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations | ||||
| #### 5. Mixed Precision Dtype Consistency (Comprehensive Fix) | ||||
| **Problem**: Mixed precision training causes dtype mismatches throughout the adversarial training pipeline | ||||
| **Error**: `Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]` | ||||
|  | ||||
| **Root Cause Analysis**: The error occurred at dimension 7168 = 512 * 14, indicating patch processing with patch_size=14. The dtype mismatch cascaded through multiple layers: | ||||
| 1. Initial bmm operations in day-specific transformations | ||||
| 2. Adversarial training residual connections between models | ||||
| 3. Patch processing operations (unfold, permute, reshape) | ||||
| 4. Gradient Reversal Layer (GRL) operations | ||||
| 5. Hidden state initialization in adversarial training helper methods | ||||
|  | ||||
| **Comprehensive Solution**: Implement dtype consistency across the entire adversarial training data flow: | ||||
|  | ||||
| ```python | ||||
| # Error: f32[32,7168] vs bf16[32,7168] in mixed precision training | ||||
| # Fix 1: Add dtype conversions for all bmm operands | ||||
| # Fix 1: Basic bmm operations with dtype consistency | ||||
| x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) | ||||
|  | ||||
| # Fix 2: Ensure dtype consistency in adversarial training residual connections | ||||
| denoised_input = x_processed - noise_output.to(x_processed.dtype) | ||||
|  | ||||
| # Fix 3: Preserve dtype through patch processing operations | ||||
| # Fix 2: Patch processing with explicit dtype preservation | ||||
| if self.patch_size > 0: | ||||
|     original_dtype = x.dtype  # Preserve original dtype for XLA/TPU compatibility | ||||
|     x = x.unsqueeze(1) | ||||
| @@ -188,8 +193,37 @@ if self.patch_size > 0: | ||||
|     x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) | ||||
|     # Ensure dtype consistency after patch processing operations | ||||
|     x = x.to(original_dtype) | ||||
|  | ||||
| # Fix 3: Adversarial training residual connections | ||||
| noise_output = noise_output.to(x_processed.dtype) | ||||
| denoised_input = x_processed - noise_output | ||||
|  | ||||
| # Fix 4: Gradient Reversal Layer dtype handling | ||||
| noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda else noise_output | ||||
| # Ensure dtype consistency after GRL (preserves input dtype but explicit check) | ||||
| noisy_input = noisy_input.to(x_processed.dtype) | ||||
|  | ||||
| # Fix 5: Hidden state dtype consistency in helper methods | ||||
| # In _clean_forward_with_processed_input: | ||||
| if states is None: | ||||
|     states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous() | ||||
|     # Ensure hidden states match input dtype for mixed precision training | ||||
|     states = states.to(x_processed.dtype) | ||||
|  | ||||
| # In _noisy_forward_with_processed_input: | ||||
| if states is None: | ||||
|     states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous() | ||||
|     # Ensure hidden states match input dtype for mixed precision training | ||||
|     states = states.to(x_processed.dtype) | ||||
| ``` | ||||
|  | ||||
| **Key Implementation Details**: | ||||
| - **GradientReversalFn**: Preserves input dtype automatically (identity forward, gradient reversal backward) | ||||
| - **Patch Processing**: Explicit dtype preservation prevents unfold operations from changing precision | ||||
| - **Residual Connections**: All tensor arithmetic operations ensure matching dtypes | ||||
| - **Helper Methods**: Hidden state initialization matches processed input dtype | ||||
| - **Data Flow**: NoiseModel → GRL → NoisySpeechModel maintains dtype consistency throughout | ||||
|  | ||||
| #### 3. Hidden State Initialization | ||||
| **Problem**: Dynamic batch size allocation causes XLA recompilation | ||||
| **Solution**: Use static shapes and avoid x.shape[0] in tensor creation | ||||
| @@ -223,12 +257,25 @@ return clean_logits, noisy_logits, noise_output  # Simple tuple return | ||||
|  | ||||
| ### Files Modified for XLA Optimization | ||||
|  | ||||
| - **`model_training_nnn/rnn_model.py`**: All three models optimized | ||||
|   - `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency | ||||
|   - `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency | ||||
|   - `NoisySpeechModel.forward()`: Hidden state optimization | ||||
|   - `TripleGRUDecoder.forward()`: Complex return values → tuple returns + adversarial residual connection dtype fix | ||||
|   - `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency | ||||
| - **`model_training_nnn/rnn_model.py`**: Comprehensive XLA optimization with dtype consistency | ||||
|   - **`GradientReversalFn`**: Added adversarial training gradient reversal layer | ||||
|   - **`NoiseModel.forward()`**: Dynamic indexing → static gather operations + comprehensive dtype consistency + patch processing dtype preservation | ||||
|   - **`CleanSpeechModel.forward()`**: Same optimizations + bmm for matrix ops + comprehensive dtype consistency + patch processing dtype preservation | ||||
|   - **`NoisySpeechModel.forward()`**: Hidden state optimization (no day layers, simplified) | ||||
|   - **`TripleGRUDecoder.forward()`**: Complex return values → tuple returns + comprehensive adversarial training dtype fixes + residual connection dtype consistency + GRL dtype handling | ||||
|   - **`TripleGRUDecoder._apply_preprocessing()`**: Static preprocessing operations + dtype consistency + patch processing dtype preservation | ||||
|   - **`TripleGRUDecoder._clean_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision | ||||
|   - **`TripleGRUDecoder._noisy_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision | ||||
|  | ||||
| **Specific Dtype Consistency Fixes Applied**: | ||||
| 1. **Basic Operations**: All `torch.bmm()` operations with `.to(x.dtype)` conversions | ||||
| 2. **Patch Processing**: Explicit dtype preservation through unfold/permute/reshape operations | ||||
| 3. **Adversarial Training**: Residual connections with `.to(x_processed.dtype)` conversions | ||||
| 4. **Gradient Reversal**: Dtype consistency after GRL operations | ||||
| 5. **Hidden States**: All hidden state initialization with `.to(x_processed.dtype)` conversions | ||||
| 6. **Data Flow**: End-to-end dtype consistency in NoiseModel → GRL → NoisySpeechModel pipeline | ||||
|  | ||||
| **Error Resolved**: `f32[32,7168] vs bf16[32,7168]` dtype mismatch in mixed precision TPU training | ||||
|  | ||||
| ### Benefits of XLA Optimizations | ||||
|  | ||||
| @@ -252,5 +299,41 @@ Created test scripts to verify model consistency: | ||||
| - Backward compatibility with existing training scripts is maintained | ||||
| - TPU training should now show improved compilation times and memory efficiency | ||||
|  | ||||
| ### Troubleshooting Dtype Issues in Mixed Precision Training | ||||
|  | ||||
| **Common Error Pattern**: | ||||
| ``` | ||||
| Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[X,Y], argument shape: bf16[X,Y] | ||||
| ``` | ||||
|  | ||||
| **Diagnosis Steps**: | ||||
| 1. **Identify Operation**: Look at the tensor dimensions to identify which operation is failing | ||||
|    - `7168 = 512 * 14`: Patch processing operation with patch_size=14 | ||||
|    - `512`: Basic neural features | ||||
|    - Other patterns may indicate different operations | ||||
|  | ||||
| 2. **Check Data Flow**: Trace the tensor through the adversarial training pipeline | ||||
|    - Input → NoiseModel → residual connection → CleanSpeechModel | ||||
|    - Input → NoiseModel → GRL → NoisySpeechModel | ||||
|  | ||||
| 3. **Verify Dtype Consistency**: Ensure all operations maintain input dtype | ||||
|    - Use `.to(x.dtype)` for all operand tensors | ||||
|    - Preserve dtype through complex operations (unfold, permute, reshape) | ||||
|    - Match hidden state dtype to input tensor dtype | ||||
|  | ||||
| **Quick Fix Template**: | ||||
| ```python | ||||
| # For any tensor operation between tensors a and b: | ||||
| result = operation(a, b.to(a.dtype)) | ||||
|  | ||||
| # For complex operations that might change dtype: | ||||
| original_dtype = tensor.dtype | ||||
| tensor = complex_operation(tensor) | ||||
| tensor = tensor.to(original_dtype) | ||||
|  | ||||
| # For hidden state initialization: | ||||
| states = states.to(input_tensor.dtype) | ||||
| ``` | ||||
|  | ||||
| ## Competition Context | ||||
| This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding. | ||||
| @@ -407,9 +407,11 @@ class TripleGRUDecoder(nn.Module): | ||||
|         '''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)''' | ||||
|         batch_size = x_processed.size(0) | ||||
|  | ||||
|         # XLA-friendly hidden state initialization | ||||
|         # XLA-friendly hidden state initialization with dtype consistency | ||||
|         if states is None: | ||||
|             states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous() | ||||
|             # Ensure hidden states match input dtype for mixed precision training | ||||
|             states = states.to(x_processed.dtype) | ||||
|  | ||||
|         # GRU forward pass (skip preprocessing since input is already processed) | ||||
|         output, hidden_states = self.clean_speech_model.gru(x_processed, states) | ||||
| @@ -422,9 +424,11 @@ class TripleGRUDecoder(nn.Module): | ||||
|         '''Forward pass for NoisySpeechModel with already processed input''' | ||||
|         batch_size = x_processed.size(0) | ||||
|  | ||||
|         # XLA-friendly hidden state initialization | ||||
|         # XLA-friendly hidden state initialization with dtype consistency | ||||
|         if states is None: | ||||
|             states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous() | ||||
|             # Ensure hidden states match input dtype for mixed precision training | ||||
|             states = states.to(x_processed.dtype) | ||||
|  | ||||
|         # GRU forward pass (NoisySpeechModel doesn't have day layers anyway) | ||||
|         output, hidden_states = self.noisy_speech_model.gru(x_processed, states) | ||||
| @@ -455,9 +459,11 @@ class TripleGRUDecoder(nn.Module): | ||||
|             # Apply the same preprocessing that the models use internally | ||||
|             x_processed = self._apply_preprocessing(x, day_idx) | ||||
|  | ||||
|             # Ensure dtype consistency between processed input and noise output | ||||
|             noise_output = noise_output.to(x_processed.dtype) | ||||
|  | ||||
|             # 3. Clean speech model processes denoised signal | ||||
|             # Ensure dtype consistency for mixed precision training in residual connection | ||||
|             denoised_input = x_processed - noise_output.to(x_processed.dtype)  # Residual connection in processed space | ||||
|             denoised_input = x_processed - noise_output  # Residual connection in processed space | ||||
|             # Clean speech model will apply its own preprocessing, so we pass the denoised processed data | ||||
|             # But we need to reverse the preprocessing first, then let clean model do its own | ||||
|             # Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing | ||||
| @@ -467,6 +473,9 @@ class TripleGRUDecoder(nn.Module): | ||||
|             # 4. Noisy speech model processes noise signal directly (no day layers needed) | ||||
|             # Optionally apply Gradient Reversal to enforce adversarial training on noise output | ||||
|             noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output | ||||
|             # Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility | ||||
|             # Use x_processed.dtype as reference since it's the main data flow dtype | ||||
|             noisy_input = noisy_input.to(x_processed.dtype) | ||||
|             noisy_logits = self._noisy_forward_with_processed_input(noisy_input, | ||||
|                                                                    states['noisy'] if states else None) | ||||
|  | ||||
| @@ -485,9 +494,9 @@ class TripleGRUDecoder(nn.Module): | ||||
|             # 2. For residual connection, we need x in the same space as noise_output | ||||
|             x_processed = self._apply_preprocessing(x, day_idx) | ||||
|  | ||||
|             # 3. Process denoised signal | ||||
|             # Ensure dtype consistency for mixed precision training in residual connection | ||||
|             denoised_input = x_processed - noise_output.to(x_processed.dtype) | ||||
|             # Ensure dtype consistency for mixed precision residual connection | ||||
|             noise_output = noise_output.to(x_processed.dtype) | ||||
|             denoised_input = x_processed - noise_output | ||||
|             clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, | ||||
|                                                                    states['clean'] if states else None) | ||||
|  | ||||
| @@ -505,7 +514,10 @@ class TripleGRUDecoder(nn.Module): | ||||
|  | ||||
|         clean_grad (tensor) - gradients from clean speech model output layer | ||||
|         noisy_grad (tensor) - gradients from noisy speech model output layer | ||||
|         learning_rate (float) - learning rate for gradient update | ||||
|                 if grl_lambda and grl_lambda != 0.0: | ||||
|                     noisy_input = gradient_reverse(noise_output, grl_lambda) | ||||
|                 else: | ||||
|                     noisy_input = noise_output | ||||
|         ''' | ||||
|         # Combine gradients: negative from clean model, positive from noisy model | ||||
|         combined_grad = -clean_grad + noisy_grad | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Zchen
					Zchen