f
This commit is contained in:
@@ -428,6 +428,63 @@ class BrainToTextDatasetTF:
|
|||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
def create_individual_dataset(self) -> tf.data.Dataset:
|
||||||
|
"""
|
||||||
|
Create tf.data.Dataset that yields individual examples for TPU-optimized batching
|
||||||
|
|
||||||
|
This method creates individual examples instead of pre-batched data,
|
||||||
|
allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def individual_example_generator():
|
||||||
|
"""Generator that yields individual trial examples"""
|
||||||
|
for batch_idx in range(self.n_batches):
|
||||||
|
batch_index = self.batch_indices[batch_idx]
|
||||||
|
|
||||||
|
# Process each trial in the batch individually
|
||||||
|
for day in batch_index.keys():
|
||||||
|
for trial in batch_index[day]:
|
||||||
|
trial_data = self._load_trial_data(day, trial)
|
||||||
|
|
||||||
|
# Yield individual example with all required fields
|
||||||
|
example = {
|
||||||
|
'input_features': trial_data['input_features'].astype(np.float32),
|
||||||
|
'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32),
|
||||||
|
'n_time_steps': np.int32(trial_data['n_time_steps']),
|
||||||
|
'phone_seq_lens': np.int32(trial_data['phone_seq_lens']),
|
||||||
|
'day_indices': np.int32(trial_data['day_index']),
|
||||||
|
'transcriptions': trial_data['transcription'].astype(np.int32),
|
||||||
|
'block_nums': np.int32(trial_data['block_num']),
|
||||||
|
'trial_nums': np.int32(trial_data['trial_num'])
|
||||||
|
}
|
||||||
|
yield example
|
||||||
|
|
||||||
|
# Define output signature for individual examples
|
||||||
|
output_signature = {
|
||||||
|
'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32),
|
||||||
|
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||||
|
'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||||
|
'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||||
|
'day_indices': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||||
|
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||||
|
'block_nums': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||||
|
'trial_nums': tf.TensorSpec(shape=(), dtype=tf.int32)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create dataset from individual examples
|
||||||
|
dataset = tf.data.Dataset.from_generator(
|
||||||
|
individual_example_generator,
|
||||||
|
output_signature=output_signature
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shuffle individual examples if training (more effective than batch-level shuffle)
|
||||||
|
if self.split == 'train':
|
||||||
|
# Use a reasonable shuffle buffer - not too large to avoid memory issues
|
||||||
|
shuffle_buffer = min(1000, self.n_trials)
|
||||||
|
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
class DataAugmentationTF:
|
class DataAugmentationTF:
|
||||||
"""
|
"""
|
||||||
@@ -439,7 +496,10 @@ class DataAugmentationTF:
|
|||||||
smooth_kernel_std: float = 2.0,
|
smooth_kernel_std: float = 2.0,
|
||||||
smooth_kernel_size: int = 100) -> tf.Tensor:
|
smooth_kernel_size: int = 100) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply Gaussian smoothing along the time axis using TensorFlow operations
|
Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation.
|
||||||
|
|
||||||
|
This implementation uses depthwise_conv2d for optimal TPU performance,
|
||||||
|
replacing the inefficient Python for-loop that created 512 separate conv1d operations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: Input tensor [batch_size, time_steps, features]
|
inputs: Input tensor [batch_size, time_steps, features]
|
||||||
@@ -456,52 +516,38 @@ class DataAugmentationTF:
|
|||||||
valid_idx = np.argwhere(gauss_kernel > 0.01)
|
valid_idx = np.argwhere(gauss_kernel > 0.01)
|
||||||
gauss_kernel = gauss_kernel[valid_idx].flatten()
|
gauss_kernel = gauss_kernel[valid_idx].flatten()
|
||||||
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
|
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
|
||||||
|
|
||||||
# Convert to TensorFlow tensor and reshape for conv1d
|
|
||||||
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
|
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
|
||||||
|
|
||||||
|
# ========================= OPTIMIZED SOLUTION =========================
|
||||||
|
# Get input dimensions
|
||||||
|
num_features = tf.shape(inputs)[-1]
|
||||||
kernel_size = tf.shape(gauss_kernel)[0]
|
kernel_size = tf.shape(gauss_kernel)[0]
|
||||||
gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels]
|
|
||||||
|
|
||||||
# Get tensor dimensions
|
# Prepare kernel for depthwise_conv2d
|
||||||
batch_size = tf.shape(inputs)[0]
|
# Shape needed: [height, width, in_channels, channel_multiplier]
|
||||||
time_steps = tf.shape(inputs)[1]
|
# Our case: [kernel_size, 1, num_features, 1]
|
||||||
num_features = tf.shape(inputs)[2]
|
# This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel
|
||||||
|
kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1, 1])
|
||||||
|
kernel = tf.tile(kernel, [1, 1, num_features, 1])
|
||||||
|
|
||||||
# Apply convolution to each feature channel separately
|
# Prepare input for conv2d
|
||||||
smoothed_features = []
|
# Shape needed: [batch, height, width, channels]
|
||||||
|
# Our case: [batch_size, time_steps, 1, num_features]
|
||||||
|
# Add a dummy width dimension
|
||||||
|
reshaped_inputs = tf.expand_dims(inputs, axis=2)
|
||||||
|
|
||||||
# Convert num_features to Python int for loop
|
# Execute depthwise convolution
|
||||||
num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
|
# This is a single, efficient operation replacing the original Python for-loop
|
||||||
|
smoothed = tf.nn.depthwise_conv2d(
|
||||||
|
reshaped_inputs,
|
||||||
|
kernel,
|
||||||
|
strides=[1, 1, 1, 1],
|
||||||
|
padding='SAME'
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(num_features_py, tf.Tensor):
|
# Remove the dummy width dimension to restore original shape
|
||||||
# If dynamic, use tf.map_fn for dynamic number of features
|
smoothed = tf.squeeze(smoothed, axis=2)
|
||||||
def smooth_single_feature(i):
|
# ================================================================
|
||||||
# Extract single feature channel: [batch_size, time_steps, 1]
|
|
||||||
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
|
||||||
# Apply 1D convolution
|
|
||||||
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
|
||||||
|
|
||||||
# Use tf.map_fn for dynamic features
|
|
||||||
indices = tf.range(num_features)
|
|
||||||
smoothed_features_tensor = tf.map_fn(
|
|
||||||
smooth_single_feature,
|
|
||||||
indices,
|
|
||||||
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
|
|
||||||
)
|
|
||||||
# Transpose to get [batch_size, time_steps, features]
|
|
||||||
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
|
||||||
smoothed = tf.squeeze(smoothed, axis=-1)
|
|
||||||
else:
|
|
||||||
# Static number of features - use loop
|
|
||||||
for i in range(num_features_py):
|
|
||||||
# Extract single feature channel: [batch_size, time_steps, 1]
|
|
||||||
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
|
||||||
# Apply 1D convolution
|
|
||||||
smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
|
||||||
smoothed_features.append(smoothed_channel)
|
|
||||||
|
|
||||||
# Concatenate all smoothed features
|
|
||||||
smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features]
|
|
||||||
|
|
||||||
return smoothed
|
return smoothed
|
||||||
|
|
||||||
@@ -683,7 +729,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
transform_args: Dict[str, Any],
|
transform_args: Dict[str, Any],
|
||||||
training: bool = True) -> tf.data.Dataset:
|
training: bool = True) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
Create input function for TPU training with data augmentation
|
Create input function for TPU training with fixed-shape batching and data augmentation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_tf: BrainToTextDatasetTF instance
|
dataset_tf: BrainToTextDatasetTF instance
|
||||||
@@ -691,30 +737,75 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
training: Whether this is for training (applies augmentations)
|
training: Whether this is for training (applies augmentations)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset ready for TPU training
|
tf.data.Dataset ready for TPU training with fixed shapes
|
||||||
"""
|
"""
|
||||||
dataset = dataset_tf.create_dataset()
|
|
||||||
|
|
||||||
def apply_transforms(batch):
|
# Create individual example dataset instead of pre-batched dataset
|
||||||
"""Apply data transformations to a batch"""
|
dataset = dataset_tf.create_individual_dataset()
|
||||||
features = batch['input_features']
|
|
||||||
n_time_steps = batch['n_time_steps']
|
def apply_transforms(example):
|
||||||
|
"""Apply data transformations to individual examples"""
|
||||||
|
features = example['input_features']
|
||||||
|
n_time_steps = example['n_time_steps']
|
||||||
|
|
||||||
# Apply transformations
|
# Apply transformations
|
||||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||||
features, n_time_steps, transform_args, training=training
|
tf.expand_dims(features, 0), # Add batch dimension for transforms
|
||||||
|
tf.expand_dims(n_time_steps, 0),
|
||||||
|
transform_args,
|
||||||
|
training=training
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update batch with transformed data
|
# Remove batch dimension
|
||||||
batch['input_features'] = features
|
example['input_features'] = tf.squeeze(features, 0)
|
||||||
batch['n_time_steps'] = n_time_steps
|
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
|
||||||
|
|
||||||
return batch
|
return example
|
||||||
|
|
||||||
# Apply transformations
|
# Apply transformations to individual examples
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
apply_transforms,
|
apply_transforms,
|
||||||
num_parallel_calls=tf.data.AUTOTUNE
|
num_parallel_calls=tf.data.AUTOTUNE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Define fixed shapes for TPU compatibility
|
||||||
|
# These should match the maximum expected sizes in your dataset
|
||||||
|
max_time_steps = 4096 # Adjust based on your data
|
||||||
|
max_phone_seq_len = 256 # Adjust based on your data
|
||||||
|
max_transcription_len = 512 # Adjust based on your data
|
||||||
|
n_features = 512 # Number of neural features
|
||||||
|
|
||||||
|
padded_shapes = {
|
||||||
|
'input_features': [max_time_steps, n_features],
|
||||||
|
'seq_class_ids': [max_phone_seq_len],
|
||||||
|
'n_time_steps': [], # Scalar
|
||||||
|
'phone_seq_lens': [], # Scalar
|
||||||
|
'day_indices': [], # Scalar
|
||||||
|
'transcriptions': [max_transcription_len],
|
||||||
|
'block_nums': [], # Scalar
|
||||||
|
'trial_nums': [] # Scalar
|
||||||
|
}
|
||||||
|
|
||||||
|
padding_values = {
|
||||||
|
'input_features': 0.0,
|
||||||
|
'seq_class_ids': 0,
|
||||||
|
'n_time_steps': 0,
|
||||||
|
'phone_seq_lens': 0,
|
||||||
|
'day_indices': 0,
|
||||||
|
'transcriptions': 0,
|
||||||
|
'block_nums': 0,
|
||||||
|
'trial_nums': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create fixed-shape batches with padding
|
||||||
|
dataset = dataset.padded_batch(
|
||||||
|
batch_size=dataset_tf.batch_size,
|
||||||
|
padded_shapes=padded_shapes,
|
||||||
|
padding_values=padding_values,
|
||||||
|
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefetch for optimal performance
|
||||||
|
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
Reference in New Issue
Block a user