#!/bin/bash # TPU XLA Multi-threading Environment Setup # Set these BEFORE starting Python to ensure they take effect echo "Setting up XLA multi-threading environment..." # Get CPU core count CPU_CORES=$(nproc) echo "Detected $CPU_CORES CPU cores" # Set XLA compilation flags export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES" export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES # Additional XLA optimizations export XLA_USE_BF16=1 export TPU_CORES=8 # Print current settings echo "XLA_FLAGS: $XLA_FLAGS" echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS" echo "XLA_USE_BF16: $XLA_USE_BF16" # Start training echo "Starting TPU training..." python train_model.py --config_path rnn_args.yaml