fixed : tf call cuda

This commit is contained in:
Zchen
2025-10-15 23:37:24 +08:00
parent 01024678c1
commit f9d3f47d20
2 changed files with 149 additions and 1 deletions

142
CLAUDE.md
View File

@@ -449,5 +449,145 @@ The deprecated APIs still work but generate warnings. For production code:
- Test thoroughly as synchronization behavior may differ slightly
- Legacy code will continue to function until removed in future versions
## TensorFlow TPU Implementation
The original PyTorch implementation has been converted to TensorFlow for optimal performance on TPU v5e-8 environments, particularly for the Brain-to-Text '25 Competition on Kaggle.
### Key TensorFlow Components (`model_training_nnn_tpu/`)
#### Core Files
- **`rnn_model_tf.py`**: TensorFlow implementation of TripleGRUDecoder architecture
- `NoiseModel`: 2-layer GRU for noise estimation with day-specific layers
- `CleanSpeechModel`: 3-layer GRU for clean speech recognition with day-specific layers
- `NoisySpeechModel`: 2-layer GRU for noisy speech recognition (no day layers)
- `TripleGRUDecoder`: Main adversarial architecture combining all three models
- `CTCLoss`: Custom CTC loss implementation for TPU compatibility
- `create_tpu_strategy()`: Enhanced TPU connection function with robust environment detection
- **`trainer_tf.py`**: TensorFlow training pipeline with distributed TPU support
- **`dataset_tf.py`**: TensorFlow data loading with augmentation pipeline optimized for TPU
- **`train_model_tf.py`**: Main training script entry point
- **`evaluate_model_tf.py`**: Evaluation pipeline for model performance analysis
### TPU v5e-8 Specific Optimizations
#### 1. Enhanced TPU Connection
The `create_tpu_strategy()` function provides robust TPU detection across different environments:
```python
def create_tpu_strategy():
"""Create TPU strategy for distributed training on TPU v5e-8"""
# Multi-environment TPU detection
if 'COLAB_TPU_ADDR' in os.environ:
tpu_address = os.environ['COLAB_TPU_ADDR']
elif 'TPU_NAME' in os.environ:
tpu_name = os.environ['TPU_NAME']
elif 'TPU_WORKER_ID' in os.environ:
# Kaggle TPU environment
tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address
# Enhanced error handling and debugging output
# Fallback to default strategy if TPU connection fails
```
**Environment Variables Detected**:
- `COLAB_TPU_ADDR`: Google Colab TPU environment
- `TPU_NAME`: Generic TPU name specification
- `TPU_WORKER_ID`: Kaggle TPU environment indicator
**Troubleshooting TPU Connection Issues**:
- Error: "Failed to initialize TPU: Please provide a TPU Name to connect to."
- Solution: The function automatically detects and uses appropriate TPU addresses based on environment
- Debugging: All TPU-related environment variables are printed during initialization
#### 2. Mixed Precision Training
Configured for optimal TPU v5e-8 performance:
```python
def configure_mixed_precision():
"""Configure mixed precision for optimal TPU v5e-8 performance"""
policy = keras.mixed_precision.Policy('mixed_bfloat16')
keras.mixed_precision.set_global_policy(policy)
```
#### 3. XLA-Optimized Operations
- **Static Tensor Operations**: Using `tf.stack()` and `tf.gather()` instead of dynamic indexing
- **Efficient Matrix Operations**: `tf.linalg.matmul()` for batch matrix multiplication
- **TPU-Friendly GRU Layers**: Disabled recurrent dropout for better TPU performance
- **Patch Processing**: TensorFlow equivalent of PyTorch's unfold using `tf.image.extract_patches()`
### Key Architecture Differences from PyTorch
#### 1. Gradient Reversal Layer (GRL)
```python
@tf.custom_gradient
def gradient_reverse(x, lambd=1.0):
"""Gradient Reversal Layer for TensorFlow"""
def grad(dy):
return -lambd * dy # Only return gradient w.r.t. x
return tf.identity(x), grad
```
#### 2. CTC Loss Implementation
Custom sparse tensor conversion for TPU compatibility:
```python
def dense_to_sparse(dense_tensor, sequence_lengths):
"""Convert dense tensor to sparse tensor for CTC"""
mask = tf.not_equal(dense_tensor, 0)
indices = tf.where(mask)
values = tf.gather_nd(dense_tensor, indices)
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
```
#### 3. Day-Specific Layers
Using `add_weight()` for TPU-compatible variable management:
```python
for i in range(n_days):
weight = self.add_weight(
name=f'day_weight_{i}',
shape=(neural_dim, neural_dim),
initializer=tf.keras.initializers.Identity(),
trainable=True
)
```
### Training on TPU v5e-8
#### Basic Training Command
```python
# In Kaggle TPU v5e-8 environment
python train_model_tf.py
```
#### Expected Output
```
🔍 Detecting TPU environment...
📍 Kaggle TPU detected, worker ID: 0, address: grpc://10.0.0.2:8470
✅ TPU initialized successfully!
🎉 Number of TPU cores: 8
Training on 8 TPU cores # Should show 8 cores, not 1
```
### Performance Benefits
1. **Multi-Core Utilization**: Properly configured TPU strategy utilizes all 8 TPU v5e-8 cores
2. **Mixed Precision**: bfloat16 precision optimized for TPU matrix units
3. **XLA Compilation**: Static operations enable efficient XLA graph compilation
4. **Memory Efficiency**: Optimized for TPU memory constraints and batch processing
### Common Issues and Solutions
#### Issue: "Training on 1 TPU cores" instead of 8
**Cause**: TPU connection fallback to default strategy
**Solution**: Enhanced `create_tpu_strategy()` function with environment detection
**Check**: Verify TPU environment variables are properly set
#### Issue: CTC Loss dtype errors
**Cause**: Mixed precision dtype mismatches
**Solution**: Explicit dtype casting in `CTCLoss.call()`
#### Issue: Gradient Reversal Layer errors
**Cause**: Incorrect gradient return format
**Solution**: Return only gradient w.r.t. input tensor, not lambda parameter
## Competition Context
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.
This codebase serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing both PyTorch and TensorFlow reference implementations for neural signal decoding with optimizations for TPU v5e-8 training environments.

View File

@@ -763,6 +763,14 @@ def create_tpu_strategy():
print("🔍 Detecting TPU environment...")
# Disable GPU to avoid CUDA conflicts in TPU environment
try:
print("🚫 Disabling GPU to prevent CUDA conflicts...")
tf.config.set_visible_devices([], 'GPU')
print("✅ GPU disabled successfully")
except Exception as e:
print(f"⚠️ Warning: Could not disable GPU: {e}")
# Check for various TPU environment variables
tpu_address = None
tpu_name = None