diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index ffd943a..2898764 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -778,20 +778,21 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, dataset = dataset_tf.create_individual_dataset() # ========================= I/O OPTIMIZATION SOLUTION ========================= - # 在数据加载之后、随机操作(如数据增强)之前进行缓存 - if training: - # 对于训练,缓存到磁盘文件或内存 - if cache_path: - dataset = dataset.cache(cache_path) - print(f"🗃️ Dataset caching enabled: {cache_path}") - print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster") - else: - # 如果内存足够大,可以缓存到内存,速度更快 - # 但对于大型数据集,推荐使用文件缓存 - dataset = dataset.cache() - print("🗃️ Dataset caching enabled: in-memory cache") - print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster") - # (对于验证集通常不需要缓存,因为它只运行一次) + # 对训练集和验证集都进行缓存,因为: + # 1. 训练集:每个epoch都要完整遍历 + # 2. 验证集:每200轮验证一次 + 早停检查,会被频繁使用 + if cache_path: + dataset = dataset.cache(cache_path) + split_name = "training" if training else "validation" + print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}") + print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") + else: + # 如果没有指定缓存路径,默认使用内存缓存 + # 对于大型数据集,建议在调用时显式指定磁盘缓存路径 + dataset = dataset.cache() + split_name = "training" if training else "validation" + print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache") + print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") # ================================================================ def apply_transforms(example):