94 lines
2.1 KiB
YAML
94 lines
2.1 KiB
YAML
# 简化的TPU训练配置 - 更快编译
|
|
model:
|
|
n_input_features: 512
|
|
n_units: 384 # 减少从768到384
|
|
rnn_dropout: 0.2 # 减少dropout
|
|
rnn_trainable: true
|
|
n_layers: 3 # 减少从5到3层
|
|
patch_size: 8 # 减少从14到8
|
|
patch_stride: 4
|
|
|
|
input_network:
|
|
n_input_layers: 1
|
|
input_layer_sizes:
|
|
- 512
|
|
input_trainable: true
|
|
input_layer_dropout: 0.1 # 减少dropout
|
|
|
|
mode: train
|
|
use_amp: true
|
|
|
|
# TPU分布式训练设置
|
|
use_tpu: true
|
|
num_tpu_cores: 8
|
|
gradient_accumulation_steps: 4 # 增加梯度累积补偿小batch
|
|
|
|
output_dir: trained_models/simple_rnn
|
|
checkpoint_dir: trained_models/simple_rnn/checkpoint
|
|
init_from_checkpoint: false
|
|
save_best_checkpoint: true
|
|
save_val_metrics: true
|
|
|
|
num_training_batches: 1000 # 先测试1000个batch
|
|
lr_scheduler_type: cosine
|
|
lr_max: 0.003 # 稍微降低学习率
|
|
lr_min: 0.0001
|
|
lr_decay_steps: 1000
|
|
lr_warmup_steps: 100
|
|
|
|
lr_max_day: 0.003
|
|
lr_min_day: 0.0001
|
|
lr_decay_steps_day: 1000
|
|
lr_warmup_steps_day: 100
|
|
|
|
beta0: 0.9
|
|
beta1: 0.999
|
|
epsilon: 0.1
|
|
weight_decay: 0.001
|
|
weight_decay_day: 0
|
|
seed: 10
|
|
grad_norm_clip_value: 5 # 减少梯度裁剪
|
|
|
|
batches_per_train_log: 50 # 更频繁的日志
|
|
batches_per_val_step: 200
|
|
log_individual_day_val_PER: true
|
|
|
|
# 禁用对抗训练进行快速测试
|
|
adversarial:
|
|
enabled: false # 先禁用对抗训练
|
|
|
|
dataset:
|
|
data_transforms:
|
|
white_noise_std: 0.5 # 减少数据增强
|
|
constant_offset_std: 0.1
|
|
random_walk_std: 0.0
|
|
random_walk_axis: -1
|
|
static_gain_std: 0.0
|
|
random_cut: 1 # 减少随机裁剪
|
|
smooth_kernel_size: 50 # 减少平滑核大小
|
|
smooth_data: true
|
|
smooth_kernel_std: 1
|
|
|
|
neural_dim: 512
|
|
batch_size: 16 # 减少batch size从32到16
|
|
n_classes: 41
|
|
max_seq_elements: 300 # 减少序列长度
|
|
days_per_batch: 2 # 减少每批天数
|
|
seed: 1
|
|
num_dataloader_workers: 0
|
|
loader_shuffle: false
|
|
test_percentage: 0.1
|
|
dataset_dir: ../data/hdf5_data_final
|
|
|
|
# 只使用部分session进行快速测试
|
|
sessions:
|
|
- t15.2023.08.11
|
|
- t15.2023.08.13
|
|
- t15.2023.08.18
|
|
- t15.2023.08.20
|
|
|
|
dataset_probability_val:
|
|
- 0
|
|
- 1
|
|
- 1
|
|
- 1 |