Enhance dataset caching logic for training and validation sets with improved messaging

This commit is contained in:
Zchen
2025-10-19 10:31:31 +08:00
parent 558be0ad98
commit cfd9653da9

View File

@@ -778,20 +778,21 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
dataset = dataset_tf.create_individual_dataset()
# ========================= I/O OPTIMIZATION SOLUTION =========================
# 在数据加载之后、随机操作(如数据增强)之前进行缓存
if training:
# 对于训练,缓存到磁盘文件或内存
if cache_path:
dataset = dataset.cache(cache_path)
print(f"🗃️ Dataset caching enabled: {cache_path}")
print(" First epoch will be slow while building cache, subsequent epochs will be much faster")
else:
# 如果内存足够大,可以缓存到内存,速度更快
# 但对于大型数据集,推荐使用文件缓存
dataset = dataset.cache()
print("🗃️ Dataset caching enabled: in-memory cache")
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
# (对于验证集通常不需要缓存,因为它只运行一次)
# 对训练集和验证集都进行缓存,因为:
# 1. 训练集每个epoch都要完整遍历
# 2. 验证集每200轮验证一次 + 早停检查,会被频繁使用
if cache_path:
dataset = dataset.cache(cache_path)
split_name = "training" if training else "validation"
print(f"🗃 {split_name.capitalize()} dataset caching enabled: {cache_path}")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
else:
# 如果没有指定缓存路径,默认使用内存缓存
# 对于大型数据集,建议在调用时显式指定磁盘缓存路径
dataset = dataset.cache()
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
# ================================================================
def apply_transforms(example):