Skip to content
Snippets Groups Projects
Commit 871e399f authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] feature_importance: fixes grad method

parent ef1f25eb
No related branches found
No related tags found
No related merge requests found
...@@ -1125,13 +1125,13 @@ class LBNLayer(tf.keras.layers.Layer): ...@@ -1125,13 +1125,13 @@ class LBNLayer(tf.keras.layers.Layer):
def feature_importance_grad(model, x=None, **kwargs): def feature_importance_grad(model, x=None, **kwargs):
inp = [tf.Variable(v) for v in x]
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
inp = [tf.Variable(v) for v in x] pred = model(inp, training=False)
pred = model(inp) ix = np.argsort(pred, axis=-1)[:, -1]
ix = np.argsort(-pred, axis=-1)[:, 0] decision = tf.gather(pred, ix, batch_dims=1)
loss = tf.gather(pred, ix, batch_dims=1)
gradients = tape.gradient(loss, inp) # gradients for decision nodes gradients = tape.gradient(decision, inp) # gradients for decision nodes
normed_gradients = [_g * _x for (_g, _x) in zip(gradients, x)] # normed to input values normed_gradients = [_g * _x for (_g, _x) in zip(gradients, x)] # normed to input values
mean_gradients = np.concatenate( mean_gradients = np.concatenate(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment