tpu
This commit is contained in:
@@ -319,6 +319,37 @@ if xm.get_xla_supported_devices():
|
||||
|
||||
**预期改进**: XLA图编译时间从5-15分钟缩短到2-8分钟
|
||||
|
||||
## New Issue: DType Mismatch in adjusted_lens Calculation (2025-10-12 16:45)
|
||||
|
||||
### Error Description
|
||||
```
|
||||
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 1 shape: f32[21504], argument shape: bf16[21504].
|
||||
```
|
||||
|
||||
### Root Cause
|
||||
The `adjusted_lens` calculation was causing dtype mismatches in TPU mixed precision (bf16) training. When `n_time_steps` is processed under `accelerator.autocast()`, it becomes bfloat16, but the arithmetic operations were creating float32 results.
|
||||
|
||||
### Problem Code
|
||||
```python
|
||||
# Before (causes f32/bf16 mismatch):
|
||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
```
|
||||
|
||||
### Solution
|
||||
Explicit float conversion before dtype casting:
|
||||
|
||||
```python
|
||||
# After (explicit dtype control):
|
||||
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
```
|
||||
|
||||
### Fixed Locations
|
||||
- `rnn_trainer.py:577` - Training loop
|
||||
- `rnn_trainer.py:753` - Validation loop
|
||||
- `rnn_trainer.py:851` - Inference batch function
|
||||
|
||||
**Key Insight**: Mixed precision training requires explicit dtype management for ALL tensor operations, even intermediate calculations.
|
||||
|
||||
## Lessons Learned
|
||||
- **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors
|
||||
- **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype
|
||||
|
Reference in New Issue
Block a user