Files
b2txt25/model_training_nnn_tpu/rnn_args_simple.yaml
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

94 lines
2.1 KiB
YAML

# 简化的TPU训练配置 - 更快编译
model:
n_input_features: 512
n_units: 384 # 减少从768到384
rnn_dropout: 0.2 # 减少dropout
rnn_trainable: true
n_layers: 3 # 减少从5到3层
patch_size: 8 # 减少从14到8
patch_stride: 4
input_network:
n_input_layers: 1
input_layer_sizes:
- 512
input_trainable: true
input_layer_dropout: 0.1 # 减少dropout
mode: train
use_amp: true
# TPU分布式训练设置
use_tpu: true
num_tpu_cores: 8
gradient_accumulation_steps: 4 # 增加梯度累积补偿小batch
output_dir: trained_models/simple_rnn
checkpoint_dir: trained_models/simple_rnn/checkpoint
init_from_checkpoint: false
save_best_checkpoint: true
save_val_metrics: true
num_training_batches: 1000 # 先测试1000个batch
lr_scheduler_type: cosine
lr_max: 0.003 # 稍微降低学习率
lr_min: 0.0001
lr_decay_steps: 1000
lr_warmup_steps: 100
lr_max_day: 0.003
lr_min_day: 0.0001
lr_decay_steps_day: 1000
lr_warmup_steps_day: 100
beta0: 0.9
beta1: 0.999
epsilon: 0.1
weight_decay: 0.001
weight_decay_day: 0
seed: 10
grad_norm_clip_value: 5 # 减少梯度裁剪
batches_per_train_log: 50 # 更频繁的日志
batches_per_val_step: 200
log_individual_day_val_PER: true
# 禁用对抗训练进行快速测试
adversarial:
enabled: false # 先禁用对抗训练
dataset:
data_transforms:
white_noise_std: 0.5 # 减少数据增强
constant_offset_std: 0.1
random_walk_std: 0.0
random_walk_axis: -1
static_gain_std: 0.0
random_cut: 1 # 减少随机裁剪
smooth_kernel_size: 50 # 减少平滑核大小
smooth_data: true
smooth_kernel_std: 1
neural_dim: 512
batch_size: 16 # 减少batch size从32到16
n_classes: 41
max_seq_elements: 300 # 减少序列长度
days_per_batch: 2 # 减少每批天数
seed: 1
num_dataloader_workers: 0
loader_shuffle: false
test_percentage: 0.1
dataset_dir: ../data/hdf5_data_final
# 只使用部分session进行快速测试
sessions:
- t15.2023.08.11
- t15.2023.08.13
- t15.2023.08.18
- t15.2023.08.20
dataset_probability_val:
- 0
- 1
- 1
- 1