Enhance dataset caching logic for training and validation sets with improved messaging

This commit is contained in:
Zchen
2025-10-19 10:31:31 +08:00
parent 558be0ad98
commit cfd9653da9

View File

@@ -778,20 +778,21 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
dataset = dataset_tf.create_individual_dataset() dataset = dataset_tf.create_individual_dataset()
# ========================= I/O OPTIMIZATION SOLUTION ========================= # ========================= I/O OPTIMIZATION SOLUTION =========================
# 在数据加载之后、随机操作(如数据增强)之前进行缓存 # 对训练集和验证集都进行缓存,因为:
if training: # 1. 训练集每个epoch都要完整遍历
# 对于训练,缓存到磁盘文件或内存 # 2. 验证集每200轮验证一次 + 早停检查,会被频繁使用
if cache_path: if cache_path:
dataset = dataset.cache(cache_path) dataset = dataset.cache(cache_path)
print(f"🗃️ Dataset caching enabled: {cache_path}") split_name = "training" if training else "validation"
print(" First epoch will be slow while building cache, subsequent epochs will be much faster") print(f"🗃 {split_name.capitalize()} dataset caching enabled: {cache_path}")
else: print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
# 如果内存足够大,可以缓存到内存,速度更快 else:
# 但对于大型数据集,推荐使用文件缓存 # 如果没有指定缓存路径,默认使用内存缓存
dataset = dataset.cache() # 对于大型数据集,建议在调用时显式指定磁盘缓存路径
print("🗃️ Dataset caching enabled: in-memory cache") dataset = dataset.cache()
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster") split_name = "training" if training else "validation"
# (对于验证集通常不需要缓存,因为它只运行一次) print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
# ================================================================ # ================================================================
def apply_transforms(example): def apply_transforms(example):