ff
This commit is contained in:
@@ -889,35 +889,32 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
# Utility functions for TPU-optimized data pipeline
|
||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Dict[str, Any],
|
||||
max_shapes: Dict[str, int],
|
||||
training: bool = True,
|
||||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||
"""
|
||||
Create input function for TPU training with BATCH-FIRST approach
|
||||
Create input function for TPU training with PRE-ANALYZED FIXED shapes
|
||||
|
||||
This function implements the correct TPU data pipeline:
|
||||
1. Load individual samples
|
||||
2. Cache raw samples
|
||||
3. Batch samples with dynamic padding
|
||||
4. Apply data augmentation to entire batches (AFTER batching)
|
||||
|
||||
This approach prevents shape conflicts from augmentation operations
|
||||
like random_cut that would otherwise make tensor shapes dynamic before batching.
|
||||
This function uses pre-computed maximum shapes to create STATIC-size batches,
|
||||
ensuring XLA compilation success on TPU hardware. This is CRITICAL for the
|
||||
final resolution of both CTC loss compatibility and graph structure issues.
|
||||
|
||||
Args:
|
||||
dataset_tf: BrainToTextDatasetTF instance
|
||||
transform_args: Data transformation configuration
|
||||
max_shapes: Pre-computed maximum shapes dictionary with keys:
|
||||
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
|
||||
training: Whether this is for training (applies augmentations)
|
||||
cache_path: Optional path for disk caching to improve I/O performance
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset ready for TPU training with XLA-compatible operations
|
||||
tf.data.Dataset ready for TPU training with FIXED STATIC shapes
|
||||
"""
|
||||
|
||||
# Step 1: Create individual example dataset with file-grouping I/O optimization
|
||||
dataset = dataset_tf.create_individual_dataset()
|
||||
|
||||
# Step 2: Cache raw samples BEFORE any augmentation
|
||||
# ========================= I/O OPTIMIZATION SOLUTION =========================
|
||||
if cache_path:
|
||||
dataset = dataset.cache(cache_path)
|
||||
split_name = "training" if training else "validation"
|
||||
@@ -929,19 +926,55 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
split_name = "training" if training else "validation"
|
||||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
|
||||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
||||
# ================================================================
|
||||
|
||||
# Step 3: Batch samples with DYNAMIC padding (XLA-friendly for variable input sizes)
|
||||
print(f"🔧 Using DYNAMIC padding for XLA compatibility:")
|
||||
# Step 3: Apply transformations to individual examples BEFORE batching
|
||||
def apply_transforms(example):
|
||||
"""Apply data transformations to individual examples"""
|
||||
features = example['input_features']
|
||||
n_time_steps = example['n_time_steps']
|
||||
|
||||
# Define padded shapes with None for dynamic dimensions
|
||||
# Apply transformations
|
||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||
tf.expand_dims(features, 0), # Add batch dimension for transforms
|
||||
tf.expand_dims(n_time_steps, 0),
|
||||
transform_args,
|
||||
training=training
|
||||
)
|
||||
|
||||
# Remove batch dimension
|
||||
example['input_features'] = tf.squeeze(features, 0)
|
||||
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
|
||||
|
||||
return example
|
||||
|
||||
# Apply transforms to cached data
|
||||
dataset = dataset.map(
|
||||
apply_transforms,
|
||||
num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
# Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA)
|
||||
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
|
||||
|
||||
# Extract pre-analyzed shape information
|
||||
max_time_steps = max_shapes['max_time_steps']
|
||||
max_phone_seq_len = max_shapes['max_phone_seq_len']
|
||||
max_transcription_len = max_shapes['max_transcription_len']
|
||||
n_features = max_shapes['n_features']
|
||||
|
||||
print(f" Fixed time steps: {max_time_steps}")
|
||||
print(f" Fixed phone sequence length: {max_phone_seq_len}")
|
||||
print(f" Fixed transcription length: {max_transcription_len}")
|
||||
print(f" Number of features: {n_features}")
|
||||
|
||||
# Define FIXED padded shapes - NO None values for XLA compatibility
|
||||
padded_shapes = {
|
||||
'input_features': tf.TensorShape([None, None]), # [time_steps, features] - dynamic
|
||||
'seq_class_ids': tf.TensorShape([None]), # [phone_seq_len] - dynamic
|
||||
'input_features': tf.TensorShape([max_time_steps, n_features]),
|
||||
'seq_class_ids': tf.TensorShape([max_phone_seq_len]),
|
||||
'n_time_steps': tf.TensorShape([]), # scalar
|
||||
'phone_seq_lens': tf.TensorShape([]), # scalar
|
||||
'day_indices': tf.TensorShape([]), # scalar
|
||||
'transcriptions': tf.TensorShape([None]), # [transcription_len] - dynamic
|
||||
'transcriptions': tf.TensorShape([max_transcription_len]),
|
||||
'block_nums': tf.TensorShape([]), # scalar
|
||||
'trial_nums': tf.TensorShape([]) # scalar
|
||||
}
|
||||
@@ -958,7 +991,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
'trial_nums': 0
|
||||
}
|
||||
|
||||
# Create batches with dynamic padding
|
||||
# Create batches with FIXED padding - XLA compiler will be happy!
|
||||
dataset = dataset.padded_batch(
|
||||
batch_size=dataset_tf.batch_size,
|
||||
padded_shapes=padded_shapes,
|
||||
@@ -966,36 +999,6 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||||
)
|
||||
|
||||
# Step 4: Apply data augmentation to ENTIRE BATCHES (after batching)
|
||||
def apply_batch_transforms(batch):
|
||||
"""Apply data transformations to entire batches - CRITICAL for XLA compatibility"""
|
||||
features = batch['input_features']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
|
||||
# Apply transformations to the entire batch
|
||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||
features, # Already batched: [batch_size, time_steps, features]
|
||||
n_time_steps, # Already batched: [batch_size]
|
||||
transform_args,
|
||||
training=training
|
||||
)
|
||||
|
||||
# Update the batch with transformed data
|
||||
batch['input_features'] = features
|
||||
batch['n_time_steps'] = n_time_steps
|
||||
|
||||
return batch
|
||||
|
||||
# Apply batch-level transforms (only if training)
|
||||
if training:
|
||||
print(f"✅ Applying batch-level data augmentation (post-batching for XLA compatibility)")
|
||||
dataset = dataset.map(
|
||||
apply_batch_transforms,
|
||||
num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
else:
|
||||
print(f"✅ Validation mode: no data augmentation applied")
|
||||
|
||||
# Prefetch for optimal performance
|
||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras.backend as K
|
||||
import numpy as np
|
||||
import time
|
||||
import json
|
||||
@@ -27,10 +28,41 @@ from dataset_tf import (
|
||||
BrainToTextDatasetTF,
|
||||
DataAugmentationTF,
|
||||
train_test_split_indices,
|
||||
create_input_fn
|
||||
create_input_fn,
|
||||
analyze_dataset_shapes
|
||||
)
|
||||
|
||||
|
||||
def ctc_loss_for_tpu(y_true, y_pred, input_length, label_length):
|
||||
"""
|
||||
TPU-compatible CTC loss function using Keras backend
|
||||
|
||||
This implementation uses K.ctc_batch_cost which is often more robust
|
||||
for XLA compilation than tf.nn.ctc_loss, especially in complex model graphs.
|
||||
|
||||
Args:
|
||||
y_true: Dense labels [batch_size, max_label_len]
|
||||
y_pred: Logits [batch_size, time_steps, num_classes]
|
||||
input_length: Logit sequence lengths [batch_size]
|
||||
label_length: True label sequence lengths [batch_size]
|
||||
|
||||
Returns:
|
||||
Scalar CTC loss value
|
||||
"""
|
||||
# K.ctc_batch_cost requires logits to be time-major [time_steps, batch_size, num_classes]
|
||||
y_pred_time_major = tf.transpose(y_pred, [1, 0, 2])
|
||||
|
||||
# Ensure correct data types for Keras backend
|
||||
y_true = tf.cast(y_true, tf.float32) # K.ctc_batch_cost expects float32 labels
|
||||
input_length = tf.cast(input_length, tf.int32)
|
||||
label_length = tf.cast(label_length, tf.int32)
|
||||
|
||||
# Calculate CTC loss using Keras backend (more XLA-friendly)
|
||||
loss = K.ctc_batch_cost(y_true, y_pred_time_major, input_length, label_length)
|
||||
|
||||
return tf.reduce_mean(loss)
|
||||
|
||||
|
||||
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||||
"""
|
||||
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
|
||||
@@ -592,34 +624,23 @@ class BrainToTextDecoderTrainerTF:
|
||||
features, day_indices, None, False, 'inference', training=True
|
||||
)
|
||||
|
||||
# Calculate losses
|
||||
# Calculate losses using TPU-compatible CTC implementation
|
||||
if use_full:
|
||||
# Convert dense labels to sparse for dynamic shapes
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
|
||||
# Clean CTC loss - will auto-fallback to CPU with soft device placement
|
||||
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||
clean_loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
logits=clean_logits_time_major,
|
||||
label_length=None, # Not needed with sparse labels
|
||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||
blank_index=0,
|
||||
logits_time_major=True
|
||||
# Clean CTC loss - using Keras backend for XLA compatibility
|
||||
clean_loss = ctc_loss_for_tpu(
|
||||
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
|
||||
y_pred=clean_logits,
|
||||
input_length=adjusted_lens,
|
||||
label_length=phone_seq_lens
|
||||
)
|
||||
clean_loss = tf.reduce_mean(clean_loss)
|
||||
|
||||
# Noisy CTC loss - will auto-fallback to CPU with soft device placement
|
||||
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
||||
noisy_loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels, # Reuse same sparse labels
|
||||
logits=noisy_logits_time_major,
|
||||
label_length=None, # Not needed with sparse labels
|
||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||
blank_index=0,
|
||||
logits_time_major=True
|
||||
# Noisy CTC loss - using Keras backend for XLA compatibility
|
||||
noisy_loss = ctc_loss_for_tpu(
|
||||
y_true=tf.cast(labels, tf.float32), # Reuse same dense labels
|
||||
y_pred=noisy_logits,
|
||||
input_length=adjusted_lens,
|
||||
label_length=phone_seq_lens
|
||||
)
|
||||
noisy_loss = tf.reduce_mean(noisy_loss)
|
||||
|
||||
# Optional noise L2 regularization
|
||||
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
|
||||
@@ -628,20 +649,13 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||
else:
|
||||
# Convert dense labels to sparse for dynamic shapes
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
|
||||
# Standard CTC loss - will auto-fallback to CPU with soft device placement
|
||||
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
logits=logits_time_major,
|
||||
label_length=None, # Not needed with sparse labels
|
||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||
blank_index=0,
|
||||
logits_time_major=True
|
||||
# Standard CTC loss - using Keras backend for XLA compatibility
|
||||
loss = ctc_loss_for_tpu(
|
||||
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
|
||||
y_pred=clean_logits,
|
||||
input_length=adjusted_lens,
|
||||
label_length=phone_seq_lens
|
||||
)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
||||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||||
@@ -696,20 +710,13 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Forward pass (inference mode only)
|
||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||
|
||||
# Convert dense labels to sparse for dynamic shapes
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
|
||||
# Calculate loss - will auto-fallback to CPU with soft device placement
|
||||
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
logits=logits_time_major,
|
||||
label_length=None, # Not needed with sparse labels
|
||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||
blank_index=0,
|
||||
logits_time_major=True
|
||||
# Calculate loss using TPU-compatible CTC implementation
|
||||
loss = ctc_loss_for_tpu(
|
||||
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
|
||||
y_pred=logits,
|
||||
input_length=adjusted_lens,
|
||||
label_length=phone_seq_lens
|
||||
)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Greedy decoding for PER calculation
|
||||
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||
@@ -725,28 +732,61 @@ class BrainToTextDecoderTrainerTF:
|
||||
initial_tpu_status = self._get_detailed_tpu_status()
|
||||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||
|
||||
# Create datasets using modern distribution API with dynamic padding
|
||||
def create_dist_dataset_fn(input_dataset_tf, training):
|
||||
"""Create distributed dataset function for modern TPU strategy with batch-first approach"""
|
||||
# ========================= DATASET SHAPE ANALYSIS =========================
|
||||
# Perform one-time full dataset analysis for FIXED shapes (critical for XLA)
|
||||
self.logger.info("🚀 Performing one-time full dataset analysis for FIXED shapes...")
|
||||
self.logger.info(" This is CRITICAL for resolving both CTC compatibility and graph structure issues")
|
||||
|
||||
# Analyze training dataset (all data for accurate max shapes)
|
||||
train_analysis_start = time.time()
|
||||
train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1)
|
||||
train_analysis_time = time.time() - train_analysis_start
|
||||
self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s")
|
||||
|
||||
# Analyze validation dataset (all data for accurate max shapes)
|
||||
val_analysis_start = time.time()
|
||||
val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1)
|
||||
val_analysis_time = time.time() - val_analysis_start
|
||||
self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s")
|
||||
|
||||
# Use maximum shapes across both datasets for consistent padding
|
||||
final_max_shapes = {
|
||||
'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']),
|
||||
'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']),
|
||||
'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']),
|
||||
'n_features': train_max_shapes['n_features']
|
||||
}
|
||||
|
||||
self.logger.info(f"📊 Final FIXED shapes for TPU training (eliminates XLA dynamic shape issues):")
|
||||
self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}")
|
||||
self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}")
|
||||
self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}")
|
||||
self.logger.info(f" Features: {final_max_shapes['n_features']}")
|
||||
# =====================================================================
|
||||
|
||||
# Create datasets using modern distribution API with FIXED shapes
|
||||
def create_dist_dataset_fn(input_dataset_tf, training, max_shapes):
|
||||
"""Create distributed dataset function for modern TPU strategy with FIXED shapes"""
|
||||
def dataset_fn(input_context):
|
||||
# create_input_fn now uses batch-first approach with dynamic padding
|
||||
# create_input_fn now requires max_shapes parameter for FIXED shapes
|
||||
return create_input_fn(
|
||||
input_dataset_tf,
|
||||
self.args['dataset']['data_transforms'],
|
||||
max_shapes=max_shapes, # Pass pre-analyzed FIXED shapes
|
||||
training=training
|
||||
)
|
||||
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
# Distribute datasets using modern API with batch-first approach
|
||||
# Distribute datasets using modern API with FIXED shapes
|
||||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
||||
dist_start_time = time.time()
|
||||
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
|
||||
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes)
|
||||
train_dist_time = time.time() - dist_start_time
|
||||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||||
|
||||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
||||
val_start_time = time.time()
|
||||
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
|
||||
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes)
|
||||
val_dist_time = time.time() - val_start_time
|
||||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||||
|
||||
|
Reference in New Issue
Block a user