This commit is contained in:
Zchen
2025-10-21 01:07:57 +08:00
parent ab12d0b7ee
commit a031972ba6

View File

@@ -77,8 +77,7 @@ from dataset_tf import (
BrainToTextDatasetTF, BrainToTextDatasetTF,
DataAugmentationTF, DataAugmentationTF,
train_test_split_indices, train_test_split_indices,
create_input_fn, create_input_fn
analyze_dataset_shapes
) )
@@ -728,63 +727,36 @@ class BrainToTextDecoderTrainerTF:
initial_tpu_status = self._get_detailed_tpu_status() initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}") self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# ========================= DATASET SHAPE ANALYSIS ========================= # ========================= 终极解决方案:批处理优先 =========================
# Perform one-time full dataset analysis for FIXED shapes (critical for XLA) # 使用经过验证的"先批处理,后增强"方法,消除数据增强与形状分析的时间悖论
self.logger.info("🚀 Performing one-time full dataset analysis for FIXED shapes...") self.logger.info("🚀 Using FINAL 'batch-first, augment-after' approach")
self.logger.info(" This is CRITICAL for resolving both CTC compatibility and graph structure issues") self.logger.info(" This eliminates the time paradox between data augmentation and shape analysis")
# Analyze training dataset (all data for accurate max shapes) # 简化的数据集创建函数,不再需要 max_shapes
train_analysis_start = time.time() def create_dist_dataset_fn(input_dataset_tf, training):
train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1) """Create distributed dataset function for the final 'batch-first' approach."""
train_analysis_time = time.time() - train_analysis_start
self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s")
# Analyze validation dataset (all data for accurate max shapes)
val_analysis_start = time.time()
val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1)
val_analysis_time = time.time() - val_analysis_start
self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s")
# Use maximum shapes across both datasets for consistent padding
final_max_shapes = {
'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']),
'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']),
'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']),
'n_features': train_max_shapes['n_features']
}
self.logger.info(f"📊 Final FIXED shapes for TPU training (eliminates XLA dynamic shape issues):")
self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}")
self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}")
self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}")
self.logger.info(f" Features: {final_max_shapes['n_features']}")
# =====================================================================
# Create datasets using modern distribution API with FIXED shapes
def create_dist_dataset_fn(input_dataset_tf, training, max_shapes):
"""Create distributed dataset function for modern TPU strategy with FIXED shapes"""
def dataset_fn(input_context): def dataset_fn(input_context):
# create_input_fn now requires max_shapes parameter for FIXED shapes # 调用新版的 create_input_fn,它不需要 max_shapes
return create_input_fn( return create_input_fn(
input_dataset_tf, input_dataset_tf,
self.args['dataset']['data_transforms'], self.args['dataset']['data_transforms'],
max_shapes=max_shapes, # Pass pre-analyzed FIXED shapes
training=training training=training
) )
return self.strategy.distribute_datasets_from_function(dataset_fn) return self.strategy.distribute_datasets_from_function(dataset_fn)
# Distribute datasets using modern API with FIXED shapes # 使用新的、简化的函数签名创建数据集
self.logger.info("🔄 Distributing training dataset across TPU cores...") self.logger.info("🔄 Distributing training dataset (batch-first approach)...")
dist_start_time = time.time() dist_start_time = time.time()
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes) train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
train_dist_time = time.time() - dist_start_time train_dist_time = time.time() - dist_start_time
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
self.logger.info("🔄 Distributing validation dataset across TPU cores...") self.logger.info("🔄 Distributing validation dataset (batch-first approach)...")
val_start_time = time.time() val_start_time = time.time()
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes) val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
val_dist_time = time.time() - val_start_time val_dist_time = time.time() - val_start_time
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s") self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
# =====================================================================
self.logger.info("Created distributed training and validation datasets") self.logger.info("Created distributed training and validation datasets")
# Training metrics # Training metrics