diff --git a/keras.py b/keras.py index 1911c2c92559a7506d4c455c6bc3ad8afc9b83cd..5248cbe85cc105bbea6b979f402e518696a564e6 100644 --- a/keras.py +++ b/keras.py @@ -473,7 +473,10 @@ class PlotMulticlass(TFSummaryCallback): imgs["roc_curve"] = figure_to_image( figure_roc_curve( - truth, prediction, class_names=self.class_names, sample_weight=self.sample_weight + truth, + prediction, + class_names=self.class_names, + sample_weight=self.sample_weight, ) ) imgs["roc_curve_log"] = figure_to_image( @@ -511,7 +514,10 @@ class PlotMulticlass(TFSummaryCallback): ) imgs["node_activation"] = figure_to_image( figure_node_activations( - prediction, truth, class_names=self.class_names, sample_weight=self.sample_weight + prediction, + truth, + class_names=self.class_names, + sample_weight=self.sample_weight, ) ) imgs["node_activation_disjoint_unweighted"] = figure_to_image( @@ -601,7 +607,12 @@ class CheckpointModel(tf.keras.callbacks.Callback): class BestTracker(tf.keras.callbacks.Callback): def __init__( - self, monitor="val_loss", mode="auto", min_delta=0, min_delta_rel=0, baseline=None + self, + monitor="val_loss", + mode="auto", + min_delta=0, + min_delta_rel=0, + baseline=None, ): pin(locals()) self.reset() @@ -684,7 +695,15 @@ class PatientTracker(BestTracker): class ScaleOnPlateau(PatientTracker): def __init__( - self, target, factor, min=None, max=None, verbose=0, log_key=None, linearly=False, **kwargs + self, + target, + factor, + min=None, + max=None, + verbose=0, + log_key=None, + linearly=False, + **kwargs, ): pin(locals(), kwargs) super(ScaleOnPlateau, self).__init__(**kwargs) @@ -722,7 +741,12 @@ class ScaleOnPlateau(PatientTracker): if self.verbose > 0: print( "\nEpoch %05d: %s scaling %s to %s." - % (epoch + 1, self.__class__.__name__, self.log_key or self.target, new) + % ( + epoch + 1, + self.__class__.__name__, + self.log_key or self.target, + new, + ) ) @@ -1186,6 +1210,7 @@ def feature_importance(*args, method="grad", columns=[], **kwargs): return { k: v for k, v in sorted( - dict(zip(columns, importance.astype(float))).items(), key=lambda item: item[1] + dict(zip(columns, importance.astype(float))).items(), + key=lambda item: item[1], ) }