HBM
This commit is contained in:
		| @@ -81,12 +81,24 @@ class BrainToTextDecoderTrainerTF: | ||||
|         # Initialize datasets | ||||
|         self._initialize_datasets() | ||||
|  | ||||
|  | ||||
|         # Build model within strategy scope | ||||
|         with self.strategy.scope(): | ||||
|             print("🔨 Building model within TPU strategy scope...") | ||||
|             self.model = self._build_model() | ||||
|             print("✅ Model built successfully") | ||||
|  | ||||
|             print("⚙️ Creating optimizer...") | ||||
|             self.optimizer = self._create_optimizer() | ||||
|             print("✅ Optimizer created") | ||||
|  | ||||
|             print("📅 Setting up learning rate scheduler...") | ||||
|             self.lr_scheduler = self._create_lr_scheduler() | ||||
|             print("✅ LR scheduler ready") | ||||
|  | ||||
|             print("🎯 Initializing CTC loss...") | ||||
|             self.ctc_loss = CTCLoss(blank_index=0, reduction='none') | ||||
|             print("✅ CTC loss initialized") | ||||
|  | ||||
|         # Log model information | ||||
|         self._log_model_info() | ||||
| @@ -154,7 +166,7 @@ class BrainToTextDecoderTrainerTF: | ||||
|         print(f"   This will accelerate data loading and preprocessing while TPU handles training") | ||||
|  | ||||
|     def _get_tpu_status(self) -> str: | ||||
|         """Get current TPU status and utilization info""" | ||||
|         """Get current TPU status and HBM utilization info""" | ||||
|         try: | ||||
|             # Get TPU devices | ||||
|             tpu_devices = tf.config.list_logical_devices('TPU') | ||||
| @@ -165,12 +177,29 @@ class BrainToTextDecoderTrainerTF: | ||||
|             # Get strategy info | ||||
|             num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 | ||||
|  | ||||
|             # Get memory usage (simplified) | ||||
|             import psutil | ||||
|             memory = psutil.virtual_memory() | ||||
|             # Try to get TPU memory info (HBM) | ||||
|             try: | ||||
|                 # Attempt to get TPU memory usage for each device | ||||
|                 memory_info = tf.config.experimental.get_memory_info('/TPU:0') | ||||
|                 if memory_info and 'current' in memory_info: | ||||
|                     current_mb = memory_info['current'] // (1024 * 1024) | ||||
|                     peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) | ||||
|                     hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)" | ||||
|                 else: | ||||
|                     hbm_info = "HBM: unknown" | ||||
|             except Exception: | ||||
|                 # Fallback: simple TPU activity check | ||||
|                 try: | ||||
|                     # Test TPU responsiveness | ||||
|                     with tf.device('/TPU:0'): | ||||
|                         test_tensor = tf.constant([1.0, 2.0]) | ||||
|                         _ = tf.reduce_sum(test_tensor) | ||||
|                     hbm_info = "HBM: active" | ||||
|                 except Exception: | ||||
|                     hbm_info = "HBM: inactive" | ||||
|  | ||||
|             return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores " | ||||
|                    f"RAM: {memory.percent:.1f}%") | ||||
|                    f"{hbm_info}") | ||||
|  | ||||
|         except Exception as e: | ||||
|             return f"TPU: status_error({str(e)[:20]})" | ||||
| @@ -188,9 +217,19 @@ class BrainToTextDecoderTrainerTF: | ||||
|             num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 | ||||
|             strategy_type = type(self.strategy).__name__ | ||||
|  | ||||
|             # Get memory info | ||||
|             import psutil | ||||
|             memory = psutil.virtual_memory() | ||||
|             # Get TPU HBM memory info | ||||
|             try: | ||||
|                 memory_info = tf.config.experimental.get_memory_info('/TPU:0') | ||||
|                 if memory_info and 'current' in memory_info: | ||||
|                     current_gb = memory_info['current'] // (1024 * 1024 * 1024) | ||||
|                     peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024) | ||||
|                     # TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB | ||||
|                     estimated_total_gb = 32 * len(tpu_devices) | ||||
|                     hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)" | ||||
|                 else: | ||||
|                     hbm_usage = "HBM: unknown" | ||||
|             except Exception: | ||||
|                 hbm_usage = "HBM: unavailable" | ||||
|  | ||||
|             # Simple TPU test | ||||
|             try: | ||||
| @@ -204,7 +243,7 @@ class BrainToTextDecoderTrainerTF: | ||||
|             return (f"TPU Devices: {len(tpu_devices)} | " | ||||
|                    f"Strategy: {strategy_type} | " | ||||
|                    f"Cores: {num_replicas} | " | ||||
|                    f"RAM: {memory.percent:.1f}% ({memory.used//1024//1024//1024}GB/{memory.total//1024//1024//1024}GB) | " | ||||
|                    f"{hbm_usage} | " | ||||
|                    f"Test: {tpu_test}") | ||||
|  | ||||
|         except Exception as e: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Zchen
					Zchen