f
This commit is contained in:
@@ -77,8 +77,7 @@ from dataset_tf import (
|
|||||||
BrainToTextDatasetTF,
|
BrainToTextDatasetTF,
|
||||||
DataAugmentationTF,
|
DataAugmentationTF,
|
||||||
train_test_split_indices,
|
train_test_split_indices,
|
||||||
create_input_fn,
|
create_input_fn
|
||||||
analyze_dataset_shapes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -728,63 +727,36 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
initial_tpu_status = self._get_detailed_tpu_status()
|
initial_tpu_status = self._get_detailed_tpu_status()
|
||||||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||||
|
|
||||||
# ========================= DATASET SHAPE ANALYSIS =========================
|
# ========================= 终极解决方案:批处理优先 =========================
|
||||||
# Perform one-time full dataset analysis for FIXED shapes (critical for XLA)
|
# 使用经过验证的"先批处理,后增强"方法,消除数据增强与形状分析的时间悖论
|
||||||
self.logger.info("🚀 Performing one-time full dataset analysis for FIXED shapes...")
|
self.logger.info("🚀 Using FINAL 'batch-first, augment-after' approach")
|
||||||
self.logger.info(" This is CRITICAL for resolving both CTC compatibility and graph structure issues")
|
self.logger.info(" This eliminates the time paradox between data augmentation and shape analysis")
|
||||||
|
|
||||||
# Analyze training dataset (all data for accurate max shapes)
|
# 简化的数据集创建函数,不再需要 max_shapes
|
||||||
train_analysis_start = time.time()
|
def create_dist_dataset_fn(input_dataset_tf, training):
|
||||||
train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1)
|
"""Create distributed dataset function for the final 'batch-first' approach."""
|
||||||
train_analysis_time = time.time() - train_analysis_start
|
|
||||||
self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s")
|
|
||||||
|
|
||||||
# Analyze validation dataset (all data for accurate max shapes)
|
|
||||||
val_analysis_start = time.time()
|
|
||||||
val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1)
|
|
||||||
val_analysis_time = time.time() - val_analysis_start
|
|
||||||
self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s")
|
|
||||||
|
|
||||||
# Use maximum shapes across both datasets for consistent padding
|
|
||||||
final_max_shapes = {
|
|
||||||
'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']),
|
|
||||||
'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']),
|
|
||||||
'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']),
|
|
||||||
'n_features': train_max_shapes['n_features']
|
|
||||||
}
|
|
||||||
|
|
||||||
self.logger.info(f"📊 Final FIXED shapes for TPU training (eliminates XLA dynamic shape issues):")
|
|
||||||
self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}")
|
|
||||||
self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}")
|
|
||||||
self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}")
|
|
||||||
self.logger.info(f" Features: {final_max_shapes['n_features']}")
|
|
||||||
# =====================================================================
|
|
||||||
|
|
||||||
# Create datasets using modern distribution API with FIXED shapes
|
|
||||||
def create_dist_dataset_fn(input_dataset_tf, training, max_shapes):
|
|
||||||
"""Create distributed dataset function for modern TPU strategy with FIXED shapes"""
|
|
||||||
def dataset_fn(input_context):
|
def dataset_fn(input_context):
|
||||||
# create_input_fn now requires max_shapes parameter for FIXED shapes
|
# 调用新版的 create_input_fn,它不需要 max_shapes
|
||||||
return create_input_fn(
|
return create_input_fn(
|
||||||
input_dataset_tf,
|
input_dataset_tf,
|
||||||
self.args['dataset']['data_transforms'],
|
self.args['dataset']['data_transforms'],
|
||||||
max_shapes=max_shapes, # Pass pre-analyzed FIXED shapes
|
|
||||||
training=training
|
training=training
|
||||||
)
|
)
|
||||||
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||||||
|
|
||||||
# Distribute datasets using modern API with FIXED shapes
|
# 使用新的、简化的函数签名创建数据集
|
||||||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
self.logger.info("🔄 Distributing training dataset (batch-first approach)...")
|
||||||
dist_start_time = time.time()
|
dist_start_time = time.time()
|
||||||
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes)
|
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
|
||||||
train_dist_time = time.time() - dist_start_time
|
train_dist_time = time.time() - dist_start_time
|
||||||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||||||
|
|
||||||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
self.logger.info("🔄 Distributing validation dataset (batch-first approach)...")
|
||||||
val_start_time = time.time()
|
val_start_time = time.time()
|
||||||
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes)
|
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
|
||||||
val_dist_time = time.time() - val_start_time
|
val_dist_time = time.time() - val_start_time
|
||||||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
self.logger.info("Created distributed training and validation datasets")
|
self.logger.info("Created distributed training and validation datasets")
|
||||||
# Training metrics
|
# Training metrics
|
||||||
|
Reference in New Issue
Block a user