2025-10-15 16:55:52 +08:00
import os
2025-10-21 00:19:05 +08:00
import sys
2025-10-15 16:55:52 +08:00
import tensorflow as tf
import h5py
import numpy as np
import math
2025-10-19 10:31:18 +08:00
import logging
2025-10-20 00:35:17 +08:00
import time
import random
import multiprocessing
2025-10-19 10:31:18 +08:00
from itertools import groupby
2025-10-15 16:55:52 +08:00
from typing import Dict , List , Tuple , Optional , Any
from scipy . ndimage import gaussian_filter1d
2025-10-20 00:35:17 +08:00
from concurrent . futures import ThreadPoolExecutor , as_completed
2025-10-15 16:55:52 +08:00
class BrainToTextDatasetTF :
"""
TensorFlow Dataset for brain - to - text data optimized for TPU v5e - 8
This class creates tf . data . Dataset objects that efficiently load and batch
brain - to - text data from HDF5 files with TPU - optimized operations .
"""
def __init__ (
self ,
trial_indices : Dict [ int , Dict [ str , Any ] ] ,
n_batches : Optional [ int ] ,
split : str = ' train ' ,
batch_size : int = 64 ,
days_per_batch : int = 1 ,
random_seed : int = - 1 ,
must_include_days : Optional [ List [ int ] ] = None ,
feature_subset : Optional [ List [ int ] ] = None ,
prefetch_buffer : int = tf . data . AUTOTUNE ,
2025-10-16 17:14:06 +08:00
num_parallel_calls : int = tf . data . AUTOTUNE ,
cache_data : bool = True ,
preload_all_data : bool = False
2025-10-15 16:55:52 +08:00
) :
"""
Initialize TensorFlow dataset for brain - to - text data
Args :
trial_indices : Dictionary with day numbers as keys and trial info as values
n_batches : Number of training batches to create ( None for validation )
split : ' train ' or ' test '
batch_size : Number of examples per batch
days_per_batch : Number of unique days per batch ( for day - specific layers )
random_seed : Random seed for reproducibility
must_include_days : Days that must be included in every batch
feature_subset : Subset of neural features to use
prefetch_buffer : Buffer size for prefetching
num_parallel_calls : Parallel processing threads
2025-10-16 17:14:06 +08:00
cache_data : Whether to cache loaded data in memory
preload_all_data : Whether to preload all data at initialization
2025-10-15 16:55:52 +08:00
"""
# Set random seed for reproducibility
if random_seed != - 1 :
tf . random . set_seed ( random_seed )
np . random . seed ( random_seed )
self . split = split
if self . split not in [ ' train ' , ' test ' ] :
raise ValueError ( f ' split must be either " train " or " test " . Received { self . split } ' )
self . days_per_batch = days_per_batch
self . batch_size = batch_size
self . n_batches = n_batches
self . trial_indices = trial_indices
self . n_days = len ( trial_indices . keys ( ) )
self . feature_subset = feature_subset
self . must_include_days = must_include_days
self . prefetch_buffer = prefetch_buffer
self . num_parallel_calls = num_parallel_calls
2025-10-16 17:14:06 +08:00
self . cache_data = cache_data
self . preload_all_data = preload_all_data
# Initialize data cache
self . data_cache = { } if cache_data else None
2025-10-15 16:55:52 +08:00
# Calculate total number of trials
self . n_trials = 0
for d in trial_indices :
self . n_trials + = len ( trial_indices [ d ] [ ' trials ' ] )
# Validation checks
if must_include_days is not None :
if len ( must_include_days ) > days_per_batch :
raise ValueError ( f ' must_include_days must be <= days_per_batch ' )
# Map negative indices
for i , d in enumerate ( must_include_days ) :
if d < 0 :
must_include_days [ i ] = self . n_days + d
if self . split == ' train ' and self . days_per_batch > self . n_days :
raise ValueError ( f ' days_per_batch ( { days_per_batch } ) > available days ( { self . n_days } ) ' )
# Create batch indices
if self . split == ' train ' :
self . batch_indices = self . _create_batch_index_train ( )
else :
self . batch_indices = self . _create_batch_index_test ( )
self . n_batches = len ( self . batch_indices )
2025-10-16 17:14:06 +08:00
# Preload data if requested (speeds up first batch significantly)
if self . preload_all_data :
print ( f " 🔄 Preloading all data for { self . split } split... " )
self . _preload_all_data ( )
print ( f " ✅ Preloading completed - { len ( self . data_cache ) } trials cached " )
2025-10-22 00:54:20 +08:00
# ========================= 特征维度自动检测 (更稳健的版本) =========================
2025-10-21 00:19:05 +08:00
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
if self . feature_subset :
self . feature_dim = len ( self . feature_subset )
print ( f " ✅ Using feature subset dimension: { self . feature_dim } " )
else :
2025-10-22 00:54:20 +08:00
# 稳健地从数据中推断实际特征维度
# 遍历所有 day, 直到找到第一个包含 trial 的 day
detected_dim = None
for day in self . trial_indices :
# 检查 trial 列表是否非空
if self . trial_indices [ day ] [ ' trials ' ] :
try :
# 列表非空,尝试加载第一个 trial
first_valid_trial = self . trial_indices [ day ] [ ' trials ' ] [ 0 ]
first_sample = self . _load_single_trial_data ( day , first_valid_trial )
detected_dim = first_sample [ ' input_features ' ] . shape [ 1 ]
print ( f " ✅ Auto-detected feature dimension from day { day } : { detected_dim } " )
break # 成功检测到维度,跳出循环
except Exception as e :
# 如果加载这个 trial 失败,则继续尝试下一个 day
print ( f " ⚠️ Warning: Could not load trial { first_valid_trial } from day { day } for dimension check. Error: { e } " )
continue
if detected_dim is not None :
self . feature_dim = detected_dim
else :
# 如果遍历完所有 day 都没有找到任何有效的 trial, 则报错或回退
print ( f " ⚠️ CRITICAL: Could not auto-detect feature dimension after checking all days. No valid trials found in the dataset split ' { self . split } ' . Falling back to 512. " )
2025-10-21 00:19:05 +08:00
self . feature_dim = 512 # 作为最后的备用方案
# ========================= 特征维度检测结束 =========================
2025-10-15 16:55:52 +08:00
def _create_batch_index_train ( self ) - > Dict [ int , Dict [ int , List [ int ] ] ] :
""" Create training batch indices with random sampling """
batch_indices = { }
# Precompute non-must-include days
if self . must_include_days is not None :
non_must_include_days = [
d for d in self . trial_indices . keys ( )
if d not in self . must_include_days
]
for batch_idx in range ( self . n_batches ) :
batch = { }
# Select days for this batch
if self . must_include_days is not None and len ( self . must_include_days ) > 0 :
additional_days = np . random . choice (
non_must_include_days ,
size = self . days_per_batch - len ( self . must_include_days ) ,
replace = False
)
days = np . concatenate ( ( self . must_include_days , additional_days ) )
else :
days = np . random . choice (
list ( self . trial_indices . keys ( ) ) ,
size = self . days_per_batch ,
replace = False
)
# Calculate trials per day
num_trials = math . ceil ( self . batch_size / self . days_per_batch )
for d in days :
# Sample trials with replacement
trial_idxs = np . random . choice (
self . trial_indices [ d ] [ ' trials ' ] ,
size = num_trials ,
replace = True
)
batch [ d ] = trial_idxs . tolist ( )
# Remove extra trials to match exact batch size
extra_trials = ( num_trials * len ( days ) ) - self . batch_size
while extra_trials > 0 :
d = np . random . choice ( days )
if len ( batch [ d ] ) > 0 :
batch [ d ] = batch [ d ] [ : - 1 ]
extra_trials - = 1
batch_indices [ batch_idx ] = batch
return batch_indices
def _create_batch_index_test ( self ) - > Dict [ int , Dict [ int , List [ int ] ] ] :
""" Create test batch indices ensuring all trials are seen once """
batch_indices = { }
batch_idx = 0
for d in self . trial_indices . keys ( ) :
num_trials = len ( self . trial_indices [ d ] [ ' trials ' ] )
num_batches = ( num_trials + self . batch_size - 1 ) / / self . batch_size
for i in range ( num_batches ) :
start_idx = i * self . batch_size
end_idx = min ( ( i + 1 ) * self . batch_size , num_trials )
batch_trials = self . trial_indices [ d ] [ ' trials ' ] [ start_idx : end_idx ]
batch_indices [ batch_idx ] = { d : batch_trials }
batch_idx + = 1
return batch_indices
2025-10-16 17:14:06 +08:00
def _preload_all_data ( self ) :
""" Preload all trial data into memory cache (uses available RAM optimally) """
import multiprocessing
from concurrent . futures import ThreadPoolExecutor , as_completed
# Use CPU cores efficiently for parallel I/O
max_workers = min ( multiprocessing . cpu_count ( ) , 32 ) # Limit to avoid overwhelming I/O
# Collect all trials to load
trials_to_load = [ ]
for day in self . trial_indices :
for trial in self . trial_indices [ day ] [ ' trials ' ] :
trials_to_load . append ( ( day , trial ) )
print ( f " 📊 Preloading { len ( trials_to_load ) } trials using { max_workers } workers... " )
# Parallel loading using ThreadPoolExecutor
with ThreadPoolExecutor ( max_workers = max_workers ) as executor :
# Submit all loading tasks
future_to_trial = {
executor . submit ( self . _load_single_trial_data , day , trial ) : ( day , trial )
for day , trial in trials_to_load
}
# Process completed tasks and update cache
loaded_count = 0
for future in as_completed ( future_to_trial ) :
day , trial = future_to_trial [ future ]
try :
trial_data = future . result ( )
cache_key = f " { day } _ { trial } "
self . data_cache [ cache_key ] = trial_data
loaded_count + = 1
# Progress indicator every 100 trials
if loaded_count % 100 == 0 :
print ( f " Loaded { loaded_count } / { len ( trials_to_load ) } trials... " )
except Exception as e :
print ( f " Warning: Failed to load trial { day } _ { trial } : { e } " )
print ( f " ✅ Preloading completed: { loaded_count } / { len ( trials_to_load ) } trials cached " )
def _load_single_trial_data ( self , day : int , trial : int ) - > Dict [ str , Any ] :
""" Load a single trial ' s data - optimized version for parallel loading """
2025-10-15 16:55:52 +08:00
try :
session_path = self . trial_indices [ day ] [ ' session_path ' ]
with h5py . File ( session_path , ' r ' ) as f :
g = f [ f ' trial_ { trial : 04d } ' ]
# Load neural features
input_features = g [ ' input_features ' ] [ : ]
if self . feature_subset :
input_features = input_features [ : , self . feature_subset ]
2025-10-16 17:14:06 +08:00
# Convert to float32 for TF compatibility
input_features = input_features . astype ( np . float32 )
2025-10-15 16:55:52 +08:00
trial_data = {
' input_features ' : input_features ,
' seq_class_ids ' : g [ ' seq_class_ids ' ] [ : ] ,
' transcription ' : g [ ' transcription ' ] [ : ] ,
' n_time_steps ' : g . attrs [ ' n_time_steps ' ] ,
' phone_seq_lens ' : g . attrs [ ' seq_len ' ] ,
' day_index ' : day ,
' block_num ' : g . attrs [ ' block_num ' ] ,
' trial_num ' : g . attrs [ ' trial_num ' ]
}
return trial_data
except Exception as e :
2025-10-22 00:28:10 +08:00
# Log the error and return dummy data with correct feature dimension
logging . warning ( f " Failed to load trial { day } _ { trial } from { session_path } . Error: { e } . Returning dummy data. " )
# Use self.feature_dim to ensure dimension consistency
feature_dim = self . feature_dim
2025-10-15 16:55:52 +08:00
return {
2025-10-22 00:28:10 +08:00
' input_features ' : np . zeros ( ( 100 , feature_dim ) , dtype = np . float32 ) ,
2025-10-15 16:55:52 +08:00
' seq_class_ids ' : np . zeros ( ( 10 , ) , dtype = np . int32 ) ,
' transcription ' : np . zeros ( ( 50 , ) , dtype = np . int32 ) ,
' n_time_steps ' : 100 ,
' phone_seq_lens ' : 10 ,
' day_index ' : day ,
' block_num ' : 0 ,
' trial_num ' : 0
}
2025-10-16 17:14:06 +08:00
def _load_trial_data ( self , day : int , trial : int ) - > Dict [ str , tf . Tensor ] :
""" Load a single trial ' s data from cache or HDF5 file """
# Check cache first if caching is enabled
if self . cache_data :
cache_key = f " { day } _ { trial } "
if cache_key in self . data_cache :
return self . data_cache [ cache_key ]
# Load from disk if not in cache
trial_data = self . _load_single_trial_data ( day , trial )
# Cache the loaded data if caching is enabled
if self . cache_data :
cache_key = f " { day } _ { trial } "
self . data_cache [ cache_key ] = trial_data
return trial_data
2025-10-15 16:55:52 +08:00
def _create_batch_generator ( self ) :
2025-10-22 00:28:10 +08:00
"""
Generator function that yields individual batches with optimized loading
⚠ ️ DEPRECATED : This method is deprecated . Use create_input_fn ( ) instead for better performance
and TPU compatibility . This method will be removed in a future version .
"""
import warnings
warnings . warn (
" _create_batch_generator is deprecated. Use create_input_fn() instead for better performance. " ,
DeprecationWarning ,
stacklevel = 2
)
2025-10-16 17:14:06 +08:00
import time
from concurrent . futures import ThreadPoolExecutor
2025-10-15 16:55:52 +08:00
for batch_idx in range ( self . n_batches ) :
2025-10-16 17:14:06 +08:00
batch_start_time = time . time ( )
2025-10-15 16:55:52 +08:00
batch_data = {
' input_features ' : [ ] ,
' seq_class_ids ' : [ ] ,
' n_time_steps ' : [ ] ,
' phone_seq_lens ' : [ ] ,
' day_indices ' : [ ] ,
' transcriptions ' : [ ] ,
' block_nums ' : [ ] ,
' trial_nums ' : [ ]
}
batch_index = self . batch_indices [ batch_idx ]
2025-10-16 17:14:06 +08:00
# Collect all trials to load for this batch
trials_to_load = [ ]
2025-10-15 16:55:52 +08:00
for day in batch_index . keys ( ) :
for trial in batch_index [ day ] :
2025-10-16 17:14:06 +08:00
trials_to_load . append ( ( day , trial ) )
# Use parallel loading if not preloaded and have multiple trials
if not self . preload_all_data and len ( trials_to_load ) > 4 :
# Parallel loading for faster I/O
with ThreadPoolExecutor ( max_workers = min ( 8 , len ( trials_to_load ) ) ) as executor :
future_to_trial = {
executor . submit ( self . _load_trial_data , day , trial ) : ( day , trial )
for day , trial in trials_to_load
}
# Collect results in order
trial_results = { }
for future in future_to_trial :
day , trial = future_to_trial [ future ]
trial_results [ ( day , trial ) ] = future . result ( )
# Add data in original order
for day , trial in trials_to_load :
trial_data = trial_results [ ( day , trial ) ]
batch_data [ ' input_features ' ] . append ( trial_data [ ' input_features ' ] )
batch_data [ ' seq_class_ids ' ] . append ( trial_data [ ' seq_class_ids ' ] )
batch_data [ ' transcriptions ' ] . append ( trial_data [ ' transcription ' ] )
batch_data [ ' n_time_steps ' ] . append ( trial_data [ ' n_time_steps ' ] )
batch_data [ ' phone_seq_lens ' ] . append ( trial_data [ ' phone_seq_lens ' ] )
batch_data [ ' day_indices ' ] . append ( trial_data [ ' day_index ' ] )
batch_data [ ' block_nums ' ] . append ( trial_data [ ' block_num ' ] )
batch_data [ ' trial_nums ' ] . append ( trial_data [ ' trial_num ' ] )
else :
# Sequential loading (fast when data is cached or few trials)
for day , trial in trials_to_load :
2025-10-15 16:55:52 +08:00
trial_data = self . _load_trial_data ( day , trial )
batch_data [ ' input_features ' ] . append ( trial_data [ ' input_features ' ] )
batch_data [ ' seq_class_ids ' ] . append ( trial_data [ ' seq_class_ids ' ] )
batch_data [ ' transcriptions ' ] . append ( trial_data [ ' transcription ' ] )
batch_data [ ' n_time_steps ' ] . append ( trial_data [ ' n_time_steps ' ] )
batch_data [ ' phone_seq_lens ' ] . append ( trial_data [ ' phone_seq_lens ' ] )
batch_data [ ' day_indices ' ] . append ( trial_data [ ' day_index ' ] )
batch_data [ ' block_nums ' ] . append ( trial_data [ ' block_num ' ] )
batch_data [ ' trial_nums ' ] . append ( trial_data [ ' trial_num ' ] )
2025-10-16 17:14:06 +08:00
data_loading_time = time . time ( ) - batch_start_time
# Add timing diagnostic for first few batches
if batch_idx < 3 :
cache_status = " cached " if self . preload_all_data else " disk "
loading_method = " parallel " if ( not self . preload_all_data and len ( trials_to_load ) > 4 ) else " sequential "
print ( f " ⏱️ Batch { batch_idx } : { len ( trials_to_load ) } trials loaded in { data_loading_time : .3f } s ( { cache_status } , { loading_method } ) " )
2025-10-15 16:55:52 +08:00
# Pad sequences to create uniform batch
max_time_steps = max ( batch_data [ ' n_time_steps ' ] )
max_phone_len = max ( len ( seq ) for seq in batch_data [ ' seq_class_ids ' ] )
max_transcription_len = max ( len ( trans ) for trans in batch_data [ ' transcriptions ' ] )
# Pad input features
padded_features = [ ]
for features in batch_data [ ' input_features ' ] :
if features . shape [ 0 ] < max_time_steps :
padding = np . zeros ( ( max_time_steps - features . shape [ 0 ] , features . shape [ 1 ] ) , dtype = np . float32 )
features = np . vstack ( [ features , padding ] )
padded_features . append ( features )
# Pad sequences
padded_seq_ids = [ ]
for seq in batch_data [ ' seq_class_ids ' ] :
if len ( seq ) < max_phone_len :
padding = np . zeros ( max_phone_len - len ( seq ) , dtype = np . int32 )
seq = np . concatenate ( [ seq , padding ] )
padded_seq_ids . append ( seq )
# Pad transcriptions
padded_transcriptions = [ ]
for trans in batch_data [ ' transcriptions ' ] :
if len ( trans ) < max_transcription_len :
padding = np . zeros ( max_transcription_len - len ( trans ) , dtype = np . int32 )
trans = np . concatenate ( [ trans , padding ] )
padded_transcriptions . append ( trans )
# Create final batch tensors
batch = {
' input_features ' : np . stack ( padded_features ) ,
' seq_class_ids ' : np . stack ( padded_seq_ids ) ,
' n_time_steps ' : np . array ( batch_data [ ' n_time_steps ' ] , dtype = np . int32 ) ,
' phone_seq_lens ' : np . array ( batch_data [ ' phone_seq_lens ' ] , dtype = np . int32 ) ,
' day_indices ' : np . array ( batch_data [ ' day_indices ' ] , dtype = np . int32 ) ,
' transcriptions ' : np . stack ( padded_transcriptions ) ,
' block_nums ' : np . array ( batch_data [ ' block_nums ' ] , dtype = np . int32 ) ,
' trial_nums ' : np . array ( batch_data [ ' trial_nums ' ] , dtype = np . int32 )
}
yield batch
def create_dataset ( self ) - > tf . data . Dataset :
2025-10-22 00:28:10 +08:00
"""
Create optimized tf . data . Dataset for TPU training
⚠ ️ DEPRECATED : This method is deprecated . Use create_input_fn ( ) instead for better performance
and TPU compatibility . This method will be removed in a future version .
"""
import warnings
warnings . warn (
" create_dataset is deprecated. Use create_input_fn() instead for better performance. " ,
DeprecationWarning ,
stacklevel = 2
)
2025-10-15 16:55:52 +08:00
# Define output signature for the dataset
output_signature = {
' input_features ' : tf . TensorSpec ( shape = ( None , None , None ) , dtype = tf . float32 ) ,
' seq_class_ids ' : tf . TensorSpec ( shape = ( None , None ) , dtype = tf . int32 ) ,
' n_time_steps ' : tf . TensorSpec ( shape = ( None , ) , dtype = tf . int32 ) ,
' phone_seq_lens ' : tf . TensorSpec ( shape = ( None , ) , dtype = tf . int32 ) ,
' day_indices ' : tf . TensorSpec ( shape = ( None , ) , dtype = tf . int32 ) ,
' transcriptions ' : tf . TensorSpec ( shape = ( None , None ) , dtype = tf . int32 ) ,
' block_nums ' : tf . TensorSpec ( shape = ( None , ) , dtype = tf . int32 ) ,
' trial_nums ' : tf . TensorSpec ( shape = ( None , ) , dtype = tf . int32 )
}
# Create dataset from generator
dataset = tf . data . Dataset . from_generator (
self . _create_batch_generator ,
output_signature = output_signature
)
# Apply TPU-optimized transformations
2025-10-16 20:26:32 +08:00
# 🚨 GPU版本策略: 不需要在Dataset级别shuffle!
# GPU版本在 _create_batch_index_train() 中已经做了随机采样( 第107-118行)
# 这里再shuffle会导致内存爆炸( 1000 batch × 256 trials = 256,000 trials同时在内存)
# if self.split == 'train':
# dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # ← 注释掉内存杀手
2025-10-15 16:55:52 +08:00
# Prefetch for better performance
dataset = dataset . prefetch ( self . prefetch_buffer )
return dataset
2025-10-17 12:20:17 +08:00
def create_individual_dataset ( self ) - > tf . data . Dataset :
"""
2025-10-19 10:31:18 +08:00
Create tf . data . Dataset that yields individual examples with I / O optimization .
2025-10-17 12:20:17 +08:00
2025-10-19 10:31:18 +08:00
This generator is refactored to group trial loading by session file ,
drastically reducing the number of file open / close operations from
N_trials to N_sessions , which is ideal for slow disk I / O .
2025-10-17 12:20:17 +08:00
"""
def individual_example_generator ( ) :
2025-10-19 10:31:18 +08:00
""" Generator that groups reads by file to minimize disk I/O. """
2025-10-17 12:20:17 +08:00
2025-10-19 10:31:18 +08:00
# 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...]
all_trials_to_load = [ ]
# 注意:这里的迭代顺序决定了大致的读取顺序
# _create_batch_index_train 已经为我们随机化了批次
for batch_idx in sorted ( self . batch_indices . keys ( ) ) :
batch_index = self . batch_indices [ batch_idx ]
2025-10-17 12:20:17 +08:00
for day in batch_index . keys ( ) :
for trial in batch_index [ day ] :
2025-10-19 10:31:18 +08:00
all_trials_to_load . append ( ( day , trial ) )
# 2. 按 'day' (即按文件) 对试验列表进行分组
# key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键
for day , group in groupby ( sorted ( all_trials_to_load , key = lambda x : x [ 0 ] ) , key = lambda x : x [ 0 ] ) :
session_path = self . trial_indices [ day ] [ ' session_path ' ]
# 3. 为每个分组(每个文件)只打开一次 HDF5 文件
try :
with h5py . File ( session_path , ' r ' ) as f :
# 4. 在文件打开的状态下,读取这个文件中需要的所有试验
for current_day , current_trial in group :
try :
# 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数
g = f [ f ' trial_ { current_trial : 04d } ' ]
input_features = g [ ' input_features ' ] [ : ]
if self . feature_subset :
input_features = input_features [ : , self . feature_subset ]
example = {
' input_features ' : input_features . astype ( np . float32 ) ,
' seq_class_ids ' : g [ ' seq_class_ids ' ] [ : ] . astype ( np . int32 ) ,
' n_time_steps ' : np . int32 ( g . attrs [ ' n_time_steps ' ] ) ,
' phone_seq_lens ' : np . int32 ( g . attrs [ ' seq_len ' ] ) ,
' day_indices ' : np . int32 ( current_day ) ,
' transcriptions ' : g [ ' transcription ' ] [ : ] . astype ( np . int32 ) ,
' block_nums ' : np . int32 ( g . attrs [ ' block_num ' ] ) ,
' trial_nums ' : np . int32 ( g . attrs [ ' trial_num ' ] )
}
yield example
except KeyError :
logging . warning ( f " Trial { current_trial } not found in file { session_path } . Skipping. " )
continue
except ( IOError , FileNotFoundError ) as e :
logging . error ( f " Could not open or read HDF5 file: { session_path } . Error: { e } . Skipping all trials for this day. " )
continue
2025-10-17 12:20:17 +08:00
# Define output signature for individual examples
output_signature = {
2025-10-22 00:28:10 +08:00
' input_features ' : tf . TensorSpec ( shape = ( None , self . feature_dim ) , dtype = tf . float32 ) ,
2025-10-17 12:20:17 +08:00
' seq_class_ids ' : tf . TensorSpec ( shape = ( None , ) , dtype = tf . int32 ) ,
' n_time_steps ' : tf . TensorSpec ( shape = ( ) , dtype = tf . int32 ) ,
' phone_seq_lens ' : tf . TensorSpec ( shape = ( ) , dtype = tf . int32 ) ,
' day_indices ' : tf . TensorSpec ( shape = ( ) , dtype = tf . int32 ) ,
' transcriptions ' : tf . TensorSpec ( shape = ( None , ) , dtype = tf . int32 ) ,
' block_nums ' : tf . TensorSpec ( shape = ( ) , dtype = tf . int32 ) ,
' trial_nums ' : tf . TensorSpec ( shape = ( ) , dtype = tf . int32 )
}
# Create dataset from individual examples
dataset = tf . data . Dataset . from_generator (
individual_example_generator ,
output_signature = output_signature
)
# Shuffle individual examples if training (more effective than batch-level shuffle)
if self . split == ' train ' :
2025-10-19 10:31:18 +08:00
# 可以适当增大buffer, 因为现在I/O更高效了
shuffle_buffer = min ( 2048 , self . n_trials )
2025-10-17 12:20:17 +08:00
dataset = dataset . shuffle ( buffer_size = shuffle_buffer )
return dataset
2025-10-15 16:55:52 +08:00
class DataAugmentationTF :
"""
TensorFlow data augmentation functions optimized for TPU v5e - 8
"""
@staticmethod
def gauss_smooth ( inputs : tf . Tensor ,
smooth_kernel_std : float = 2.0 ,
smooth_kernel_size : int = 100 ) - > tf . Tensor :
"""
2025-10-17 12:20:17 +08:00
Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation .
This implementation uses depthwise_conv2d for optimal TPU performance ,
replacing the inefficient Python for - loop that created 512 separate conv1d operations .
2025-10-15 16:55:52 +08:00
Args :
inputs : Input tensor [ batch_size , time_steps , features ]
smooth_kernel_std : Standard deviation of Gaussian kernel
smooth_kernel_size : Size of the Gaussian kernel
Returns :
Smoothed tensor with same shape as input
"""
# Create Gaussian kernel using numpy (computed once)
inp = np . zeros ( smooth_kernel_size , dtype = np . float32 )
inp [ smooth_kernel_size / / 2 ] = 1
gauss_kernel = gaussian_filter1d ( inp , smooth_kernel_std )
valid_idx = np . argwhere ( gauss_kernel > 0.01 )
gauss_kernel = gauss_kernel [ valid_idx ] . flatten ( )
gauss_kernel = gauss_kernel / np . sum ( gauss_kernel )
gauss_kernel = tf . constant ( gauss_kernel , dtype = tf . float32 )
2025-10-17 12:20:17 +08:00
# ========================= OPTIMIZED SOLUTION =========================
# Get input dimensions
num_features = tf . shape ( inputs ) [ - 1 ]
2025-10-15 20:45:25 +08:00
kernel_size = tf . shape ( gauss_kernel ) [ 0 ]
2025-10-17 12:20:17 +08:00
# Prepare kernel for depthwise_conv2d
# Shape needed: [height, width, in_channels, channel_multiplier]
# Our case: [kernel_size, 1, num_features, 1]
# This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel
kernel = tf . reshape ( gauss_kernel , [ kernel_size , 1 , 1 , 1 ] )
kernel = tf . tile ( kernel , [ 1 , 1 , num_features , 1 ] )
# Prepare input for conv2d
# Shape needed: [batch, height, width, channels]
# Our case: [batch_size, time_steps, 1, num_features]
# Add a dummy width dimension
reshaped_inputs = tf . expand_dims ( inputs , axis = 2 )
# Execute depthwise convolution
# This is a single, efficient operation replacing the original Python for-loop
smoothed = tf . nn . depthwise_conv2d (
reshaped_inputs ,
kernel ,
strides = [ 1 , 1 , 1 , 1 ] ,
padding = ' SAME '
)
# Remove the dummy width dimension to restore original shape
smoothed = tf . squeeze ( smoothed , axis = 2 )
# ================================================================
2025-10-15 16:55:52 +08:00
return smoothed
@staticmethod
def transform_data ( features : tf . Tensor ,
n_time_steps : tf . Tensor ,
transform_args : Dict [ str , Any ] ,
training : bool = True ) - > Tuple [ tf . Tensor , tf . Tensor ] :
"""
Apply data transformations optimized for TPU
Args :
features : Input features [ batch_size , time_steps , channels ]
n_time_steps : Number of valid time steps per sample
transform_args : Transformation configuration
training : Whether to apply training - only augmentations
Returns :
Transformed features and updated time steps
"""
batch_size = tf . shape ( features ) [ 0 ]
time_steps = tf . shape ( features ) [ 1 ]
channels = tf . shape ( features ) [ 2 ]
# Training-only augmentations
if training :
# Static gain noise
if transform_args . get ( ' static_gain_std ' , 0 ) > 0 :
gain_std = transform_args [ ' static_gain_std ' ]
# Create identity matrices for each batch
identity_matrices = tf . eye ( channels , batch_shape = [ batch_size ] )
# Add noise to create warp matrices
noise = tf . random . normal ( [ batch_size , channels , channels ] ) * gain_std
warp_matrices = identity_matrices + noise
# Apply transformation
features = tf . linalg . matmul ( features , warp_matrices )
# White noise
if transform_args . get ( ' white_noise_std ' , 0 ) > 0 :
white_noise = tf . random . normal ( tf . shape ( features ) ) * transform_args [ ' white_noise_std ' ]
features = features + white_noise
# Constant offset noise
if transform_args . get ( ' constant_offset_std ' , 0 ) > 0 :
offset_noise = tf . random . normal ( [ batch_size , 1 , channels ] ) * transform_args [ ' constant_offset_std ' ]
features = features + offset_noise
# Random walk noise
if transform_args . get ( ' random_walk_std ' , 0 ) > 0 :
random_walk_noise = tf . random . normal ( tf . shape ( features ) ) * transform_args [ ' random_walk_std ' ]
axis = transform_args . get ( ' random_walk_axis ' , 1 )
random_walk_noise = tf . cumsum ( random_walk_noise , axis = axis )
features = features + random_walk_noise
# Random cutoff (simplified for TPU - apply to all samples in batch)
if transform_args . get ( ' random_cut ' , 0 ) > 0 :
max_cut = transform_args [ ' random_cut ' ]
cut = tf . random . uniform ( [ ] , 0 , max_cut , dtype = tf . int32 )
features = features [ : , cut : , : ]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing (both training and validation)
if transform_args . get ( ' smooth_data ' , False ) :
features = DataAugmentationTF . gauss_smooth (
features ,
smooth_kernel_std = transform_args . get ( ' smooth_kernel_std ' , 2.0 ) ,
smooth_kernel_size = transform_args . get ( ' smooth_kernel_size ' , 100 )
)
return features , n_time_steps
def train_test_split_indices ( file_paths : List [ str ] ,
test_percentage : float = 0.1 ,
seed : int = - 1 ,
bad_trials_dict : Optional [ Dict ] = None ) - > Tuple [ Dict , Dict ] :
"""
Split data from file_paths into train and test splits
Args :
file_paths : List of HDF5 file paths
test_percentage : Percentage of trials for testing
seed : Random seed for reproducibility
bad_trials_dict : Dictionary of trials to exclude
Returns :
Tuple of ( train_trials , test_trials ) dictionaries
"""
# Set seed for reproducibility
if seed != - 1 :
np . random . seed ( seed )
# Get trials in each day
trials_per_day = { }
for i , path in enumerate ( file_paths ) :
2025-10-22 00:28:10 +08:00
try :
# Handle both Windows and Unix path separators
path_parts = path . replace ( ' \\ ' , ' / ' ) . split ( ' / ' )
session_candidates = [ s for s in path_parts if ( s . startswith ( ' t15.20 ' ) or s . startswith ( ' t12.20 ' ) ) ]
if not session_candidates :
logging . error ( f " Could not parse session name from path: { path } . Skipping this file. " )
continue
session = session_candidates [ 0 ]
except Exception as e :
logging . error ( f " Error parsing path { path } : { e } . Skipping this file. " )
continue
2025-10-15 16:55:52 +08:00
good_trial_indices = [ ]
if os . path . exists ( path ) :
with h5py . File ( path , ' r ' ) as f :
num_trials = len ( list ( f . keys ( ) ) )
for t in range ( num_trials ) :
key = f ' trial_ { t : 04d } '
if key not in f :
continue
block_num = f [ key ] . attrs [ ' block_num ' ]
trial_num = f [ key ] . attrs [ ' trial_num ' ]
# Check if trial should be excluded
if ( bad_trials_dict is not None
and session in bad_trials_dict
and str ( block_num ) in bad_trials_dict [ session ]
and trial_num in bad_trials_dict [ session ] [ str ( block_num ) ] ) :
continue
good_trial_indices . append ( t )
trials_per_day [ i ] = {
' num_trials ' : len ( good_trial_indices ) ,
' trial_indices ' : good_trial_indices ,
' session_path ' : path
}
# Split trials into train and test
train_trials = { }
test_trials = { }
for day in trials_per_day . keys ( ) :
num_trials = trials_per_day [ day ] [ ' num_trials ' ]
all_trial_indices = trials_per_day [ day ] [ ' trial_indices ' ]
if test_percentage == 0 :
train_trials [ day ] = {
' trials ' : all_trial_indices ,
' session_path ' : trials_per_day [ day ] [ ' session_path ' ]
}
test_trials [ day ] = {
' trials ' : [ ] ,
' session_path ' : trials_per_day [ day ] [ ' session_path ' ]
}
elif test_percentage == 1 :
train_trials [ day ] = {
' trials ' : [ ] ,
' session_path ' : trials_per_day [ day ] [ ' session_path ' ]
}
test_trials [ day ] = {
' trials ' : all_trial_indices ,
' session_path ' : trials_per_day [ day ] [ ' session_path ' ]
}
else :
# Calculate number of test trials
num_test = max ( 1 , int ( num_trials * test_percentage ) )
# Randomly select test indices
test_indices = np . random . choice ( all_trial_indices , size = num_test , replace = False ) . tolist ( )
# Remaining indices for training
train_indices = [ idx for idx in all_trial_indices if idx not in test_indices ]
train_trials [ day ] = {
' trials ' : train_indices ,
' session_path ' : trials_per_day [ day ] [ ' session_path ' ]
}
test_trials [ day ] = {
' trials ' : test_indices ,
' session_path ' : trials_per_day [ day ] [ ' session_path ' ]
}
return train_trials , test_trials
2025-10-19 11:04:36 +08:00
def analyze_dataset_shapes ( dataset_tf : BrainToTextDatasetTF , sample_size : int = 100 ) - > Dict [ str , int ] :
"""
2025-10-20 00:35:17 +08:00
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch ,
utilizing multiple CPU cores and the dataset ' s caching mechanism.
2025-10-19 11:04:36 +08:00
Args :
dataset_tf : Dataset instance to analyze
sample_size : Number of samples to analyze ( set to - 1 for all data )
Returns :
Dictionary with maximum dimensions
"""
2025-10-22 00:28:10 +08:00
print ( f " 🚀 Starting parallel dataset analysis (sampling: { ' ALL ' if sample_size == - 1 else sample_size } )... " )
2025-10-20 00:35:17 +08:00
start_time = time . time ( )
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
2025-10-20 00:13:39 +08:00
all_trials = [ ]
2025-10-20 00:35:17 +08:00
unique_trials = set ( )
2025-10-20 00:13:39 +08:00
for batch_idx in sorted ( dataset_tf . batch_indices . keys ( ) ) :
batch = dataset_tf . batch_indices [ batch_idx ]
for day , trials in batch . items ( ) :
for trial in trials :
2025-10-20 00:35:17 +08:00
if ( day , trial ) not in unique_trials :
unique_trials . add ( ( day , trial ) )
all_trials . append ( ( day , trial ) )
# 2. 如果需要采样,则对列表进行采样
if 0 < sample_size < len ( all_trials ) :
# 设置种子以确保可重现性(如果需要的话)
random . seed ( 42 )
trials_to_check = random . sample ( all_trials , sample_size )
2025-10-19 13:18:20 +08:00
else :
2025-10-20 00:13:39 +08:00
trials_to_check = all_trials
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
total_trials_to_analyze = len ( trials_to_check )
print ( f " 📊 Total unique trials to analyze: { total_trials_to_analyze } " )
2025-10-19 13:18:20 +08:00
2025-10-20 00:35:17 +08:00
# 定义一个辅助函数,供每个线程调用
def analyze_single_trial ( day_trial_pair ) :
""" Loads and analyzes a single trial, returns its shapes. """
day , trial = day_trial_pair
2025-10-20 00:13:39 +08:00
try :
2025-10-20 00:35:17 +08:00
# 复用 dataset_tf 的加载和缓存逻辑
trial_data = dataset_tf . _load_trial_data ( day , trial )
# 直接从加载的数据中获取信息
time_steps = int ( trial_data [ ' n_time_steps ' ] )
phone_seq_len = int ( trial_data [ ' phone_seq_lens ' ] )
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
# 处理 transcription 数据 - 它可能是数组
transcription_data = trial_data [ ' transcription ' ]
if hasattr ( transcription_data , ' __len__ ' ) :
2025-10-20 00:13:39 +08:00
transcription_len = len ( transcription_data )
2025-10-20 00:35:17 +08:00
else :
transcription_len = 1 # 如果是标量, 长度为1
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
return ( time_steps , phone_seq_len , transcription_len )
except Exception as e :
logging . warning ( f " Failed to analyze trial { day } _ { trial } : { e } " )
return None # 返回 None 表示失败
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
# 3. 使用 ThreadPoolExecutor 进行并行处理
2025-10-22 00:28:10 +08:00
# Use dynamic calculation based on CPU cores with reasonable upper limit
cpu_count = os . cpu_count ( ) or 4 # Fallback to 4 if cpu_count() returns None
max_workers = min ( 32 , cpu_count , len ( trials_to_check ) )
2025-10-20 00:35:17 +08:00
local_max_shapes = [ ]
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
print ( f " 🔧 Using { max_workers } parallel workers for analysis... " )
2025-10-19 13:18:20 +08:00
2025-10-20 00:35:17 +08:00
with ThreadPoolExecutor ( max_workers = max_workers ) as executor :
# 提交所有分析任务
future_to_trial = {
executor . submit ( analyze_single_trial , trial_pair ) : trial_pair
for trial_pair in trials_to_check
}
count = 0
progress_interval = max ( 1 , total_trials_to_analyze / / 20 ) # 显示20次进度更新
for future in as_completed ( future_to_trial ) :
result = future . result ( )
if result :
local_max_shapes . append ( result )
count + = 1
if count % progress_interval == 0 or count == total_trials_to_analyze :
elapsed = time . time ( ) - start_time
rate = count / elapsed if elapsed > 0 else 0
print ( f " 📈 Analyzed { count } / { total_trials_to_analyze } trials... ( { count / total_trials_to_analyze : .1% } ) [ { rate : .1f } trials/sec] " )
# 4. 聚合所有线程的结果
if not local_max_shapes :
raise ValueError ( " Dataset analysis failed: No trials could be successfully analyzed. " )
# 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...)
unzipped_shapes = list ( zip ( * local_max_shapes ) )
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
original_max_shapes = {
' max_time_steps ' : int ( np . max ( unzipped_shapes [ 0 ] ) ) ,
' max_phone_seq_len ' : int ( np . max ( unzipped_shapes [ 1 ] ) ) ,
' max_transcription_len ' : int ( np . max ( unzipped_shapes [ 2 ] ) ) ,
2025-10-22 00:28:10 +08:00
' n_features ' : dataset_tf . feature_dim
2025-10-20 00:35:17 +08:00
}
2025-10-22 01:48:40 +08:00
# 5. 添加适当的安全边际 - 基于分析范围调整
if sample_size == - 1 :
# 全数据分析:只需要很小的边际应对可能的舍入误差
safety_margin = 1.02 # 2% buffer for rounding errors
margin_reason = " minimal buffer for full dataset analysis "
else :
# 采样分析:需要更大的边际应对未采样到的极值
safety_margin = 1.3 # 30% buffer for sampling uncertainty
margin_reason = f " larger buffer due to sampling only { sample_size } trials "
2025-10-20 00:35:17 +08:00
final_max_shapes = {
' max_time_steps ' : int ( original_max_shapes [ ' max_time_steps ' ] * safety_margin ) ,
' max_phone_seq_len ' : int ( original_max_shapes [ ' max_phone_seq_len ' ] * safety_margin ) ,
' max_transcription_len ' : int ( original_max_shapes [ ' max_transcription_len ' ] * safety_margin ) ,
' n_features ' : original_max_shapes [ ' n_features ' ]
}
2025-10-19 13:18:20 +08:00
2025-10-20 00:35:17 +08:00
analysis_time = time . time ( ) - start_time
successful_rate = len ( local_max_shapes ) / total_trials_to_analyze * 100
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
print ( f " ✅ Parallel analysis complete in { analysis_time : .2f } seconds! " )
print ( f " 📊 Successfully analyzed { len ( local_max_shapes ) } / { total_trials_to_analyze } trials ( { successful_rate : .1f } %) " )
2025-10-22 01:48:40 +08:00
print ( f " 📏 Final max shapes (with { int ( ( safety_margin - 1 ) * 100 ) } % safety margin - { margin_reason } ): " )
2025-10-20 00:35:17 +08:00
print ( f " Time steps: { original_max_shapes [ ' max_time_steps ' ] } → { final_max_shapes [ ' max_time_steps ' ] } " )
print ( f " Phone sequence length: { original_max_shapes [ ' max_phone_seq_len ' ] } → { final_max_shapes [ ' max_phone_seq_len ' ] } " )
print ( f " Transcription length: { original_max_shapes [ ' max_transcription_len ' ] } → { final_max_shapes [ ' max_transcription_len ' ] } " )
print ( f " Number of features: { final_max_shapes [ ' n_features ' ] } " )
2025-10-19 11:04:36 +08:00
2025-10-20 00:35:17 +08:00
return final_max_shapes
2025-10-19 11:04:36 +08:00
2025-10-15 16:55:52 +08:00
# Utility functions for TPU-optimized data pipeline
def create_input_fn ( dataset_tf : BrainToTextDatasetTF ,
transform_args : Dict [ str , Any ] ,
2025-10-19 10:31:18 +08:00
training : bool = True ,
2025-10-22 01:29:31 +08:00
cache_path : Optional [ str ] = None ,
use_static_shapes : bool = True ) - > tf . data . Dataset :
2025-10-15 16:55:52 +08:00
"""
2025-10-22 01:29:31 +08:00
Create input function for TPU training with configurable shape handling
2025-10-15 16:55:52 +08:00
Args :
dataset_tf : BrainToTextDatasetTF instance
transform_args : Data transformation configuration
training : Whether this is for training ( applies augmentations )
2025-10-19 10:31:18 +08:00
cache_path : Optional path for disk caching to improve I / O performance
2025-10-22 01:29:31 +08:00
use_static_shapes : If True , use pre - computed static shapes for XLA compatibility
2025-10-15 16:55:52 +08:00
Returns :
2025-10-22 01:29:31 +08:00
tf . data . Dataset ready for TPU training
2025-10-15 16:55:52 +08:00
"""
2025-10-21 00:31:59 +08:00
# Step 1: Create individual example dataset
2025-10-17 12:20:17 +08:00
dataset = dataset_tf . create_individual_dataset ( )
2025-10-21 00:31:59 +08:00
# Step 2: Cache raw samples BEFORE any augmentation or batching
2025-10-19 10:31:31 +08:00
if cache_path :
dataset = dataset . cache ( cache_path )
split_name = " training " if training else " validation "
print ( f " 🗃️ { split_name . capitalize ( ) } dataset caching enabled: { cache_path } " )
print ( f " ⚠️ First access will be slow while building { split_name } cache, subsequent access will be much faster " )
else :
dataset = dataset . cache ( )
split_name = " training " if training else " validation "
print ( f " 🗃️ { split_name . capitalize ( ) } dataset caching enabled: in-memory cache " )
2025-10-20 13:37:11 +08:00
2025-10-22 01:29:31 +08:00
# Step 3: Batch samples with shape handling optimized for TPU
if use_static_shapes :
print ( f " 🔧 Using STATIC shapes for XLA compatibility " )
# Analyze dataset to get maximum shapes
print ( " 📊 Analyzing dataset for maximum shapes... " )
2025-10-22 01:47:08 +08:00
max_shapes = analyze_dataset_shapes ( dataset_tf , sample_size = - 1 ) # Analyze ALL data for maximum accuracy
2025-10-22 01:29:31 +08:00
# Use static shapes based on analysis
padded_shapes = {
' input_features ' : ( max_shapes [ ' max_time_steps ' ] , dataset_tf . feature_dim ) ,
' seq_class_ids ' : ( max_shapes [ ' max_phone_seq_len ' ] , ) ,
' n_time_steps ' : ( ) ,
' phone_seq_lens ' : ( ) ,
' day_indices ' : ( ) ,
' transcriptions ' : ( max_shapes [ ' max_transcription_len ' ] , ) ,
' block_nums ' : ( ) ,
' trial_nums ' : ( )
}
print ( f " 📏 Using static shapes: time_steps= { max_shapes [ ' max_time_steps ' ] } , "
f " phone_len= { max_shapes [ ' max_phone_seq_len ' ] } , "
f " transcription_len= { max_shapes [ ' max_transcription_len ' ] } " )
else :
print ( f " 🔧 Using DYNAMIC shapes (may cause XLA compilation issues) " )
# Use dynamic shapes - may cause XLA compilation issues
padded_shapes = {
' input_features ' : ( None , dataset_tf . feature_dim ) ,
' seq_class_ids ' : ( None , ) ,
' n_time_steps ' : ( ) ,
' phone_seq_lens ' : ( ) ,
' day_indices ' : ( ) ,
' transcriptions ' : ( None , ) ,
' block_nums ' : ( ) ,
' trial_nums ' : ( )
}
2025-10-20 13:37:11 +08:00
2025-10-22 01:29:31 +08:00
print ( f " 🔧 Feature dimension: { dataset_tf . feature_dim } " )
2025-10-17 12:20:17 +08:00
2025-10-19 20:16:23 +08:00
# Define padding values for each field
2025-10-17 12:20:17 +08:00
padding_values = {
' input_features ' : 0.0 ,
' seq_class_ids ' : 0 ,
' n_time_steps ' : 0 ,
' phone_seq_lens ' : 0 ,
' day_indices ' : 0 ,
' transcriptions ' : 0 ,
' block_nums ' : 0 ,
' trial_nums ' : 0
}
2025-10-21 00:31:59 +08:00
# Create batches with DYNAMIC padding - this cannot fail due to size mismatches
2025-10-17 12:20:17 +08:00
dataset = dataset . padded_batch (
batch_size = dataset_tf . batch_size ,
padded_shapes = padded_shapes ,
padding_values = padding_values ,
drop_remainder = True # Critical for TPU: ensures all batches have same size
)
2025-10-21 00:31:59 +08:00
# Step 4: Apply data augmentation to BATCHES (after dynamic batching)
def apply_batch_transforms ( batch ) :
""" Apply data transformations to entire batches - resolves time paradox """
features = batch [ ' input_features ' ]
n_time_steps = batch [ ' n_time_steps ' ]
# Apply transformations to the entire batch
features , n_time_steps = DataAugmentationTF . transform_data (
features , # Already has batch dimension
n_time_steps ,
transform_args ,
training = training
)
# Update batch with transformed data
batch [ ' input_features ' ] = features
batch [ ' n_time_steps ' ] = n_time_steps
return batch
# Apply batch transforms only during training
if training :
dataset = dataset . map (
apply_batch_transforms ,
num_parallel_calls = tf . data . AUTOTUNE
)
print ( f " ✅ Batch augmentation enabled for training " )
else :
print ( f " ✅ No augmentation for validation " )
# Step 5: Prefetch for optimal performance
2025-10-17 12:20:17 +08:00
dataset = dataset . prefetch ( tf . data . AUTOTUNE )
2025-10-15 16:55:52 +08:00
return dataset