Skip to content
Snippets Groups Projects
Commit aa7bd212 authored by jan.middendorf@rwth-aachen.de's avatar jan.middendorf@rwth-aachen.de
Browse files

Added linear scaling to ScaleOnPlateau

parent 8ca8ec34
No related branches found
No related tags found
No related merge requests found
...@@ -488,7 +488,7 @@ class PatientTracker(BestTracker): ...@@ -488,7 +488,7 @@ class PatientTracker(BestTracker):
class ScaleOnPlateau(PatientTracker): class ScaleOnPlateau(PatientTracker):
def __init__(self, target, factor, min=None, max=None, verbose=0, log_key=None, **kwargs): def __init__(self, target, factor, min=None, max=None, verbose=0, log_key=None, linearly=False, **kwargs):
pin(locals(), kwargs) pin(locals(), kwargs)
super(ScaleOnPlateau, self).__init__(**kwargs) super(ScaleOnPlateau, self).__init__(**kwargs)
...@@ -509,6 +509,9 @@ class ScaleOnPlateau(PatientTracker): ...@@ -509,6 +509,9 @@ class ScaleOnPlateau(PatientTracker):
if self.log_key is not None: if self.log_key is not None:
logs.setdefault(self.log_key, cur) logs.setdefault(self.log_key, cur)
if self.patient_step(epoch, logs) == "good": if self.patient_step(epoch, logs) == "good":
if self.linearly:
new = cur + self.factor
else:
new = cur * self.factor new = cur * self.factor
if self.min is not None: if self.min is not None:
new = max(new, self.min) new = max(new, self.min)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment