Enhance dataset caching logic for training and validation sets with improved messaging
This commit is contained in:
@@ -778,20 +778,21 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
dataset = dataset_tf.create_individual_dataset()
|
||||
|
||||
# ========================= I/O OPTIMIZATION SOLUTION =========================
|
||||
# 在数据加载之后、随机操作(如数据增强)之前进行缓存
|
||||
if training:
|
||||
# 对于训练,缓存到磁盘文件或内存
|
||||
if cache_path:
|
||||
dataset = dataset.cache(cache_path)
|
||||
print(f"🗃️ Dataset caching enabled: {cache_path}")
|
||||
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
|
||||
else:
|
||||
# 如果内存足够大,可以缓存到内存,速度更快
|
||||
# 但对于大型数据集,推荐使用文件缓存
|
||||
dataset = dataset.cache()
|
||||
print("🗃️ Dataset caching enabled: in-memory cache")
|
||||
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
|
||||
# (对于验证集通常不需要缓存,因为它只运行一次)
|
||||
# 对训练集和验证集都进行缓存,因为:
|
||||
# 1. 训练集:每个epoch都要完整遍历
|
||||
# 2. 验证集:每200轮验证一次 + 早停检查,会被频繁使用
|
||||
if cache_path:
|
||||
dataset = dataset.cache(cache_path)
|
||||
split_name = "training" if training else "validation"
|
||||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}")
|
||||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
||||
else:
|
||||
# 如果没有指定缓存路径,默认使用内存缓存
|
||||
# 对于大型数据集,建议在调用时显式指定磁盘缓存路径
|
||||
dataset = dataset.cache()
|
||||
split_name = "training" if training else "validation"
|
||||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
|
||||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
||||
# ================================================================
|
||||
|
||||
def apply_transforms(example):
|
||||
|
Reference in New Issue
Block a user