f
This commit is contained in:
@@ -428,6 +428,63 @@ class BrainToTextDatasetTF:
|
||||
|
||||
return dataset
|
||||
|
||||
def create_individual_dataset(self) -> tf.data.Dataset:
|
||||
"""
|
||||
Create tf.data.Dataset that yields individual examples for TPU-optimized batching
|
||||
|
||||
This method creates individual examples instead of pre-batched data,
|
||||
allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU.
|
||||
"""
|
||||
|
||||
def individual_example_generator():
|
||||
"""Generator that yields individual trial examples"""
|
||||
for batch_idx in range(self.n_batches):
|
||||
batch_index = self.batch_indices[batch_idx]
|
||||
|
||||
# Process each trial in the batch individually
|
||||
for day in batch_index.keys():
|
||||
for trial in batch_index[day]:
|
||||
trial_data = self._load_trial_data(day, trial)
|
||||
|
||||
# Yield individual example with all required fields
|
||||
example = {
|
||||
'input_features': trial_data['input_features'].astype(np.float32),
|
||||
'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32),
|
||||
'n_time_steps': np.int32(trial_data['n_time_steps']),
|
||||
'phone_seq_lens': np.int32(trial_data['phone_seq_lens']),
|
||||
'day_indices': np.int32(trial_data['day_index']),
|
||||
'transcriptions': trial_data['transcription'].astype(np.int32),
|
||||
'block_nums': np.int32(trial_data['block_num']),
|
||||
'trial_nums': np.int32(trial_data['trial_num'])
|
||||
}
|
||||
yield example
|
||||
|
||||
# Define output signature for individual examples
|
||||
output_signature = {
|
||||
'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32),
|
||||
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||
'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||
'day_indices': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'block_nums': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||
'trial_nums': tf.TensorSpec(shape=(), dtype=tf.int32)
|
||||
}
|
||||
|
||||
# Create dataset from individual examples
|
||||
dataset = tf.data.Dataset.from_generator(
|
||||
individual_example_generator,
|
||||
output_signature=output_signature
|
||||
)
|
||||
|
||||
# Shuffle individual examples if training (more effective than batch-level shuffle)
|
||||
if self.split == 'train':
|
||||
# Use a reasonable shuffle buffer - not too large to avoid memory issues
|
||||
shuffle_buffer = min(1000, self.n_trials)
|
||||
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class DataAugmentationTF:
|
||||
"""
|
||||
@@ -439,7 +496,10 @@ class DataAugmentationTF:
|
||||
smooth_kernel_std: float = 2.0,
|
||||
smooth_kernel_size: int = 100) -> tf.Tensor:
|
||||
"""
|
||||
Apply Gaussian smoothing along the time axis using TensorFlow operations
|
||||
Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation.
|
||||
|
||||
This implementation uses depthwise_conv2d for optimal TPU performance,
|
||||
replacing the inefficient Python for-loop that created 512 separate conv1d operations.
|
||||
|
||||
Args:
|
||||
inputs: Input tensor [batch_size, time_steps, features]
|
||||
@@ -456,52 +516,38 @@ class DataAugmentationTF:
|
||||
valid_idx = np.argwhere(gauss_kernel > 0.01)
|
||||
gauss_kernel = gauss_kernel[valid_idx].flatten()
|
||||
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
|
||||
|
||||
# Convert to TensorFlow tensor and reshape for conv1d
|
||||
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
|
||||
|
||||
# ========================= OPTIMIZED SOLUTION =========================
|
||||
# Get input dimensions
|
||||
num_features = tf.shape(inputs)[-1]
|
||||
kernel_size = tf.shape(gauss_kernel)[0]
|
||||
gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels]
|
||||
|
||||
# Get tensor dimensions
|
||||
batch_size = tf.shape(inputs)[0]
|
||||
time_steps = tf.shape(inputs)[1]
|
||||
num_features = tf.shape(inputs)[2]
|
||||
# Prepare kernel for depthwise_conv2d
|
||||
# Shape needed: [height, width, in_channels, channel_multiplier]
|
||||
# Our case: [kernel_size, 1, num_features, 1]
|
||||
# This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel
|
||||
kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1, 1])
|
||||
kernel = tf.tile(kernel, [1, 1, num_features, 1])
|
||||
|
||||
# Apply convolution to each feature channel separately
|
||||
smoothed_features = []
|
||||
# Prepare input for conv2d
|
||||
# Shape needed: [batch, height, width, channels]
|
||||
# Our case: [batch_size, time_steps, 1, num_features]
|
||||
# Add a dummy width dimension
|
||||
reshaped_inputs = tf.expand_dims(inputs, axis=2)
|
||||
|
||||
# Convert num_features to Python int for loop
|
||||
num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
|
||||
# Execute depthwise convolution
|
||||
# This is a single, efficient operation replacing the original Python for-loop
|
||||
smoothed = tf.nn.depthwise_conv2d(
|
||||
reshaped_inputs,
|
||||
kernel,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME'
|
||||
)
|
||||
|
||||
if isinstance(num_features_py, tf.Tensor):
|
||||
# If dynamic, use tf.map_fn for dynamic number of features
|
||||
def smooth_single_feature(i):
|
||||
# Extract single feature channel: [batch_size, time_steps, 1]
|
||||
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
||||
# Apply 1D convolution
|
||||
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
||||
|
||||
# Use tf.map_fn for dynamic features
|
||||
indices = tf.range(num_features)
|
||||
smoothed_features_tensor = tf.map_fn(
|
||||
smooth_single_feature,
|
||||
indices,
|
||||
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
|
||||
)
|
||||
# Transpose to get [batch_size, time_steps, features]
|
||||
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
||||
smoothed = tf.squeeze(smoothed, axis=-1)
|
||||
else:
|
||||
# Static number of features - use loop
|
||||
for i in range(num_features_py):
|
||||
# Extract single feature channel: [batch_size, time_steps, 1]
|
||||
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
||||
# Apply 1D convolution
|
||||
smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
||||
smoothed_features.append(smoothed_channel)
|
||||
|
||||
# Concatenate all smoothed features
|
||||
smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features]
|
||||
# Remove the dummy width dimension to restore original shape
|
||||
smoothed = tf.squeeze(smoothed, axis=2)
|
||||
# ================================================================
|
||||
|
||||
return smoothed
|
||||
|
||||
@@ -683,7 +729,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Dict[str, Any],
|
||||
training: bool = True) -> tf.data.Dataset:
|
||||
"""
|
||||
Create input function for TPU training with data augmentation
|
||||
Create input function for TPU training with fixed-shape batching and data augmentation
|
||||
|
||||
Args:
|
||||
dataset_tf: BrainToTextDatasetTF instance
|
||||
@@ -691,30 +737,75 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
training: Whether this is for training (applies augmentations)
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset ready for TPU training
|
||||
tf.data.Dataset ready for TPU training with fixed shapes
|
||||
"""
|
||||
dataset = dataset_tf.create_dataset()
|
||||
|
||||
def apply_transforms(batch):
|
||||
"""Apply data transformations to a batch"""
|
||||
features = batch['input_features']
|
||||
n_time_steps = batch['n_time_steps']
|
||||
# Create individual example dataset instead of pre-batched dataset
|
||||
dataset = dataset_tf.create_individual_dataset()
|
||||
|
||||
def apply_transforms(example):
|
||||
"""Apply data transformations to individual examples"""
|
||||
features = example['input_features']
|
||||
n_time_steps = example['n_time_steps']
|
||||
|
||||
# Apply transformations
|
||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||
features, n_time_steps, transform_args, training=training
|
||||
tf.expand_dims(features, 0), # Add batch dimension for transforms
|
||||
tf.expand_dims(n_time_steps, 0),
|
||||
transform_args,
|
||||
training=training
|
||||
)
|
||||
|
||||
# Update batch with transformed data
|
||||
batch['input_features'] = features
|
||||
batch['n_time_steps'] = n_time_steps
|
||||
# Remove batch dimension
|
||||
example['input_features'] = tf.squeeze(features, 0)
|
||||
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
|
||||
|
||||
return batch
|
||||
return example
|
||||
|
||||
# Apply transformations
|
||||
# Apply transformations to individual examples
|
||||
dataset = dataset.map(
|
||||
apply_transforms,
|
||||
num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
# Define fixed shapes for TPU compatibility
|
||||
# These should match the maximum expected sizes in your dataset
|
||||
max_time_steps = 4096 # Adjust based on your data
|
||||
max_phone_seq_len = 256 # Adjust based on your data
|
||||
max_transcription_len = 512 # Adjust based on your data
|
||||
n_features = 512 # Number of neural features
|
||||
|
||||
padded_shapes = {
|
||||
'input_features': [max_time_steps, n_features],
|
||||
'seq_class_ids': [max_phone_seq_len],
|
||||
'n_time_steps': [], # Scalar
|
||||
'phone_seq_lens': [], # Scalar
|
||||
'day_indices': [], # Scalar
|
||||
'transcriptions': [max_transcription_len],
|
||||
'block_nums': [], # Scalar
|
||||
'trial_nums': [] # Scalar
|
||||
}
|
||||
|
||||
padding_values = {
|
||||
'input_features': 0.0,
|
||||
'seq_class_ids': 0,
|
||||
'n_time_steps': 0,
|
||||
'phone_seq_lens': 0,
|
||||
'day_indices': 0,
|
||||
'transcriptions': 0,
|
||||
'block_nums': 0,
|
||||
'trial_nums': 0
|
||||
}
|
||||
|
||||
# Create fixed-shape batches with padding
|
||||
dataset = dataset.padded_batch(
|
||||
batch_size=dataset_tf.batch_size,
|
||||
padded_shapes=padded_shapes,
|
||||
padding_values=padding_values,
|
||||
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||||
)
|
||||
|
||||
# Prefetch for optimal performance
|
||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
return dataset
|
Reference in New Issue
Block a user