This commit is contained in:
Zchen
2025-10-12 22:59:45 +08:00
parent 5c941d9efa
commit 6cfc568f9a
2 changed files with 56 additions and 0 deletions

View File

@@ -350,6 +350,42 @@ adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / sel
**Key Insight**: Mixed precision training requires explicit dtype management for ALL tensor operations, even intermediate calculations.
## New Issue: Features Tensor DType Mismatch (2025-10-12 17:00)
### Error Description
```
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168].
```
### Root Cause Analysis
After fixing the `adjusted_lens` dtype issue, a new mismatch emerged in the `features` tensor dimensions `[32, 7168]` representing (batch_size=32, neural_dim×patch_size=512×14=7168). Under `accelerator.autocast()` with mixed precision `bf16`, input tensors are automatically converted to bfloat16, but model parameters remained in float32 after removing hardcoded dtype specifications, creating a mismatch at the model input level.
### Problem Code
```python
# Inside accelerator.autocast() context:
# features becomes bf16 automatically by autocast
logits = self.model(features, day_indicies, None, False, 'inference')
# Model expects f32 parameters but receives bf16 input → mismatch
```
### Solution
Add explicit dtype conversion before all model calls to ensure consistency:
```python
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
```
### Fixed Locations
- `rnn_trainer.py:582-584` - Training loop model call
- `rnn_trainer.py:760-763` - Validation loop model call
- `rnn_trainer.py:839-842` - Inference method model call
- `rnn_trainer.py:863-866` - Inference batch method model call
**Key Insight**: Mixed precision autocast converts inputs but not necessarily model parameters. When removing hardcoded dtypes, explicit conversion ensures compatibility between autocast inputs and model parameters.
## Lessons Learned
- **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors
- **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype