595 lines
24 KiB
Python
595 lines
24 KiB
Python
import os
|
|
import tensorflow as tf
|
|
import h5py
|
|
import numpy as np
|
|
import math
|
|
from typing import Dict, List, Tuple, Optional, Any
|
|
from scipy.ndimage import gaussian_filter1d
|
|
|
|
|
|
class BrainToTextDatasetTF:
|
|
"""
|
|
TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8
|
|
|
|
This class creates tf.data.Dataset objects that efficiently load and batch
|
|
brain-to-text data from HDF5 files with TPU-optimized operations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
trial_indices: Dict[int, Dict[str, Any]],
|
|
n_batches: Optional[int],
|
|
split: str = 'train',
|
|
batch_size: int = 64,
|
|
days_per_batch: int = 1,
|
|
random_seed: int = -1,
|
|
must_include_days: Optional[List[int]] = None,
|
|
feature_subset: Optional[List[int]] = None,
|
|
prefetch_buffer: int = tf.data.AUTOTUNE,
|
|
num_parallel_calls: int = tf.data.AUTOTUNE
|
|
):
|
|
"""
|
|
Initialize TensorFlow dataset for brain-to-text data
|
|
|
|
Args:
|
|
trial_indices: Dictionary with day numbers as keys and trial info as values
|
|
n_batches: Number of training batches to create (None for validation)
|
|
split: 'train' or 'test'
|
|
batch_size: Number of examples per batch
|
|
days_per_batch: Number of unique days per batch (for day-specific layers)
|
|
random_seed: Random seed for reproducibility
|
|
must_include_days: Days that must be included in every batch
|
|
feature_subset: Subset of neural features to use
|
|
prefetch_buffer: Buffer size for prefetching
|
|
num_parallel_calls: Parallel processing threads
|
|
"""
|
|
|
|
# Set random seed for reproducibility
|
|
if random_seed != -1:
|
|
tf.random.set_seed(random_seed)
|
|
np.random.seed(random_seed)
|
|
|
|
self.split = split
|
|
if self.split not in ['train', 'test']:
|
|
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
|
|
|
|
self.days_per_batch = days_per_batch
|
|
self.batch_size = batch_size
|
|
self.n_batches = n_batches
|
|
self.trial_indices = trial_indices
|
|
self.n_days = len(trial_indices.keys())
|
|
self.feature_subset = feature_subset
|
|
self.must_include_days = must_include_days
|
|
self.prefetch_buffer = prefetch_buffer
|
|
self.num_parallel_calls = num_parallel_calls
|
|
|
|
# Calculate total number of trials
|
|
self.n_trials = 0
|
|
for d in trial_indices:
|
|
self.n_trials += len(trial_indices[d]['trials'])
|
|
|
|
# Validation checks
|
|
if must_include_days is not None:
|
|
if len(must_include_days) > days_per_batch:
|
|
raise ValueError(f'must_include_days must be <= days_per_batch')
|
|
|
|
# Map negative indices
|
|
for i, d in enumerate(must_include_days):
|
|
if d < 0:
|
|
must_include_days[i] = self.n_days + d
|
|
|
|
if self.split == 'train' and self.days_per_batch > self.n_days:
|
|
raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})')
|
|
|
|
# Create batch indices
|
|
if self.split == 'train':
|
|
self.batch_indices = self._create_batch_index_train()
|
|
else:
|
|
self.batch_indices = self._create_batch_index_test()
|
|
self.n_batches = len(self.batch_indices)
|
|
|
|
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
|
|
"""Create training batch indices with random sampling"""
|
|
batch_indices = {}
|
|
|
|
# Precompute non-must-include days
|
|
if self.must_include_days is not None:
|
|
non_must_include_days = [
|
|
d for d in self.trial_indices.keys()
|
|
if d not in self.must_include_days
|
|
]
|
|
|
|
for batch_idx in range(self.n_batches):
|
|
batch = {}
|
|
|
|
# Select days for this batch
|
|
if self.must_include_days is not None and len(self.must_include_days) > 0:
|
|
additional_days = np.random.choice(
|
|
non_must_include_days,
|
|
size=self.days_per_batch - len(self.must_include_days),
|
|
replace=False
|
|
)
|
|
days = np.concatenate((self.must_include_days, additional_days))
|
|
else:
|
|
days = np.random.choice(
|
|
list(self.trial_indices.keys()),
|
|
size=self.days_per_batch,
|
|
replace=False
|
|
)
|
|
|
|
# Calculate trials per day
|
|
num_trials = math.ceil(self.batch_size / self.days_per_batch)
|
|
|
|
for d in days:
|
|
# Sample trials with replacement
|
|
trial_idxs = np.random.choice(
|
|
self.trial_indices[d]['trials'],
|
|
size=num_trials,
|
|
replace=True
|
|
)
|
|
batch[d] = trial_idxs.tolist()
|
|
|
|
# Remove extra trials to match exact batch size
|
|
extra_trials = (num_trials * len(days)) - self.batch_size
|
|
while extra_trials > 0:
|
|
d = np.random.choice(days)
|
|
if len(batch[d]) > 0:
|
|
batch[d] = batch[d][:-1]
|
|
extra_trials -= 1
|
|
|
|
batch_indices[batch_idx] = batch
|
|
|
|
return batch_indices
|
|
|
|
def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]:
|
|
"""Create test batch indices ensuring all trials are seen once"""
|
|
batch_indices = {}
|
|
batch_idx = 0
|
|
|
|
for d in self.trial_indices.keys():
|
|
num_trials = len(self.trial_indices[d]['trials'])
|
|
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
|
|
|
|
for i in range(num_batches):
|
|
start_idx = i * self.batch_size
|
|
end_idx = min((i + 1) * self.batch_size, num_trials)
|
|
|
|
batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx]
|
|
batch_indices[batch_idx] = {d: batch_trials}
|
|
batch_idx += 1
|
|
|
|
return batch_indices
|
|
|
|
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
|
|
"""Load a single trial's data from HDF5 file"""
|
|
try:
|
|
session_path = self.trial_indices[day]['session_path']
|
|
|
|
with h5py.File(session_path, 'r') as f:
|
|
g = f[f'trial_{trial:04d}']
|
|
|
|
# Load neural features
|
|
input_features = g['input_features'][:]
|
|
if self.feature_subset:
|
|
input_features = input_features[:, self.feature_subset]
|
|
|
|
# Convert to bfloat16 for TPU efficiency
|
|
input_features = input_features.astype(np.float32) # TF will handle bfloat16 conversion
|
|
|
|
trial_data = {
|
|
'input_features': input_features,
|
|
'seq_class_ids': g['seq_class_ids'][:],
|
|
'transcription': g['transcription'][:],
|
|
'n_time_steps': g.attrs['n_time_steps'],
|
|
'phone_seq_lens': g.attrs['seq_len'],
|
|
'day_index': day,
|
|
'block_num': g.attrs['block_num'],
|
|
'trial_num': g.attrs['trial_num']
|
|
}
|
|
|
|
return trial_data
|
|
|
|
except Exception as e:
|
|
print(f'Error loading trial {trial} from day {day}: {e}')
|
|
# Return dummy data to maintain batch structure
|
|
return {
|
|
'input_features': np.zeros((100, 512), dtype=np.float32),
|
|
'seq_class_ids': np.zeros((10,), dtype=np.int32),
|
|
'transcription': np.zeros((50,), dtype=np.int32),
|
|
'n_time_steps': 100,
|
|
'phone_seq_lens': 10,
|
|
'day_index': day,
|
|
'block_num': 0,
|
|
'trial_num': 0
|
|
}
|
|
|
|
def _create_batch_generator(self):
|
|
"""Generator function that yields individual batches"""
|
|
for batch_idx in range(self.n_batches):
|
|
batch_data = {
|
|
'input_features': [],
|
|
'seq_class_ids': [],
|
|
'n_time_steps': [],
|
|
'phone_seq_lens': [],
|
|
'day_indices': [],
|
|
'transcriptions': [],
|
|
'block_nums': [],
|
|
'trial_nums': []
|
|
}
|
|
|
|
batch_index = self.batch_indices[batch_idx]
|
|
|
|
# Load data for each day in the batch
|
|
for day in batch_index.keys():
|
|
for trial in batch_index[day]:
|
|
trial_data = self._load_trial_data(day, trial)
|
|
|
|
batch_data['input_features'].append(trial_data['input_features'])
|
|
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
|
|
batch_data['transcriptions'].append(trial_data['transcription'])
|
|
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
|
|
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
|
|
batch_data['day_indices'].append(trial_data['day_index'])
|
|
batch_data['block_nums'].append(trial_data['block_num'])
|
|
batch_data['trial_nums'].append(trial_data['trial_num'])
|
|
|
|
# Pad sequences to create uniform batch
|
|
max_time_steps = max(batch_data['n_time_steps'])
|
|
max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
|
|
max_transcription_len = max(len(trans) for trans in batch_data['transcriptions'])
|
|
|
|
# Pad input features
|
|
padded_features = []
|
|
for features in batch_data['input_features']:
|
|
if features.shape[0] < max_time_steps:
|
|
padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32)
|
|
features = np.vstack([features, padding])
|
|
padded_features.append(features)
|
|
|
|
# Pad sequences
|
|
padded_seq_ids = []
|
|
for seq in batch_data['seq_class_ids']:
|
|
if len(seq) < max_phone_len:
|
|
padding = np.zeros(max_phone_len - len(seq), dtype=np.int32)
|
|
seq = np.concatenate([seq, padding])
|
|
padded_seq_ids.append(seq)
|
|
|
|
# Pad transcriptions
|
|
padded_transcriptions = []
|
|
for trans in batch_data['transcriptions']:
|
|
if len(trans) < max_transcription_len:
|
|
padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32)
|
|
trans = np.concatenate([trans, padding])
|
|
padded_transcriptions.append(trans)
|
|
|
|
# Create final batch tensors
|
|
batch = {
|
|
'input_features': np.stack(padded_features),
|
|
'seq_class_ids': np.stack(padded_seq_ids),
|
|
'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32),
|
|
'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32),
|
|
'day_indices': np.array(batch_data['day_indices'], dtype=np.int32),
|
|
'transcriptions': np.stack(padded_transcriptions),
|
|
'block_nums': np.array(batch_data['block_nums'], dtype=np.int32),
|
|
'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32)
|
|
}
|
|
|
|
yield batch
|
|
|
|
def create_dataset(self) -> tf.data.Dataset:
|
|
"""Create optimized tf.data.Dataset for TPU training"""
|
|
|
|
# Define output signature for the dataset
|
|
output_signature = {
|
|
'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
|
|
'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
|
|
'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
|
'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
|
'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
|
'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
|
|
'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
|
'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32)
|
|
}
|
|
|
|
# Create dataset from generator
|
|
dataset = tf.data.Dataset.from_generator(
|
|
self._create_batch_generator,
|
|
output_signature=output_signature
|
|
)
|
|
|
|
# Apply TPU-optimized transformations
|
|
if self.split == 'train':
|
|
# For training, add shuffling
|
|
dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches))
|
|
|
|
# Prefetch for better performance
|
|
dataset = dataset.prefetch(self.prefetch_buffer)
|
|
|
|
return dataset
|
|
|
|
|
|
class DataAugmentationTF:
|
|
"""
|
|
TensorFlow data augmentation functions optimized for TPU v5e-8
|
|
"""
|
|
|
|
@staticmethod
|
|
def gauss_smooth(inputs: tf.Tensor,
|
|
smooth_kernel_std: float = 2.0,
|
|
smooth_kernel_size: int = 100) -> tf.Tensor:
|
|
"""
|
|
Apply Gaussian smoothing along the time axis using TensorFlow operations
|
|
|
|
Args:
|
|
inputs: Input tensor [batch_size, time_steps, features]
|
|
smooth_kernel_std: Standard deviation of Gaussian kernel
|
|
smooth_kernel_size: Size of the Gaussian kernel
|
|
|
|
Returns:
|
|
Smoothed tensor with same shape as input
|
|
"""
|
|
# Create Gaussian kernel using numpy (computed once)
|
|
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
|
|
inp[smooth_kernel_size // 2] = 1
|
|
gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std)
|
|
valid_idx = np.argwhere(gauss_kernel > 0.01)
|
|
gauss_kernel = gauss_kernel[valid_idx].flatten()
|
|
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
|
|
|
|
# Convert to TensorFlow tensor and reshape for conv1d
|
|
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
|
|
kernel_size = tf.shape(gauss_kernel)[0]
|
|
gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels]
|
|
|
|
# Get tensor dimensions
|
|
batch_size = tf.shape(inputs)[0]
|
|
time_steps = tf.shape(inputs)[1]
|
|
num_features = tf.shape(inputs)[2]
|
|
|
|
# Apply convolution to each feature channel separately
|
|
smoothed_features = []
|
|
|
|
# Convert num_features to Python int for loop
|
|
num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
|
|
|
|
if isinstance(num_features_py, tf.Tensor):
|
|
# If dynamic, use tf.map_fn for dynamic number of features
|
|
def smooth_single_feature(i):
|
|
# Extract single feature channel: [batch_size, time_steps, 1]
|
|
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
|
# Apply 1D convolution
|
|
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
|
|
|
# Use tf.map_fn for dynamic features
|
|
indices = tf.range(num_features)
|
|
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
|
|
# Transpose to get [batch_size, time_steps, features]
|
|
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
|
smoothed = tf.squeeze(smoothed, axis=-1)
|
|
else:
|
|
# Static number of features - use loop
|
|
for i in range(num_features_py):
|
|
# Extract single feature channel: [batch_size, time_steps, 1]
|
|
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
|
# Apply 1D convolution
|
|
smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
|
smoothed_features.append(smoothed_channel)
|
|
|
|
# Concatenate all smoothed features
|
|
smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features]
|
|
|
|
return smoothed
|
|
|
|
@staticmethod
|
|
def transform_data(features: tf.Tensor,
|
|
n_time_steps: tf.Tensor,
|
|
transform_args: Dict[str, Any],
|
|
training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
"""
|
|
Apply data transformations optimized for TPU
|
|
|
|
Args:
|
|
features: Input features [batch_size, time_steps, channels]
|
|
n_time_steps: Number of valid time steps per sample
|
|
transform_args: Transformation configuration
|
|
training: Whether to apply training-only augmentations
|
|
|
|
Returns:
|
|
Transformed features and updated time steps
|
|
"""
|
|
batch_size = tf.shape(features)[0]
|
|
time_steps = tf.shape(features)[1]
|
|
channels = tf.shape(features)[2]
|
|
|
|
# Training-only augmentations
|
|
if training:
|
|
# Static gain noise
|
|
if transform_args.get('static_gain_std', 0) > 0:
|
|
gain_std = transform_args['static_gain_std']
|
|
# Create identity matrices for each batch
|
|
identity_matrices = tf.eye(channels, batch_shape=[batch_size])
|
|
# Add noise to create warp matrices
|
|
noise = tf.random.normal([batch_size, channels, channels]) * gain_std
|
|
warp_matrices = identity_matrices + noise
|
|
# Apply transformation
|
|
features = tf.linalg.matmul(features, warp_matrices)
|
|
|
|
# White noise
|
|
if transform_args.get('white_noise_std', 0) > 0:
|
|
white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std']
|
|
features = features + white_noise
|
|
|
|
# Constant offset noise
|
|
if transform_args.get('constant_offset_std', 0) > 0:
|
|
offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std']
|
|
features = features + offset_noise
|
|
|
|
# Random walk noise
|
|
if transform_args.get('random_walk_std', 0) > 0:
|
|
random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std']
|
|
axis = transform_args.get('random_walk_axis', 1)
|
|
random_walk_noise = tf.cumsum(random_walk_noise, axis=axis)
|
|
features = features + random_walk_noise
|
|
|
|
# Random cutoff (simplified for TPU - apply to all samples in batch)
|
|
if transform_args.get('random_cut', 0) > 0:
|
|
max_cut = transform_args['random_cut']
|
|
cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32)
|
|
features = features[:, cut:, :]
|
|
n_time_steps = n_time_steps - cut
|
|
|
|
# Apply Gaussian smoothing (both training and validation)
|
|
if transform_args.get('smooth_data', False):
|
|
features = DataAugmentationTF.gauss_smooth(
|
|
features,
|
|
smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0),
|
|
smooth_kernel_size=transform_args.get('smooth_kernel_size', 100)
|
|
)
|
|
|
|
return features, n_time_steps
|
|
|
|
|
|
def train_test_split_indices(file_paths: List[str],
|
|
test_percentage: float = 0.1,
|
|
seed: int = -1,
|
|
bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]:
|
|
"""
|
|
Split data from file_paths into train and test splits
|
|
|
|
Args:
|
|
file_paths: List of HDF5 file paths
|
|
test_percentage: Percentage of trials for testing
|
|
seed: Random seed for reproducibility
|
|
bad_trials_dict: Dictionary of trials to exclude
|
|
|
|
Returns:
|
|
Tuple of (train_trials, test_trials) dictionaries
|
|
"""
|
|
# Set seed for reproducibility
|
|
if seed != -1:
|
|
np.random.seed(seed)
|
|
|
|
# Get trials in each day
|
|
trials_per_day = {}
|
|
for i, path in enumerate(file_paths):
|
|
# Handle both Windows and Unix path separators
|
|
path_parts = path.replace('\\', '/').split('/')
|
|
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
|
|
|
good_trial_indices = []
|
|
|
|
if os.path.exists(path):
|
|
with h5py.File(path, 'r') as f:
|
|
num_trials = len(list(f.keys()))
|
|
for t in range(num_trials):
|
|
key = f'trial_{t:04d}'
|
|
|
|
if key not in f:
|
|
continue
|
|
|
|
block_num = f[key].attrs['block_num']
|
|
trial_num = f[key].attrs['trial_num']
|
|
|
|
# Check if trial should be excluded
|
|
if (bad_trials_dict is not None
|
|
and session in bad_trials_dict
|
|
and str(block_num) in bad_trials_dict[session]
|
|
and trial_num in bad_trials_dict[session][str(block_num)]):
|
|
continue
|
|
|
|
good_trial_indices.append(t)
|
|
|
|
trials_per_day[i] = {
|
|
'num_trials': len(good_trial_indices),
|
|
'trial_indices': good_trial_indices,
|
|
'session_path': path
|
|
}
|
|
|
|
# Split trials into train and test
|
|
train_trials = {}
|
|
test_trials = {}
|
|
|
|
for day in trials_per_day.keys():
|
|
num_trials = trials_per_day[day]['num_trials']
|
|
all_trial_indices = trials_per_day[day]['trial_indices']
|
|
|
|
if test_percentage == 0:
|
|
train_trials[day] = {
|
|
'trials': all_trial_indices,
|
|
'session_path': trials_per_day[day]['session_path']
|
|
}
|
|
test_trials[day] = {
|
|
'trials': [],
|
|
'session_path': trials_per_day[day]['session_path']
|
|
}
|
|
elif test_percentage == 1:
|
|
train_trials[day] = {
|
|
'trials': [],
|
|
'session_path': trials_per_day[day]['session_path']
|
|
}
|
|
test_trials[day] = {
|
|
'trials': all_trial_indices,
|
|
'session_path': trials_per_day[day]['session_path']
|
|
}
|
|
else:
|
|
# Calculate number of test trials
|
|
num_test = max(1, int(num_trials * test_percentage))
|
|
|
|
# Randomly select test indices
|
|
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
|
|
|
|
# Remaining indices for training
|
|
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
|
|
|
|
train_trials[day] = {
|
|
'trials': train_indices,
|
|
'session_path': trials_per_day[day]['session_path']
|
|
}
|
|
test_trials[day] = {
|
|
'trials': test_indices,
|
|
'session_path': trials_per_day[day]['session_path']
|
|
}
|
|
|
|
return train_trials, test_trials
|
|
|
|
|
|
# Utility functions for TPU-optimized data pipeline
|
|
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|
transform_args: Dict[str, Any],
|
|
training: bool = True) -> tf.data.Dataset:
|
|
"""
|
|
Create input function for TPU training with data augmentation
|
|
|
|
Args:
|
|
dataset_tf: BrainToTextDatasetTF instance
|
|
transform_args: Data transformation configuration
|
|
training: Whether this is for training (applies augmentations)
|
|
|
|
Returns:
|
|
tf.data.Dataset ready for TPU training
|
|
"""
|
|
dataset = dataset_tf.create_dataset()
|
|
|
|
def apply_transforms(batch):
|
|
"""Apply data transformations to a batch"""
|
|
features = batch['input_features']
|
|
n_time_steps = batch['n_time_steps']
|
|
|
|
# Apply transformations
|
|
features, n_time_steps = DataAugmentationTF.transform_data(
|
|
features, n_time_steps, transform_args, training=training
|
|
)
|
|
|
|
# Update batch with transformed data
|
|
batch['input_features'] = features
|
|
batch['n_time_steps'] = n_time_steps
|
|
|
|
return batch
|
|
|
|
# Apply transformations
|
|
dataset = dataset.map(
|
|
apply_transforms,
|
|
num_parallel_calls=tf.data.AUTOTUNE
|
|
)
|
|
|
|
return dataset |