diff --git a/keras.py b/keras.py
index 5e6a4f95e081978f1c0e3f52ad334ae0da0143ea..385a48265946f0024a443b6797771c1388144eba 100644
--- a/keras.py
+++ b/keras.py
@@ -133,7 +133,7 @@ class KFeed(object):
                 )
             )
             if len(keys) > 1
-            else src[keys[0]].batch_generator(batch_size, rng=rng)
+            else src[list(keys)[0]].batch_generator(batch_size, rng=rng)
         )
         gs = float(batch_size // len(keys))
         steps = int(auto_steps([src[k].blen / gs for k in keys])) or None
@@ -1123,3 +1123,58 @@ class LBNLayer(tf.keras.layers.Layer):
         feats = self.lbn_layer(ll)
         feats = self.batch_norm(feats)
         return feats
+
+
+def feature_importance_grad(model, x=None, **kwargs):
+    with tf.GradientTape() as tape:
+        inp = [tf.Variable(v) for v in x]
+        pred = model(inp)
+        ix = np.argsort(-pred, axis=-1)[:, 0]
+        loss = tf.gather(pred, ix, batch_dims=1)
+
+    gradients = tape.gradient(loss, inp)  # gradients for decision nodes
+    normed_gradients = [_g * _x for (_g, _x) in zip(gradients, x)]  # normed to input values
+
+    mean_gradients = np.concatenate(
+        [np.abs(g.numpy()).mean(axis=0).flatten() for g in normed_gradients]
+    )
+    return mean_gradients / mean_gradients.max()
+
+
+def feature_importance_perm(model, x=None, **kwargs):
+    inp_list = list(x)
+    feat = 1  # acc
+    ref = model.evaluate(x=x, **kwargs)[feat]
+    accs = []
+    for index, tensor in enumerate(inp_list):
+        s = tensor.shape
+        for i in range(np.prod(s[1:])):
+            arr = tensor.reshape((-1, np.prod(s[1:])))
+
+            slice_before = arr[:, :i]
+            slice_shuffled = np.random.permutation(arr[:, i : i + 1])
+            slice_after = arr[:, i + 1 :]
+            arr_shuffled = np.concatenate([slice_before, slice_shuffled, slice_after], axis=-1)
+
+            arr_shuffled_reshaped = arr_shuffled.reshape(s)
+
+            valid = inp_list.copy()
+            valid[index] = arr_shuffled_reshaped
+            accs.append(model.evaluate(x=valid, **kwargs)[feat])
+
+    return ref / np.array(accs)
+
+
+def feature_importance(*args, method="grad", columns=[], **kwargs):
+    if method == "grad":
+        importance = feature_importance_grad(*args, **kwargs)
+    elif method == "perm":
+        importance = feature_importance_perm(*args, **kwargs)
+    else:
+        raise NotImplementedError("Feature importance method not implemented")
+    return {
+        k: v
+        for k, v in sorted(
+            dict(zip(columns, importance.astype(float))).items(), key=lambda item: item[1]
+        )
+    }