From ccaa2da8f36eadb52c95be1e99d7c1c4edca5dbc Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Wed, 2 Sep 2020 10:17:03 +0200 Subject: [PATCH] [keras] black --- keras.py | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/keras.py b/keras.py index 1911c2c..5248cbe 100644 --- a/keras.py +++ b/keras.py @@ -473,7 +473,10 @@ class PlotMulticlass(TFSummaryCallback): imgs["roc_curve"] = figure_to_image( figure_roc_curve( - truth, prediction, class_names=self.class_names, sample_weight=self.sample_weight + truth, + prediction, + class_names=self.class_names, + sample_weight=self.sample_weight, ) ) imgs["roc_curve_log"] = figure_to_image( @@ -511,7 +514,10 @@ class PlotMulticlass(TFSummaryCallback): ) imgs["node_activation"] = figure_to_image( figure_node_activations( - prediction, truth, class_names=self.class_names, sample_weight=self.sample_weight + prediction, + truth, + class_names=self.class_names, + sample_weight=self.sample_weight, ) ) imgs["node_activation_disjoint_unweighted"] = figure_to_image( @@ -601,7 +607,12 @@ class CheckpointModel(tf.keras.callbacks.Callback): class BestTracker(tf.keras.callbacks.Callback): def __init__( - self, monitor="val_loss", mode="auto", min_delta=0, min_delta_rel=0, baseline=None + self, + monitor="val_loss", + mode="auto", + min_delta=0, + min_delta_rel=0, + baseline=None, ): pin(locals()) self.reset() @@ -684,7 +695,15 @@ class PatientTracker(BestTracker): class ScaleOnPlateau(PatientTracker): def __init__( - self, target, factor, min=None, max=None, verbose=0, log_key=None, linearly=False, **kwargs + self, + target, + factor, + min=None, + max=None, + verbose=0, + log_key=None, + linearly=False, + **kwargs, ): pin(locals(), kwargs) super(ScaleOnPlateau, self).__init__(**kwargs) @@ -722,7 +741,12 @@ class ScaleOnPlateau(PatientTracker): if self.verbose > 0: print( "\nEpoch %05d: %s scaling %s to %s." - % (epoch + 1, self.__class__.__name__, self.log_key or self.target, new) + % ( + epoch + 1, + self.__class__.__name__, + self.log_key or self.target, + new, + ) ) @@ -1186,6 +1210,7 @@ def feature_importance(*args, method="grad", columns=[], **kwargs): return { k: v for k, v in sorted( - dict(zip(columns, importance.astype(float))).items(), key=lambda item: item[1] + dict(zip(columns, importance.astype(float))).items(), + key=lambda item: item[1], ) } -- GitLab