diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index b5f92b1..cc6f7ba 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -77,8 +77,7 @@ from dataset_tf import ( BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices, - create_input_fn, - analyze_dataset_shapes + create_input_fn ) @@ -728,63 +727,36 @@ class BrainToTextDecoderTrainerTF: initial_tpu_status = self._get_detailed_tpu_status() self.logger.info(f"Initial TPU Status: {initial_tpu_status}") - # ========================= DATASET SHAPE ANALYSIS ========================= - # Perform one-time full dataset analysis for FIXED shapes (critical for XLA) - self.logger.info("๐Ÿš€ Performing one-time full dataset analysis for FIXED shapes...") - self.logger.info(" This is CRITICAL for resolving both CTC compatibility and graph structure issues") + # ========================= ็ปˆๆž่งฃๅ†ณๆ–นๆกˆ๏ผšๆ‰นๅค„็†ไผ˜ๅ…ˆ ========================= + # ไฝฟ็”จ็ป่ฟ‡้ชŒ่ฏ็š„"ๅ…ˆๆ‰นๅค„็†๏ผŒๅŽๅขžๅผบ"ๆ–นๆณ•๏ผŒๆถˆ้™คๆ•ฐๆฎๅขžๅผบไธŽๅฝข็Šถๅˆ†ๆž็š„ๆ—ถ้—ดๆ‚–่ฎบ + self.logger.info("๐Ÿš€ Using FINAL 'batch-first, augment-after' approach") + self.logger.info(" This eliminates the time paradox between data augmentation and shape analysis") - # Analyze training dataset (all data for accurate max shapes) - train_analysis_start = time.time() - train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1) - train_analysis_time = time.time() - train_analysis_start - self.logger.info(f"โœ… Training dataset analysis completed in {train_analysis_time:.2f}s") - - # Analyze validation dataset (all data for accurate max shapes) - val_analysis_start = time.time() - val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1) - val_analysis_time = time.time() - val_analysis_start - self.logger.info(f"โœ… Validation dataset analysis completed in {val_analysis_time:.2f}s") - - # Use maximum shapes across both datasets for consistent padding - final_max_shapes = { - 'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']), - 'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']), - 'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']), - 'n_features': train_max_shapes['n_features'] - } - - self.logger.info(f"๐Ÿ“Š Final FIXED shapes for TPU training (eliminates XLA dynamic shape issues):") - self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}") - self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}") - self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}") - self.logger.info(f" Features: {final_max_shapes['n_features']}") - # ===================================================================== - - # Create datasets using modern distribution API with FIXED shapes - def create_dist_dataset_fn(input_dataset_tf, training, max_shapes): - """Create distributed dataset function for modern TPU strategy with FIXED shapes""" + # ็ฎ€ๅŒ–็š„ๆ•ฐๆฎ้›†ๅˆ›ๅปบๅ‡ฝๆ•ฐ๏ผŒไธๅ†้œ€่ฆ max_shapes + def create_dist_dataset_fn(input_dataset_tf, training): + """Create distributed dataset function for the final 'batch-first' approach.""" def dataset_fn(input_context): - # create_input_fn now requires max_shapes parameter for FIXED shapes + # ่ฐƒ็”จๆ–ฐ็‰ˆ็š„ create_input_fn๏ผŒๅฎƒไธ้œ€่ฆ max_shapes return create_input_fn( input_dataset_tf, self.args['dataset']['data_transforms'], - max_shapes=max_shapes, # Pass pre-analyzed FIXED shapes training=training ) return self.strategy.distribute_datasets_from_function(dataset_fn) - # Distribute datasets using modern API with FIXED shapes - self.logger.info("๐Ÿ”„ Distributing training dataset across TPU cores...") + # ไฝฟ็”จๆ–ฐ็š„ใ€็ฎ€ๅŒ–็š„ๅ‡ฝๆ•ฐ็ญพๅๅˆ›ๅปบๆ•ฐๆฎ้›† + self.logger.info("๐Ÿ”„ Distributing training dataset (batch-first approach)...") dist_start_time = time.time() - train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes) + train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True) train_dist_time = time.time() - dist_start_time self.logger.info(f"โœ… Training dataset distributed in {train_dist_time:.2f}s") - self.logger.info("๐Ÿ”„ Distributing validation dataset across TPU cores...") + self.logger.info("๐Ÿ”„ Distributing validation dataset (batch-first approach)...") val_start_time = time.time() - val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes) + val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False) val_dist_time = time.time() - val_start_time self.logger.info(f"โœ… Validation dataset distributed in {val_dist_time:.2f}s") + # ===================================================================== self.logger.info("Created distributed training and validation datasets") # Training metrics