diff --git a/keras.py b/keras.py
index 1911c2c92559a7506d4c455c6bc3ad8afc9b83cd..5248cbe85cc105bbea6b979f402e518696a564e6 100644
--- a/keras.py
+++ b/keras.py
@@ -473,7 +473,10 @@ class PlotMulticlass(TFSummaryCallback):
 
         imgs["roc_curve"] = figure_to_image(
             figure_roc_curve(
-                truth, prediction, class_names=self.class_names, sample_weight=self.sample_weight
+                truth,
+                prediction,
+                class_names=self.class_names,
+                sample_weight=self.sample_weight,
             )
         )
         imgs["roc_curve_log"] = figure_to_image(
@@ -511,7 +514,10 @@ class PlotMulticlass(TFSummaryCallback):
         )
         imgs["node_activation"] = figure_to_image(
             figure_node_activations(
-                prediction, truth, class_names=self.class_names, sample_weight=self.sample_weight
+                prediction,
+                truth,
+                class_names=self.class_names,
+                sample_weight=self.sample_weight,
             )
         )
         imgs["node_activation_disjoint_unweighted"] = figure_to_image(
@@ -601,7 +607,12 @@ class CheckpointModel(tf.keras.callbacks.Callback):
 
 class BestTracker(tf.keras.callbacks.Callback):
     def __init__(
-        self, monitor="val_loss", mode="auto", min_delta=0, min_delta_rel=0, baseline=None
+        self,
+        monitor="val_loss",
+        mode="auto",
+        min_delta=0,
+        min_delta_rel=0,
+        baseline=None,
     ):
         pin(locals())
         self.reset()
@@ -684,7 +695,15 @@ class PatientTracker(BestTracker):
 
 class ScaleOnPlateau(PatientTracker):
     def __init__(
-        self, target, factor, min=None, max=None, verbose=0, log_key=None, linearly=False, **kwargs
+        self,
+        target,
+        factor,
+        min=None,
+        max=None,
+        verbose=0,
+        log_key=None,
+        linearly=False,
+        **kwargs,
     ):
         pin(locals(), kwargs)
         super(ScaleOnPlateau, self).__init__(**kwargs)
@@ -722,7 +741,12 @@ class ScaleOnPlateau(PatientTracker):
                 if self.verbose > 0:
                     print(
                         "\nEpoch %05d: %s scaling %s to %s."
-                        % (epoch + 1, self.__class__.__name__, self.log_key or self.target, new)
+                        % (
+                            epoch + 1,
+                            self.__class__.__name__,
+                            self.log_key or self.target,
+                            new,
+                        )
                     )
 
 
@@ -1186,6 +1210,7 @@ def feature_importance(*args, method="grad", columns=[], **kwargs):
     return {
         k: v
         for k, v in sorted(
-            dict(zip(columns, importance.astype(float))).items(), key=lambda item: item[1]
+            dict(zip(columns, importance.astype(float))).items(),
+            key=lambda item: item[1],
         )
     }