27 lines
800 B
Bash
27 lines
800 B
Bash
#!/bin/bash
|
|
|
|
# TPU XLA Multi-threading Environment Setup
|
|
# Set these BEFORE starting Python to ensure they take effect
|
|
|
|
echo "Setting up XLA multi-threading environment..."
|
|
|
|
# Get CPU core count
|
|
CPU_CORES=$(nproc)
|
|
echo "Detected $CPU_CORES CPU cores"
|
|
|
|
# Set XLA compilation flags
|
|
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES"
|
|
export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES
|
|
|
|
# Additional XLA optimizations
|
|
export XLA_USE_BF16=1
|
|
export TPU_CORES=8
|
|
|
|
# Print current settings
|
|
echo "XLA_FLAGS: $XLA_FLAGS"
|
|
echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS"
|
|
echo "XLA_USE_BF16: $XLA_USE_BF16"
|
|
|
|
# Start training
|
|
echo "Starting TPU training..."
|
|
python train_model.py --config_path rnn_args.yaml |