diff --git a/keras.py b/keras.py index 129544e5d37329bb978b91534e5d961829335d67..ce2e8951a6434138a013e4344b10467085a1adc4 100644 --- a/keras.py +++ b/keras.py @@ -3,6 +3,7 @@ import gc from collections import OrderedDict, defaultdict import fnmatch import math +import datetime import numpy as np import tensorflow as tf @@ -836,7 +837,8 @@ class ReduceLROnPlateau(ScaleOnPlateau): class EarlyStopping(PatientTracker): - def __init__(self, restore_best_weights=False, verbose=0, do_stop=None, **kwargs): + def __init__(self, runtime=None, restore_best_weights=False, verbose=0, do_stop=None, **kwargs): + stop_time = datetime.datetime.now() + datetime.timedelta(**runtime) if runtime else None pin(locals(), kwargs) super(EarlyStopping, self).__init__(**kwargs) @@ -854,7 +856,9 @@ class EarlyStopping(PatientTracker): if self.restore_best_weights: self.best_weights = self.model.get_weights() self.best_epoch = epoch - elif action == "good": + elif action == "good" or ( + (self.stop_time is not None) and (datetime.datetime.now() > self.stop_time) + ): self.stopped_epoch = epoch self.model.stop_training = True if self.restore_best_weights and self.best_weights is not None: