Files
b2txt25/model_training_nnn_tpu/start_tpu_training.sh
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

27 lines
800 B
Bash

#!/bin/bash
# TPU XLA Multi-threading Environment Setup
# Set these BEFORE starting Python to ensure they take effect
echo "Setting up XLA multi-threading environment..."
# Get CPU core count
CPU_CORES=$(nproc)
echo "Detected $CPU_CORES CPU cores"
# Set XLA compilation flags
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES"
export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES
# Additional XLA optimizations
export XLA_USE_BF16=1
export TPU_CORES=8
# Print current settings
echo "XLA_FLAGS: $XLA_FLAGS"
echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS"
echo "XLA_USE_BF16: $XLA_USE_BF16"
# Start training
echo "Starting TPU training..."
python train_model.py --config_path rnn_args.yaml