diff --git a/keras.py b/keras.py
index b9aede7225899ff09a4a177a801e99add570bf0b..6497baaef85f184e3a2bbdd825f15e7f303ca191 100644
--- a/keras.py
+++ b/keras.py
@@ -530,17 +530,34 @@ class PlotMulticlass(TFSummaryCallback):
             self.clear_figure(fig)
 
         if self.plot_importance:
-            importance = feature_importance(
-                self.model,
-                x=self.x,
-                y=self.truth,
-                sample_weight=self.sample_weight,
-                method="grad",
-                columns=[c for col in self.columns.values() for c in col],
-            )
-            fig = figure_dict(importance)
-            imgs["importance"] = figure_to_image(fig)
-            self.clear_figure(fig)
+            for method in ["grad", "perm"]:
+                importance = feature_importance(
+                    self.model,
+                    x=[a[:10000] for a in self.x],
+                    y=self.truth[:10000],
+                    sample_weight=self.sample_weight[:10000],
+                    method=method,
+                    columns=[c for col in self.columns.values() for c in col],
+                )
+                particles = ["px", "py", "pz", "energy"]
+                fig = figure_dict(
+                    {
+                        key: importance[key]
+                        for key in importance.keys()
+                        if any(key.endswith(p) for p in particles)
+                    }
+                )
+                imgs[f"importance_{method}_ll"] = figure_to_image(fig)
+                self.clear_figure(fig)
+                fig = figure_dict(
+                    {
+                        key: importance[key]
+                        for key in importance.keys()
+                        if not any(key.endswith(p) for p in particles)
+                    }
+                )
+                imgs[f"importance_{method}_hl"] = figure_to_image(fig)
+                self.clear_figure(fig)
 
         for name, img in imgs.items():
             with self.writer.as_default():
@@ -1358,12 +1375,11 @@ class LBNLayer(tf.keras.layers.Layer):
         return feats
 
 
-def feature_importance_grad(model, x=None, **kwargs):
+def feature_importance_grad(model, x=None, filter_types=["int32", "int64"], **kwargs):
     inp = {f"{i}": _x for (i, _x) in enumerate(x)}
     inp["sample_weight"] = kwargs["sample_weight"]
-
     ds = tf.data.Dataset.from_tensor_slices((inp))
-    ds = ds.batch(10000)
+    ds = ds.batch(256)
     grad = 0
     n_batches = tf.data.experimental.cardinality(ds).numpy()
 
@@ -1371,26 +1387,33 @@ def feature_importance_grad(model, x=None, **kwargs):
         pbar.set_description("Importance grad")
         for _x in ds:
             if "sample_weight" in _x:
-                sw = tf.cast(_x["sample_weight"], tf.float32)
-                _x.pop("sample_weight")
+                sw = tf.cast(_x.pop("sample_weight"), tf.float32)
             inp = list(_x.values())
-            with tf.GradientTape() as tape:
-                tape.watch(inp)
+            with tf.GradientTape(watch_accessed_variables=False) as tape:
+                [tape.watch(i) for i in inp if i.dtype not in filter_types]
                 pred = model(inp, training=False)
-                ix = tf.argsort(pred, axis=-1)[:, -1]
+                ix = tf.keras.layers.Lambda(lambda x: x[:, -1])(tf.argsort(pred, axis=-1))
                 decision = tf.gather(pred, ix, batch_dims=1)
-            gradients = tape.gradient(decision, inp)  # gradients for decision nodes
+            gradients = tape.gradient(
+                decision, [i for i in inp if i.dtype not in filter_types]
+            )  # gradients for decision nodes
             gradients = [
                 grad if grad is not None else tf.constant(0.0, dtype=tf.float32)
                 for grad in gradients
             ]  # categorical tensors
-            gradients = [tf.transpose(tf.transpose(_g) * sw) for _g in gradients]
-            gradients = [tf.math.reduce_mean(tf.math.abs(_g), axis=0) for _g in gradients]
-            gradients = [
-                _g * tf.math.reduce_std(i, axis=0) for (_g, i) in zip(gradients, inp)
-            ]  # norm to value ranges
-            grad += tf.concat([tf.keras.backend.flatten(_g) for _g in gradients], axis=-1)
+            gradients = [tf.transpose(tf.transpose(g) * sw) for g in gradients]
+            gradients = [tf.math.reduce_mean(tf.math.abs(g), axis=0) for g in gradients]
+
+            # norm gradients to std
+            normed_gradients = []
+            for g, i in zip(gradients, inp):
+                if i.dtype in filter_types:
+                    val = np.nan * tf.math.reduce_std(i, axis=0)
+                else:
+                    val = g * tf.math.reduce_std(i, axis=0)
+                normed_gradients.append(val)
 
+            grad += tf.concat([tf.keras.backend.flatten(g) for g in normed_gradients], axis=-1)
             pbar.update()
 
     mean_gradients = grad.numpy() / n_batches
@@ -1398,16 +1421,17 @@ def feature_importance_grad(model, x=None, **kwargs):
     return mean_gradients / mean_gradients.max()
 
 
-def feature_importance_perm(model, x=None, **kwargs):
+def feature_importance_perm(model, x=None, filter_types=["int32", "int64"], **kwargs):
     inp_list = list(x)
     feat = 0  # loss
     ref = model.evaluate(x=x, **kwargs)[feat]
     vals = []
     for index, tensor in enumerate(inp_list):
+        if tensor.dtype in filter_types:
+            vals.append(np.nan)
         s = tensor.shape
         for i in range(np.prod(s[1:])):
             arr = tensor.reshape((-1, np.prod(s[1:])))
-
             slice_before = arr[:, :i]
             slice_shuffled = np.random.permutation(arr[:, i : i + 1])
             slice_after = arr[:, i + 1 :]
@@ -1418,7 +1442,7 @@ def feature_importance_perm(model, x=None, **kwargs):
             valid = inp_list.copy()
             valid[index] = arr_shuffled_reshaped
             vals.append(model.evaluate(x=valid, **kwargs)[feat])
-    return np.array(vals) / ref
+    return (np.array(vals) - np.min(np.array(vals))) / ref
 
 
 def feature_importance(*args, method="grad", columns=[], **kwargs):
@@ -1434,6 +1458,7 @@ def feature_importance(*args, method="grad", columns=[], **kwargs):
             dict(zip(columns, importance.astype(float))).items(),
             key=lambda item: item[1],
         )
+        if v is not np.nan
     }