diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 7e81f7f..ca4ceea 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -428,6 +428,63 @@ class BrainToTextDatasetTF: return dataset + def create_individual_dataset(self) -> tf.data.Dataset: + """ + Create tf.data.Dataset that yields individual examples for TPU-optimized batching + + This method creates individual examples instead of pre-batched data, + allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU. + """ + + def individual_example_generator(): + """Generator that yields individual trial examples""" + for batch_idx in range(self.n_batches): + batch_index = self.batch_indices[batch_idx] + + # Process each trial in the batch individually + for day in batch_index.keys(): + for trial in batch_index[day]: + trial_data = self._load_trial_data(day, trial) + + # Yield individual example with all required fields + example = { + 'input_features': trial_data['input_features'].astype(np.float32), + 'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32), + 'n_time_steps': np.int32(trial_data['n_time_steps']), + 'phone_seq_lens': np.int32(trial_data['phone_seq_lens']), + 'day_indices': np.int32(trial_data['day_index']), + 'transcriptions': trial_data['transcription'].astype(np.int32), + 'block_nums': np.int32(trial_data['block_num']), + 'trial_nums': np.int32(trial_data['trial_num']) + } + yield example + + # Define output signature for individual examples + output_signature = { + 'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32), + 'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32), + 'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32), + 'day_indices': tf.TensorSpec(shape=(), dtype=tf.int32), + 'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'block_nums': tf.TensorSpec(shape=(), dtype=tf.int32), + 'trial_nums': tf.TensorSpec(shape=(), dtype=tf.int32) + } + + # Create dataset from individual examples + dataset = tf.data.Dataset.from_generator( + individual_example_generator, + output_signature=output_signature + ) + + # Shuffle individual examples if training (more effective than batch-level shuffle) + if self.split == 'train': + # Use a reasonable shuffle buffer - not too large to avoid memory issues + shuffle_buffer = min(1000, self.n_trials) + dataset = dataset.shuffle(buffer_size=shuffle_buffer) + + return dataset + class DataAugmentationTF: """ @@ -439,7 +496,10 @@ class DataAugmentationTF: smooth_kernel_std: float = 2.0, smooth_kernel_size: int = 100) -> tf.Tensor: """ - Apply Gaussian smoothing along the time axis using TensorFlow operations + Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation. + + This implementation uses depthwise_conv2d for optimal TPU performance, + replacing the inefficient Python for-loop that created 512 separate conv1d operations. Args: inputs: Input tensor [batch_size, time_steps, features] @@ -456,52 +516,38 @@ class DataAugmentationTF: valid_idx = np.argwhere(gauss_kernel > 0.01) gauss_kernel = gauss_kernel[valid_idx].flatten() gauss_kernel = gauss_kernel / np.sum(gauss_kernel) - - # Convert to TensorFlow tensor and reshape for conv1d gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32) + + # ========================= OPTIMIZED SOLUTION ========================= + # Get input dimensions + num_features = tf.shape(inputs)[-1] kernel_size = tf.shape(gauss_kernel)[0] - gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels] - # Get tensor dimensions - batch_size = tf.shape(inputs)[0] - time_steps = tf.shape(inputs)[1] - num_features = tf.shape(inputs)[2] + # Prepare kernel for depthwise_conv2d + # Shape needed: [height, width, in_channels, channel_multiplier] + # Our case: [kernel_size, 1, num_features, 1] + # This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel + kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1, 1]) + kernel = tf.tile(kernel, [1, 1, num_features, 1]) - # Apply convolution to each feature channel separately - smoothed_features = [] + # Prepare input for conv2d + # Shape needed: [batch, height, width, channels] + # Our case: [batch_size, time_steps, 1, num_features] + # Add a dummy width dimension + reshaped_inputs = tf.expand_dims(inputs, axis=2) - # Convert num_features to Python int for loop - num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1] + # Execute depthwise convolution + # This is a single, efficient operation replacing the original Python for-loop + smoothed = tf.nn.depthwise_conv2d( + reshaped_inputs, + kernel, + strides=[1, 1, 1, 1], + padding='SAME' + ) - if isinstance(num_features_py, tf.Tensor): - # If dynamic, use tf.map_fn for dynamic number of features - def smooth_single_feature(i): - # Extract single feature channel: [batch_size, time_steps, 1] - feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1) - # Apply 1D convolution - return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME') - - # Use tf.map_fn for dynamic features - indices = tf.range(num_features) - smoothed_features_tensor = tf.map_fn( - smooth_single_feature, - indices, - fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32) - ) - # Transpose to get [batch_size, time_steps, features] - smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3]) - smoothed = tf.squeeze(smoothed, axis=-1) - else: - # Static number of features - use loop - for i in range(num_features_py): - # Extract single feature channel: [batch_size, time_steps, 1] - feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1) - # Apply 1D convolution - smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME') - smoothed_features.append(smoothed_channel) - - # Concatenate all smoothed features - smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features] + # Remove the dummy width dimension to restore original shape + smoothed = tf.squeeze(smoothed, axis=2) + # ================================================================ return smoothed @@ -683,7 +729,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], training: bool = True) -> tf.data.Dataset: """ - Create input function for TPU training with data augmentation + Create input function for TPU training with fixed-shape batching and data augmentation Args: dataset_tf: BrainToTextDatasetTF instance @@ -691,30 +737,75 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, training: Whether this is for training (applies augmentations) Returns: - tf.data.Dataset ready for TPU training + tf.data.Dataset ready for TPU training with fixed shapes """ - dataset = dataset_tf.create_dataset() - def apply_transforms(batch): - """Apply data transformations to a batch""" - features = batch['input_features'] - n_time_steps = batch['n_time_steps'] + # Create individual example dataset instead of pre-batched dataset + dataset = dataset_tf.create_individual_dataset() + + def apply_transforms(example): + """Apply data transformations to individual examples""" + features = example['input_features'] + n_time_steps = example['n_time_steps'] # Apply transformations features, n_time_steps = DataAugmentationTF.transform_data( - features, n_time_steps, transform_args, training=training + tf.expand_dims(features, 0), # Add batch dimension for transforms + tf.expand_dims(n_time_steps, 0), + transform_args, + training=training ) - # Update batch with transformed data - batch['input_features'] = features - batch['n_time_steps'] = n_time_steps + # Remove batch dimension + example['input_features'] = tf.squeeze(features, 0) + example['n_time_steps'] = tf.squeeze(n_time_steps, 0) - return batch + return example - # Apply transformations + # Apply transformations to individual examples dataset = dataset.map( apply_transforms, num_parallel_calls=tf.data.AUTOTUNE ) + # Define fixed shapes for TPU compatibility + # These should match the maximum expected sizes in your dataset + max_time_steps = 4096 # Adjust based on your data + max_phone_seq_len = 256 # Adjust based on your data + max_transcription_len = 512 # Adjust based on your data + n_features = 512 # Number of neural features + + padded_shapes = { + 'input_features': [max_time_steps, n_features], + 'seq_class_ids': [max_phone_seq_len], + 'n_time_steps': [], # Scalar + 'phone_seq_lens': [], # Scalar + 'day_indices': [], # Scalar + 'transcriptions': [max_transcription_len], + 'block_nums': [], # Scalar + 'trial_nums': [] # Scalar + } + + padding_values = { + 'input_features': 0.0, + 'seq_class_ids': 0, + 'n_time_steps': 0, + 'phone_seq_lens': 0, + 'day_indices': 0, + 'transcriptions': 0, + 'block_nums': 0, + 'trial_nums': 0 + } + + # Create fixed-shape batches with padding + dataset = dataset.padded_batch( + batch_size=dataset_tf.batch_size, + padded_shapes=padded_shapes, + padding_values=padding_values, + drop_remainder=True # Critical for TPU: ensures all batches have same size + ) + + # Prefetch for optimal performance + dataset = dataset.prefetch(tf.data.AUTOTUNE) + return dataset \ No newline at end of file