Refactor input function to implement batch-first approach with dynamic padding and apply data augmentation post-batching for TPU compatibility
This commit is contained in:
@@ -889,34 +889,35 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
# Utility functions for TPU-optimized data pipeline
|
# Utility functions for TPU-optimized data pipeline
|
||||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||||
transform_args: Dict[str, Any],
|
transform_args: Dict[str, Any],
|
||||||
max_shapes: Dict[str, int],
|
|
||||||
training: bool = True,
|
training: bool = True,
|
||||||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
Create input function for TPU training with PRE-ANALYZED FIXED shapes
|
Create input function for TPU training with BATCH-FIRST approach
|
||||||
|
|
||||||
This function uses pre-computed maximum shapes to create fixed-size batches,
|
This function implements the correct TPU data pipeline:
|
||||||
ensuring XLA compilation success on TPU hardware.
|
1. Load individual samples
|
||||||
|
2. Cache raw samples
|
||||||
|
3. Batch samples with dynamic padding
|
||||||
|
4. Apply data augmentation to entire batches (AFTER batching)
|
||||||
|
|
||||||
|
This approach prevents shape conflicts from augmentation operations
|
||||||
|
like random_cut that would otherwise make tensor shapes dynamic before batching.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_tf: BrainToTextDatasetTF instance
|
dataset_tf: BrainToTextDatasetTF instance
|
||||||
transform_args: Data transformation configuration
|
transform_args: Data transformation configuration
|
||||||
max_shapes: Pre-computed maximum shapes dictionary with keys:
|
|
||||||
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
|
|
||||||
training: Whether this is for training (applies augmentations)
|
training: Whether this is for training (applies augmentations)
|
||||||
cache_path: Optional path for disk caching to improve I/O performance
|
cache_path: Optional path for disk caching to improve I/O performance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset ready for TPU training with fixed shapes
|
tf.data.Dataset ready for TPU training with XLA-compatible operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create individual example dataset with file-grouping I/O optimization
|
# Step 1: Create individual example dataset with file-grouping I/O optimization
|
||||||
dataset = dataset_tf.create_individual_dataset()
|
dataset = dataset_tf.create_individual_dataset()
|
||||||
|
|
||||||
|
# Step 2: Cache raw samples BEFORE any augmentation
|
||||||
# ========================= I/O OPTIMIZATION SOLUTION =========================
|
# ========================= I/O OPTIMIZATION SOLUTION =========================
|
||||||
# 对训练集和验证集都进行缓存,因为:
|
|
||||||
# 1. 训练集:每个epoch都要完整遍历
|
|
||||||
# 2. 验证集:每200轮验证一次 + 早停检查,会被频繁使用
|
|
||||||
if cache_path:
|
if cache_path:
|
||||||
dataset = dataset.cache(cache_path)
|
dataset = dataset.cache(cache_path)
|
||||||
split_name = "training" if training else "validation"
|
split_name = "training" if training else "validation"
|
||||||
@@ -924,63 +925,25 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
||||||
else:
|
else:
|
||||||
# 如果没有指定缓存路径,默认使用内存缓存
|
# 如果没有指定缓存路径,默认使用内存缓存
|
||||||
# 对于大型数据集,建议在调用时显式指定磁盘缓存路径
|
|
||||||
dataset = dataset.cache()
|
dataset = dataset.cache()
|
||||||
split_name = "training" if training else "validation"
|
split_name = "training" if training else "validation"
|
||||||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
|
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
|
||||||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
||||||
# ================================================================
|
# ================================================================
|
||||||
|
|
||||||
def apply_transforms(example):
|
# Step 3: Batch samples with DYNAMIC padding (XLA-friendly for variable input sizes)
|
||||||
"""Apply data transformations to individual examples"""
|
print(f"🔧 Using DYNAMIC padding for XLA compatibility:")
|
||||||
features = example['input_features']
|
|
||||||
n_time_steps = example['n_time_steps']
|
|
||||||
|
|
||||||
# Apply transformations
|
# Define padded shapes with None for dynamic dimensions
|
||||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
|
||||||
tf.expand_dims(features, 0), # Add batch dimension for transforms
|
|
||||||
tf.expand_dims(n_time_steps, 0),
|
|
||||||
transform_args,
|
|
||||||
training=training
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove batch dimension
|
|
||||||
example['input_features'] = tf.squeeze(features, 0)
|
|
||||||
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
|
|
||||||
|
|
||||||
return example
|
|
||||||
|
|
||||||
# 在缓存之后应用随机的数据增强,确保每个epoch的增强都不同
|
|
||||||
dataset = dataset.map(
|
|
||||||
apply_transforms,
|
|
||||||
num_parallel_calls=tf.data.AUTOTUNE
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========================= FIXED SHAPES SOLUTION =========================
|
|
||||||
# 使用预分析的固定形状确保 XLA 编译成功
|
|
||||||
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
|
|
||||||
|
|
||||||
# 从传入的参数中获取形状信息
|
|
||||||
max_time_steps = max_shapes['max_time_steps']
|
|
||||||
max_phone_seq_len = max_shapes['max_phone_seq_len']
|
|
||||||
max_transcription_len = max_shapes['max_transcription_len']
|
|
||||||
n_features = max_shapes['n_features']
|
|
||||||
|
|
||||||
print(f" Fixed time steps: {max_time_steps}")
|
|
||||||
print(f" Fixed phone sequence length: {max_phone_seq_len}")
|
|
||||||
print(f" Fixed transcription length: {max_transcription_len}")
|
|
||||||
print(f" Number of features: {n_features}")
|
|
||||||
|
|
||||||
# Define fixed padded shapes - NO None values for XLA compatibility
|
|
||||||
padded_shapes = {
|
padded_shapes = {
|
||||||
'input_features': tf.TensorShape([max_time_steps, n_features]),
|
'input_features': tf.TensorShape([None, None]), # [time_steps, features] - dynamic
|
||||||
'seq_class_ids': tf.TensorShape([max_phone_seq_len]),
|
'seq_class_ids': tf.TensorShape([None]), # [phone_seq_len] - dynamic
|
||||||
'n_time_steps': tf.TensorShape([]), # 标量
|
'n_time_steps': tf.TensorShape([]), # scalar
|
||||||
'phone_seq_lens': tf.TensorShape([]), # 标量
|
'phone_seq_lens': tf.TensorShape([]), # scalar
|
||||||
'day_indices': tf.TensorShape([]), # 标量
|
'day_indices': tf.TensorShape([]), # scalar
|
||||||
'transcriptions': tf.TensorShape([max_transcription_len]),
|
'transcriptions': tf.TensorShape([None]), # [transcription_len] - dynamic
|
||||||
'block_nums': tf.TensorShape([]), # 标量
|
'block_nums': tf.TensorShape([]), # scalar
|
||||||
'trial_nums': tf.TensorShape([]) # 标量
|
'trial_nums': tf.TensorShape([]) # scalar
|
||||||
}
|
}
|
||||||
|
|
||||||
# Define padding values for each field
|
# Define padding values for each field
|
||||||
@@ -995,7 +958,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
'trial_nums': 0
|
'trial_nums': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create batches with FIXED padding - XLA compiler will be happy!
|
# Create batches with dynamic padding
|
||||||
dataset = dataset.padded_batch(
|
dataset = dataset.padded_batch(
|
||||||
batch_size=dataset_tf.batch_size,
|
batch_size=dataset_tf.batch_size,
|
||||||
padded_shapes=padded_shapes,
|
padded_shapes=padded_shapes,
|
||||||
@@ -1003,6 +966,36 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Step 4: Apply data augmentation to ENTIRE BATCHES (after batching)
|
||||||
|
def apply_batch_transforms(batch):
|
||||||
|
"""Apply data transformations to entire batches - CRITICAL for XLA compatibility"""
|
||||||
|
features = batch['input_features']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
|
||||||
|
# Apply transformations to the entire batch
|
||||||
|
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||||
|
features, # Already batched: [batch_size, time_steps, features]
|
||||||
|
n_time_steps, # Already batched: [batch_size]
|
||||||
|
transform_args,
|
||||||
|
training=training
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the batch with transformed data
|
||||||
|
batch['input_features'] = features
|
||||||
|
batch['n_time_steps'] = n_time_steps
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# Apply batch-level transforms (only if training)
|
||||||
|
if training:
|
||||||
|
print(f"✅ Applying batch-level data augmentation (post-batching for XLA compatibility)")
|
||||||
|
dataset = dataset.map(
|
||||||
|
apply_batch_transforms,
|
||||||
|
num_parallel_calls=tf.data.AUTOTUNE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"✅ Validation mode: no data augmentation applied")
|
||||||
|
|
||||||
# Prefetch for optimal performance
|
# Prefetch for optimal performance
|
||||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
@@ -27,11 +27,53 @@ from dataset_tf import (
|
|||||||
BrainToTextDatasetTF,
|
BrainToTextDatasetTF,
|
||||||
DataAugmentationTF,
|
DataAugmentationTF,
|
||||||
train_test_split_indices,
|
train_test_split_indices,
|
||||||
create_input_fn,
|
create_input_fn
|
||||||
analyze_dataset_shapes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||||||
|
"""
|
||||||
|
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
|
||||||
|
sequence_lengths: Actual sequence lengths [batch_size]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SparseTensor suitable for tf.nn.ctc_loss
|
||||||
|
"""
|
||||||
|
# Create mask for valid (non-zero) elements within sequence lengths
|
||||||
|
batch_size = tf.shape(dense_tensor)[0]
|
||||||
|
max_seq_len = tf.shape(dense_tensor)[1]
|
||||||
|
|
||||||
|
# Create range indices
|
||||||
|
batch_indices = tf.range(batch_size)
|
||||||
|
seq_indices = tf.range(max_seq_len)
|
||||||
|
|
||||||
|
# Create meshgrid for batch and sequence dimensions
|
||||||
|
batch_mesh, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
|
||||||
|
|
||||||
|
# Create mask based on sequence lengths and non-zero values
|
||||||
|
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
|
||||||
|
value_mask = tf.not_equal(dense_tensor, 0)
|
||||||
|
combined_mask = tf.logical_and(length_mask, value_mask)
|
||||||
|
|
||||||
|
# Get indices of valid elements
|
||||||
|
indices = tf.where(combined_mask)
|
||||||
|
|
||||||
|
# Get values at valid indices
|
||||||
|
values = tf.gather_nd(dense_tensor, indices)
|
||||||
|
|
||||||
|
# Create sparse tensor
|
||||||
|
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
|
||||||
|
|
||||||
|
return tf.SparseTensor(
|
||||||
|
indices=tf.cast(indices, tf.int64),
|
||||||
|
values=tf.cast(values, tf.int32),
|
||||||
|
dense_shape=dense_shape
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BrainToTextDecoderTrainerTF:
|
class BrainToTextDecoderTrainerTF:
|
||||||
"""
|
"""
|
||||||
TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
|
TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
|
||||||
@@ -392,7 +434,8 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
import psutil
|
import psutil
|
||||||
initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
|
initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
print("🔄 Initializing training dataset with GPU-style memory management...")
|
print("🔄 Initializing training dataset with TPU-optimized memory management...")
|
||||||
|
print(" 🚀 Preloading all data to RAM for maximum parallel analysis speed...")
|
||||||
init_start_time = time.time()
|
init_start_time = time.time()
|
||||||
self.train_dataset_tf = BrainToTextDatasetTF(
|
self.train_dataset_tf = BrainToTextDatasetTF(
|
||||||
trial_indices=train_trials,
|
trial_indices=train_trials,
|
||||||
@@ -403,8 +446,8 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
random_seed=self.args['dataset']['seed'],
|
random_seed=self.args['dataset']['seed'],
|
||||||
must_include_days=self.args['dataset'].get('must_include_days'),
|
must_include_days=self.args['dataset'].get('must_include_days'),
|
||||||
feature_subset=self.args['dataset'].get('feature_subset'),
|
feature_subset=self.args['dataset'].get('feature_subset'),
|
||||||
cache_data=True, # 启用智能缓存(像GPU版本一样)
|
cache_data=True, # 启用智能缓存
|
||||||
preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出
|
preload_all_data=True # 🚀 TPU优化:预加载全部数据,解锁并行分析
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log training dataset initialization performance
|
# Log training dataset initialization performance
|
||||||
@@ -413,7 +456,8 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
train_memory_used = train_memory_mb - initial_memory_mb
|
train_memory_used = train_memory_mb - initial_memory_mb
|
||||||
print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM")
|
print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM")
|
||||||
|
|
||||||
print("🔄 Initializing validation dataset with GPU-style memory management...")
|
print("🔄 Initializing validation dataset with TPU-optimized memory management...")
|
||||||
|
print(" 🚀 Preloading all validation data to RAM for maximum parallel analysis speed...")
|
||||||
val_init_start_time = time.time()
|
val_init_start_time = time.time()
|
||||||
self.val_dataset_tf = BrainToTextDatasetTF(
|
self.val_dataset_tf = BrainToTextDatasetTF(
|
||||||
trial_indices=val_trials,
|
trial_indices=val_trials,
|
||||||
@@ -423,8 +467,8 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
days_per_batch=1, # One day per validation batch
|
days_per_batch=1, # One day per validation batch
|
||||||
random_seed=self.args['dataset']['seed'],
|
random_seed=self.args['dataset']['seed'],
|
||||||
feature_subset=self.args['dataset'].get('feature_subset'),
|
feature_subset=self.args['dataset'].get('feature_subset'),
|
||||||
cache_data=True, # 启用智能缓存(像GPU版本一样)
|
cache_data=True, # 启用智能缓存
|
||||||
preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出
|
preload_all_data=True # 🚀 TPU优化:预加载全部数据,解锁并行分析
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log validation dataset initialization performance
|
# Log validation dataset initialization performance
|
||||||
@@ -525,12 +569,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
day_indices = batch['day_indices']
|
day_indices = batch['day_indices']
|
||||||
|
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
# Apply data transformations
|
# Calculate adjusted lengths for CTC (data augmentation now handled in dataset pipeline)
|
||||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
|
||||||
features, n_time_steps, self.args['dataset']['data_transforms'], training=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate adjusted lengths for CTC
|
|
||||||
adjusted_lens = tf.cast(
|
adjusted_lens = tf.cast(
|
||||||
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
||||||
self.args['model']['patch_stride'] + 1,
|
self.args['model']['patch_stride'] + 1,
|
||||||
@@ -551,25 +590,28 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Calculate losses
|
# Calculate losses
|
||||||
if use_full:
|
if use_full:
|
||||||
# Clean CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
|
# Convert dense labels to sparse for dynamic shapes
|
||||||
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
|
|
||||||
|
# Clean CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
||||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||||
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||||
clean_loss = tf.nn.ctc_loss(
|
clean_loss = tf.nn.ctc_loss(
|
||||||
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
|
labels=sparse_labels,
|
||||||
logits=clean_logits_time_major,
|
logits=clean_logits_time_major,
|
||||||
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
|
label_length=None, # Not needed with sparse labels
|
||||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
blank_index=0,
|
blank_index=0,
|
||||||
logits_time_major=True
|
logits_time_major=True
|
||||||
)
|
)
|
||||||
clean_loss = tf.reduce_mean(clean_loss)
|
clean_loss = tf.reduce_mean(clean_loss)
|
||||||
|
|
||||||
# Noisy CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
|
# Noisy CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
||||||
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
||||||
noisy_loss = tf.nn.ctc_loss(
|
noisy_loss = tf.nn.ctc_loss(
|
||||||
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
|
labels=sparse_labels, # Reuse same sparse labels
|
||||||
logits=noisy_logits_time_major,
|
logits=noisy_logits_time_major,
|
||||||
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
|
label_length=None, # Not needed with sparse labels
|
||||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
blank_index=0,
|
blank_index=0,
|
||||||
logits_time_major=True
|
logits_time_major=True
|
||||||
@@ -583,12 +625,15 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||||
else:
|
else:
|
||||||
# Standard CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
|
# Convert dense labels to sparse for dynamic shapes
|
||||||
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
|
|
||||||
|
# Standard CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
||||||
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||||
loss = tf.nn.ctc_loss(
|
loss = tf.nn.ctc_loss(
|
||||||
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
|
labels=sparse_labels,
|
||||||
logits=logits_time_major,
|
logits=logits_time_major,
|
||||||
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
|
label_length=None, # Not needed with sparse labels
|
||||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
blank_index=0,
|
blank_index=0,
|
||||||
logits_time_major=True
|
logits_time_major=True
|
||||||
@@ -638,12 +683,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
phone_seq_lens = batch['phone_seq_lens']
|
phone_seq_lens = batch['phone_seq_lens']
|
||||||
day_indices = batch['day_indices']
|
day_indices = batch['day_indices']
|
||||||
|
|
||||||
# Apply data transformations (no augmentation for validation)
|
# Calculate adjusted lengths (no augmentation for validation, handled in dataset pipeline)
|
||||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
|
||||||
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate adjusted lengths
|
|
||||||
adjusted_lens = tf.cast(
|
adjusted_lens = tf.cast(
|
||||||
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
|
||||||
self.args['model']['patch_stride'] + 1,
|
self.args['model']['patch_stride'] + 1,
|
||||||
@@ -653,13 +693,16 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Forward pass (inference mode only)
|
# Forward pass (inference mode only)
|
||||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||||
|
|
||||||
# Calculate loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
|
# Convert dense labels to sparse for dynamic shapes
|
||||||
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
|
|
||||||
|
# Calculate loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
||||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||||
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
||||||
loss = tf.nn.ctc_loss(
|
loss = tf.nn.ctc_loss(
|
||||||
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
|
labels=sparse_labels,
|
||||||
logits=logits_time_major,
|
logits=logits_time_major,
|
||||||
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
|
label_length=None, # Not needed with sparse labels
|
||||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
blank_index=0,
|
blank_index=0,
|
||||||
logits_time_major=True
|
logits_time_major=True
|
||||||
@@ -680,60 +723,28 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
initial_tpu_status = self._get_detailed_tpu_status()
|
initial_tpu_status = self._get_detailed_tpu_status()
|
||||||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||||
|
|
||||||
# ========================= DATASET SHAPE ANALYSIS =========================
|
# Create datasets using modern distribution API with dynamic padding
|
||||||
# Perform one-time full dataset analysis for fixed shapes (TPU requirement)
|
def create_dist_dataset_fn(input_dataset_tf, training):
|
||||||
self.logger.info("🚀 Performing one-time full dataset analysis for fixed shapes...")
|
"""Create distributed dataset function for modern TPU strategy with batch-first approach"""
|
||||||
|
|
||||||
# Analyze training dataset (all data for accurate max shapes)
|
|
||||||
train_analysis_start = time.time()
|
|
||||||
train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1)
|
|
||||||
train_analysis_time = time.time() - train_analysis_start
|
|
||||||
self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s")
|
|
||||||
|
|
||||||
# Analyze validation dataset (all data for accurate max shapes)
|
|
||||||
val_analysis_start = time.time()
|
|
||||||
val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1)
|
|
||||||
val_analysis_time = time.time() - val_analysis_start
|
|
||||||
self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s")
|
|
||||||
|
|
||||||
# Use maximum shapes across both datasets for consistent padding
|
|
||||||
final_max_shapes = {
|
|
||||||
'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']),
|
|
||||||
'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']),
|
|
||||||
'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']),
|
|
||||||
'n_features': train_max_shapes['n_features']
|
|
||||||
}
|
|
||||||
|
|
||||||
self.logger.info(f"📊 Final fixed shapes for TPU training:")
|
|
||||||
self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}")
|
|
||||||
self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}")
|
|
||||||
self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}")
|
|
||||||
self.logger.info(f" Features: {final_max_shapes['n_features']}")
|
|
||||||
# =====================================================================
|
|
||||||
|
|
||||||
# Create datasets using modern distribution API with fixed shapes
|
|
||||||
def create_dist_dataset_fn(input_dataset_tf, training, max_shapes):
|
|
||||||
"""Create distributed dataset function for modern TPU strategy with fixed shapes"""
|
|
||||||
def dataset_fn(input_context):
|
def dataset_fn(input_context):
|
||||||
# create_input_fn now requires max_shapes parameter for fixed shapes
|
# create_input_fn now uses batch-first approach with dynamic padding
|
||||||
return create_input_fn(
|
return create_input_fn(
|
||||||
input_dataset_tf,
|
input_dataset_tf,
|
||||||
self.args['dataset']['data_transforms'],
|
self.args['dataset']['data_transforms'],
|
||||||
max_shapes=max_shapes, # Pass pre-analyzed shapes
|
|
||||||
training=training
|
training=training
|
||||||
)
|
)
|
||||||
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||||||
|
|
||||||
# Distribute datasets using modern API with fixed shapes
|
# Distribute datasets using modern API with batch-first approach
|
||||||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
||||||
dist_start_time = time.time()
|
dist_start_time = time.time()
|
||||||
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes)
|
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
|
||||||
train_dist_time = time.time() - dist_start_time
|
train_dist_time = time.time() - dist_start_time
|
||||||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||||||
|
|
||||||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
||||||
val_start_time = time.time()
|
val_start_time = time.time()
|
||||||
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes)
|
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
|
||||||
val_dist_time = time.time() - val_start_time
|
val_dist_time = time.time() - val_start_time
|
||||||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user