fixed : tf call cuda
This commit is contained in:
142
CLAUDE.md
142
CLAUDE.md
@@ -449,5 +449,145 @@ The deprecated APIs still work but generate warnings. For production code:
|
||||
- Test thoroughly as synchronization behavior may differ slightly
|
||||
- Legacy code will continue to function until removed in future versions
|
||||
|
||||
## TensorFlow TPU Implementation
|
||||
|
||||
The original PyTorch implementation has been converted to TensorFlow for optimal performance on TPU v5e-8 environments, particularly for the Brain-to-Text '25 Competition on Kaggle.
|
||||
|
||||
### Key TensorFlow Components (`model_training_nnn_tpu/`)
|
||||
|
||||
#### Core Files
|
||||
- **`rnn_model_tf.py`**: TensorFlow implementation of TripleGRUDecoder architecture
|
||||
- `NoiseModel`: 2-layer GRU for noise estimation with day-specific layers
|
||||
- `CleanSpeechModel`: 3-layer GRU for clean speech recognition with day-specific layers
|
||||
- `NoisySpeechModel`: 2-layer GRU for noisy speech recognition (no day layers)
|
||||
- `TripleGRUDecoder`: Main adversarial architecture combining all three models
|
||||
- `CTCLoss`: Custom CTC loss implementation for TPU compatibility
|
||||
- `create_tpu_strategy()`: Enhanced TPU connection function with robust environment detection
|
||||
|
||||
- **`trainer_tf.py`**: TensorFlow training pipeline with distributed TPU support
|
||||
- **`dataset_tf.py`**: TensorFlow data loading with augmentation pipeline optimized for TPU
|
||||
- **`train_model_tf.py`**: Main training script entry point
|
||||
- **`evaluate_model_tf.py`**: Evaluation pipeline for model performance analysis
|
||||
|
||||
### TPU v5e-8 Specific Optimizations
|
||||
|
||||
#### 1. Enhanced TPU Connection
|
||||
The `create_tpu_strategy()` function provides robust TPU detection across different environments:
|
||||
|
||||
```python
|
||||
def create_tpu_strategy():
|
||||
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
||||
# Multi-environment TPU detection
|
||||
if 'COLAB_TPU_ADDR' in os.environ:
|
||||
tpu_address = os.environ['COLAB_TPU_ADDR']
|
||||
elif 'TPU_NAME' in os.environ:
|
||||
tpu_name = os.environ['TPU_NAME']
|
||||
elif 'TPU_WORKER_ID' in os.environ:
|
||||
# Kaggle TPU environment
|
||||
tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address
|
||||
|
||||
# Enhanced error handling and debugging output
|
||||
# Fallback to default strategy if TPU connection fails
|
||||
```
|
||||
|
||||
**Environment Variables Detected**:
|
||||
- `COLAB_TPU_ADDR`: Google Colab TPU environment
|
||||
- `TPU_NAME`: Generic TPU name specification
|
||||
- `TPU_WORKER_ID`: Kaggle TPU environment indicator
|
||||
|
||||
**Troubleshooting TPU Connection Issues**:
|
||||
- Error: "Failed to initialize TPU: Please provide a TPU Name to connect to."
|
||||
- Solution: The function automatically detects and uses appropriate TPU addresses based on environment
|
||||
- Debugging: All TPU-related environment variables are printed during initialization
|
||||
|
||||
#### 2. Mixed Precision Training
|
||||
Configured for optimal TPU v5e-8 performance:
|
||||
```python
|
||||
def configure_mixed_precision():
|
||||
"""Configure mixed precision for optimal TPU v5e-8 performance"""
|
||||
policy = keras.mixed_precision.Policy('mixed_bfloat16')
|
||||
keras.mixed_precision.set_global_policy(policy)
|
||||
```
|
||||
|
||||
#### 3. XLA-Optimized Operations
|
||||
- **Static Tensor Operations**: Using `tf.stack()` and `tf.gather()` instead of dynamic indexing
|
||||
- **Efficient Matrix Operations**: `tf.linalg.matmul()` for batch matrix multiplication
|
||||
- **TPU-Friendly GRU Layers**: Disabled recurrent dropout for better TPU performance
|
||||
- **Patch Processing**: TensorFlow equivalent of PyTorch's unfold using `tf.image.extract_patches()`
|
||||
|
||||
### Key Architecture Differences from PyTorch
|
||||
|
||||
#### 1. Gradient Reversal Layer (GRL)
|
||||
```python
|
||||
@tf.custom_gradient
|
||||
def gradient_reverse(x, lambd=1.0):
|
||||
"""Gradient Reversal Layer for TensorFlow"""
|
||||
def grad(dy):
|
||||
return -lambd * dy # Only return gradient w.r.t. x
|
||||
return tf.identity(x), grad
|
||||
```
|
||||
|
||||
#### 2. CTC Loss Implementation
|
||||
Custom sparse tensor conversion for TPU compatibility:
|
||||
```python
|
||||
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||||
"""Convert dense tensor to sparse tensor for CTC"""
|
||||
mask = tf.not_equal(dense_tensor, 0)
|
||||
indices = tf.where(mask)
|
||||
values = tf.gather_nd(dense_tensor, indices)
|
||||
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
|
||||
```
|
||||
|
||||
#### 3. Day-Specific Layers
|
||||
Using `add_weight()` for TPU-compatible variable management:
|
||||
```python
|
||||
for i in range(n_days):
|
||||
weight = self.add_weight(
|
||||
name=f'day_weight_{i}',
|
||||
shape=(neural_dim, neural_dim),
|
||||
initializer=tf.keras.initializers.Identity(),
|
||||
trainable=True
|
||||
)
|
||||
```
|
||||
|
||||
### Training on TPU v5e-8
|
||||
|
||||
#### Basic Training Command
|
||||
```python
|
||||
# In Kaggle TPU v5e-8 environment
|
||||
python train_model_tf.py
|
||||
```
|
||||
|
||||
#### Expected Output
|
||||
```
|
||||
🔍 Detecting TPU environment...
|
||||
📍 Kaggle TPU detected, worker ID: 0, address: grpc://10.0.0.2:8470
|
||||
✅ TPU initialized successfully!
|
||||
🎉 Number of TPU cores: 8
|
||||
Training on 8 TPU cores # Should show 8 cores, not 1
|
||||
```
|
||||
|
||||
### Performance Benefits
|
||||
|
||||
1. **Multi-Core Utilization**: Properly configured TPU strategy utilizes all 8 TPU v5e-8 cores
|
||||
2. **Mixed Precision**: bfloat16 precision optimized for TPU matrix units
|
||||
3. **XLA Compilation**: Static operations enable efficient XLA graph compilation
|
||||
4. **Memory Efficiency**: Optimized for TPU memory constraints and batch processing
|
||||
|
||||
### Common Issues and Solutions
|
||||
|
||||
#### Issue: "Training on 1 TPU cores" instead of 8
|
||||
**Cause**: TPU connection fallback to default strategy
|
||||
**Solution**: Enhanced `create_tpu_strategy()` function with environment detection
|
||||
**Check**: Verify TPU environment variables are properly set
|
||||
|
||||
#### Issue: CTC Loss dtype errors
|
||||
**Cause**: Mixed precision dtype mismatches
|
||||
**Solution**: Explicit dtype casting in `CTCLoss.call()`
|
||||
|
||||
#### Issue: Gradient Reversal Layer errors
|
||||
**Cause**: Incorrect gradient return format
|
||||
**Solution**: Return only gradient w.r.t. input tensor, not lambda parameter
|
||||
|
||||
## Competition Context
|
||||
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.
|
||||
This codebase serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing both PyTorch and TensorFlow reference implementations for neural signal decoding with optimizations for TPU v5e-8 training environments.
|
@@ -763,6 +763,14 @@ def create_tpu_strategy():
|
||||
|
||||
print("🔍 Detecting TPU environment...")
|
||||
|
||||
# Disable GPU to avoid CUDA conflicts in TPU environment
|
||||
try:
|
||||
print("🚫 Disabling GPU to prevent CUDA conflicts...")
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
print("✅ GPU disabled successfully")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Could not disable GPU: {e}")
|
||||
|
||||
# Check for various TPU environment variables
|
||||
tpu_address = None
|
||||
tpu_name = None
|
||||
|
Reference in New Issue
Block a user