tpu
This commit is contained in:
16
CLAUDE.md
16
CLAUDE.md
@@ -166,8 +166,8 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtyp
|
||||
```
|
||||
|
||||
#### 5. Mixed Precision Dtype Consistency
|
||||
**Problem**: Mixed precision training causes dtype mismatches in bmm operations and adversarial residual connections
|
||||
**Solution**: Ensure all operands match input tensor dtype
|
||||
**Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations
|
||||
**Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations
|
||||
|
||||
```python
|
||||
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
||||
@@ -176,6 +176,18 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||
|
||||
# Fix 2: Ensure dtype consistency in adversarial training residual connections
|
||||
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
||||
|
||||
# Fix 3: Preserve dtype through patch processing operations
|
||||
if self.patch_size > 0:
|
||||
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||
x = x.unsqueeze(1)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||
x_unfold = x_unfold.squeeze(2)
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||
# Ensure dtype consistency after patch processing operations
|
||||
x = x.to(original_dtype)
|
||||
```
|
||||
|
||||
#### 3. Hidden State Initialization
|
||||
|
Reference in New Issue
Block a user