Skip to content
Snippets Groups Projects
Commit aa7bd212 authored by jan.middendorf@rwth-aachen.de's avatar jan.middendorf@rwth-aachen.de
Browse files

Added linear scaling to ScaleOnPlateau

parent 8ca8ec34
No related branches found
No related tags found
No related merge requests found
......@@ -488,7 +488,7 @@ class PatientTracker(BestTracker):
class ScaleOnPlateau(PatientTracker):
def __init__(self, target, factor, min=None, max=None, verbose=0, log_key=None, **kwargs):
def __init__(self, target, factor, min=None, max=None, verbose=0, log_key=None, linearly=False, **kwargs):
pin(locals(), kwargs)
super(ScaleOnPlateau, self).__init__(**kwargs)
......@@ -509,7 +509,10 @@ class ScaleOnPlateau(PatientTracker):
if self.log_key is not None:
logs.setdefault(self.log_key, cur)
if self.patient_step(epoch, logs) == "good":
new = cur * self.factor
if self.linearly:
new = cur + self.factor
else:
new = cur * self.factor
if self.min is not None:
new = max(new, self.min)
if self.max is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment