TPU
This commit is contained in:
114
CLAUDE.md
114
CLAUDE.md
@@ -335,5 +335,119 @@ tensor = tensor.to(original_dtype)
|
||||
states = states.to(input_tensor.dtype)
|
||||
```
|
||||
|
||||
## PyTorch XLA API Updates and Warnings
|
||||
|
||||
### Deprecated APIs (as of 2024)
|
||||
|
||||
**Important**: Several torch_xla APIs have been deprecated and should be updated in new code:
|
||||
|
||||
#### 1. Device API Changes
|
||||
```python
|
||||
# ❌ Deprecated (shows DeprecationWarning):
|
||||
device = xm.xla_device()
|
||||
|
||||
# ✅ Modern API:
|
||||
import torch_xla
|
||||
device = torch_xla.device()
|
||||
```
|
||||
|
||||
#### 2. Synchronization API Changes
|
||||
```python
|
||||
# ❌ Deprecated (shows DeprecationWarning):
|
||||
xm.mark_step()
|
||||
|
||||
# ✅ Modern API:
|
||||
import torch_xla
|
||||
torch_xla.sync()
|
||||
```
|
||||
|
||||
#### 3. Mixed Precision Environment Variables
|
||||
```python
|
||||
# ⚠️ Will be deprecated after PyTorch XLA 2.6:
|
||||
os.environ['XLA_USE_BF16'] = '1'
|
||||
|
||||
# 💡 Recommended: Convert model to bf16 directly in code
|
||||
model = model.to(torch.bfloat16)
|
||||
```
|
||||
|
||||
### TPU Performance Warnings
|
||||
|
||||
#### Transparent Hugepages Warning
|
||||
```
|
||||
UserWarning: Transparent hugepages are not enabled. TPU runtime startup and
|
||||
shutdown time should be significantly improved on TPU v5e and newer.
|
||||
```
|
||||
|
||||
**Solution** (for TPU v5e and newer):
|
||||
```bash
|
||||
sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled"
|
||||
```
|
||||
|
||||
**Note**: This warning appears on TPU environments and can be safely ignored if you don't have root access (e.g., Kaggle, Colab).
|
||||
|
||||
### Updated Code Patterns
|
||||
|
||||
#### Modern XLA Synchronization Pattern
|
||||
```python
|
||||
import torch_xla.core.xla_model as xm # Still needed for other functions
|
||||
import torch_xla
|
||||
|
||||
# Modern pattern:
|
||||
def train_step():
|
||||
# ... training code ...
|
||||
|
||||
# Synchronize every N steps
|
||||
if step % sync_frequency == 0:
|
||||
torch_xla.sync() # Instead of xm.mark_step()
|
||||
|
||||
# Legacy pattern (still works but deprecated):
|
||||
def train_step_legacy():
|
||||
# ... training code ...
|
||||
|
||||
# Old way (shows deprecation warning)
|
||||
if step % sync_frequency == 0:
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops() # This is still current
|
||||
```
|
||||
|
||||
#### Device Detection Pattern
|
||||
```python
|
||||
# Modern approach:
|
||||
import torch_xla
|
||||
|
||||
try:
|
||||
device = torch_xla.device()
|
||||
print(f"Using XLA device: {device}")
|
||||
except:
|
||||
device = torch.device('cpu')
|
||||
print("Falling back to CPU")
|
||||
|
||||
# Legacy approach (shows warnings):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
try:
|
||||
device = xm.xla_device() # DeprecationWarning
|
||||
print(f"Using XLA device: {device}")
|
||||
except:
|
||||
device = torch.device('cpu')
|
||||
```
|
||||
|
||||
### Migration Guidelines
|
||||
|
||||
When updating existing code:
|
||||
|
||||
1. **Replace `xm.xla_device()`** with `torch_xla.device()`
|
||||
2. **Replace `xm.mark_step()`** with `torch_xla.sync()`
|
||||
3. **Keep `xm.wait_device_ops()`** (still current API)
|
||||
4. **Update imports** to include `torch_xla` directly
|
||||
5. **Consider explicit bf16 conversion** instead of environment variables
|
||||
|
||||
### Backward Compatibility
|
||||
|
||||
The deprecated APIs still work but generate warnings. For production code:
|
||||
- Update to modern APIs to avoid warnings
|
||||
- Test thoroughly as synchronization behavior may differ slightly
|
||||
- Legacy code will continue to function until removed in future versions
|
||||
|
||||
## Competition Context
|
||||
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.
|
Reference in New Issue
Block a user