diff --git a/keras.py b/keras.py index 5500168a33cf720342b6f66bbf9e20eb2ce23ea8..7b45f41512ec3038eeaacaf77067567fe11b8454 100644 --- a/keras.py +++ b/keras.py @@ -488,7 +488,7 @@ class PatientTracker(BestTracker): class ScaleOnPlateau(PatientTracker): - def __init__(self, target, factor, min=None, max=None, verbose=0, log_key=None, **kwargs): + def __init__(self, target, factor, min=None, max=None, verbose=0, log_key=None, linearly=False, **kwargs): pin(locals(), kwargs) super(ScaleOnPlateau, self).__init__(**kwargs) @@ -509,7 +509,10 @@ class ScaleOnPlateau(PatientTracker): if self.log_key is not None: logs.setdefault(self.log_key, cur) if self.patient_step(epoch, logs) == "good": - new = cur * self.factor + if self.linearly: + new = cur + self.factor + else: + new = cur * self.factor if self.min is not None: new = max(new, self.min) if self.max is not None: