diff --git a/model_training_nnn_tpu/rnn_trainer.py b/model_training_nnn_tpu/rnn_trainer.py index bf04428..e4bf28e 100644 --- a/model_training_nnn_tpu/rnn_trainer.py +++ b/model_training_nnn_tpu/rnn_trainer.py @@ -581,11 +581,11 @@ class BrainToTextDecoder_Trainer: val_steps_since_improvement = 0 - # training params + # training params save_best_checkpoint = self.args.get('save_best_checkpoint', True) early_stopping = self.args.get('early_stopping', True) - early_stopping_val_steps = self.args['early_stopping_val_steps'] + early_stopping_val_steps = self.args.get('early_stopping_val_steps', 20) train_start_time = time.time()