This commit is contained in:
Zchen
2025-10-17 12:20:17 +08:00
parent eb058fe9d3
commit d83f990beb

View File

@@ -428,6 +428,63 @@ class BrainToTextDatasetTF:
return dataset
def create_individual_dataset(self) -> tf.data.Dataset:
"""
Create tf.data.Dataset that yields individual examples for TPU-optimized batching
This method creates individual examples instead of pre-batched data,
allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU.
"""
def individual_example_generator():
"""Generator that yields individual trial examples"""
for batch_idx in range(self.n_batches):
batch_index = self.batch_indices[batch_idx]
# Process each trial in the batch individually
for day in batch_index.keys():
for trial in batch_index[day]:
trial_data = self._load_trial_data(day, trial)
# Yield individual example with all required fields
example = {
'input_features': trial_data['input_features'].astype(np.float32),
'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32),
'n_time_steps': np.int32(trial_data['n_time_steps']),
'phone_seq_lens': np.int32(trial_data['phone_seq_lens']),
'day_indices': np.int32(trial_data['day_index']),
'transcriptions': trial_data['transcription'].astype(np.int32),
'block_nums': np.int32(trial_data['block_num']),
'trial_nums': np.int32(trial_data['trial_num'])
}
yield example
# Define output signature for individual examples
output_signature = {
'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32),
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32),
'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32),
'day_indices': tf.TensorSpec(shape=(), dtype=tf.int32),
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'block_nums': tf.TensorSpec(shape=(), dtype=tf.int32),
'trial_nums': tf.TensorSpec(shape=(), dtype=tf.int32)
}
# Create dataset from individual examples
dataset = tf.data.Dataset.from_generator(
individual_example_generator,
output_signature=output_signature
)
# Shuffle individual examples if training (more effective than batch-level shuffle)
if self.split == 'train':
# Use a reasonable shuffle buffer - not too large to avoid memory issues
shuffle_buffer = min(1000, self.n_trials)
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
return dataset
class DataAugmentationTF:
"""
@@ -439,7 +496,10 @@ class DataAugmentationTF:
smooth_kernel_std: float = 2.0,
smooth_kernel_size: int = 100) -> tf.Tensor:
"""
Apply Gaussian smoothing along the time axis using TensorFlow operations
Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation.
This implementation uses depthwise_conv2d for optimal TPU performance,
replacing the inefficient Python for-loop that created 512 separate conv1d operations.
Args:
inputs: Input tensor [batch_size, time_steps, features]
@@ -456,52 +516,38 @@ class DataAugmentationTF:
valid_idx = np.argwhere(gauss_kernel > 0.01)
gauss_kernel = gauss_kernel[valid_idx].flatten()
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
# Convert to TensorFlow tensor and reshape for conv1d
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
# ========================= OPTIMIZED SOLUTION =========================
# Get input dimensions
num_features = tf.shape(inputs)[-1]
kernel_size = tf.shape(gauss_kernel)[0]
gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels]
# Get tensor dimensions
batch_size = tf.shape(inputs)[0]
time_steps = tf.shape(inputs)[1]
num_features = tf.shape(inputs)[2]
# Prepare kernel for depthwise_conv2d
# Shape needed: [height, width, in_channels, channel_multiplier]
# Our case: [kernel_size, 1, num_features, 1]
# This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel
kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1, 1])
kernel = tf.tile(kernel, [1, 1, num_features, 1])
# Apply convolution to each feature channel separately
smoothed_features = []
# Prepare input for conv2d
# Shape needed: [batch, height, width, channels]
# Our case: [batch_size, time_steps, 1, num_features]
# Add a dummy width dimension
reshaped_inputs = tf.expand_dims(inputs, axis=2)
# Convert num_features to Python int for loop
num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
# Execute depthwise convolution
# This is a single, efficient operation replacing the original Python for-loop
smoothed = tf.nn.depthwise_conv2d(
reshaped_inputs,
kernel,
strides=[1, 1, 1, 1],
padding='SAME'
)
if isinstance(num_features_py, tf.Tensor):
# If dynamic, use tf.map_fn for dynamic number of features
def smooth_single_feature(i):
# Extract single feature channel: [batch_size, time_steps, 1]
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
# Apply 1D convolution
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
# Use tf.map_fn for dynamic features
indices = tf.range(num_features)
smoothed_features_tensor = tf.map_fn(
smooth_single_feature,
indices,
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
)
# Transpose to get [batch_size, time_steps, features]
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
smoothed = tf.squeeze(smoothed, axis=-1)
else:
# Static number of features - use loop
for i in range(num_features_py):
# Extract single feature channel: [batch_size, time_steps, 1]
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
# Apply 1D convolution
smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
smoothed_features.append(smoothed_channel)
# Concatenate all smoothed features
smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features]
# Remove the dummy width dimension to restore original shape
smoothed = tf.squeeze(smoothed, axis=2)
# ================================================================
return smoothed
@@ -683,7 +729,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True) -> tf.data.Dataset:
"""
Create input function for TPU training with data augmentation
Create input function for TPU training with fixed-shape batching and data augmentation
Args:
dataset_tf: BrainToTextDatasetTF instance
@@ -691,30 +737,75 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
training: Whether this is for training (applies augmentations)
Returns:
tf.data.Dataset ready for TPU training
tf.data.Dataset ready for TPU training with fixed shapes
"""
dataset = dataset_tf.create_dataset()
def apply_transforms(batch):
"""Apply data transformations to a batch"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Create individual example dataset instead of pre-batched dataset
dataset = dataset_tf.create_individual_dataset()
def apply_transforms(example):
"""Apply data transformations to individual examples"""
features = example['input_features']
n_time_steps = example['n_time_steps']
# Apply transformations
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, transform_args, training=training
tf.expand_dims(features, 0), # Add batch dimension for transforms
tf.expand_dims(n_time_steps, 0),
transform_args,
training=training
)
# Update batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
# Remove batch dimension
example['input_features'] = tf.squeeze(features, 0)
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
return batch
return example
# Apply transformations
# Apply transformations to individual examples
dataset = dataset.map(
apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
# Define fixed shapes for TPU compatibility
# These should match the maximum expected sizes in your dataset
max_time_steps = 4096 # Adjust based on your data
max_phone_seq_len = 256 # Adjust based on your data
max_transcription_len = 512 # Adjust based on your data
n_features = 512 # Number of neural features
padded_shapes = {
'input_features': [max_time_steps, n_features],
'seq_class_ids': [max_phone_seq_len],
'n_time_steps': [], # Scalar
'phone_seq_lens': [], # Scalar
'day_indices': [], # Scalar
'transcriptions': [max_transcription_len],
'block_nums': [], # Scalar
'trial_nums': [] # Scalar
}
padding_values = {
'input_features': 0.0,
'seq_class_ids': 0,
'n_time_steps': 0,
'phone_seq_lens': 0,
'day_indices': 0,
'transcriptions': 0,
'block_nums': 0,
'trial_nums': 0
}
# Create fixed-shape batches with padding
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
padding_values=padding_values,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
# Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset