Skip to content
Snippets Groups Projects
Commit b93c50ad authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] feature importance perm: now uses tf data dataset

parent f7da208a
No related branches found
No related tags found
No related merge requests found
......@@ -1315,21 +1315,38 @@ class LBNLayer(tf.keras.layers.Layer):
def feature_importance_grad(model, x=None, **kwargs):
inp = [tf.constant(v) for v in x]
with tf.GradientTape() as tape:
tape.watch(inp)
pred = model(inp, training=False)
ix = np.argsort(pred, axis=-1)[:, -1]
decision = tf.gather(pred, ix, batch_dims=1)
gradients = tape.gradient(decision, inp) # gradients for decision nodes
gradients = [grad if grad is not None else 0 for grad in gradients] # categorical tensors
gradients = [
np.array(_g) * _x.std(axis=0) for (_g, _x) in zip(gradients, x)
] # norm to value ranges
if "sample_weight" in kwargs:
gradients = [(_g.T * kwargs["sample_weight"]).T for _g in gradients] # apply ev weight
mean_gradients = np.concatenate([np.abs(g).mean(axis=0).flatten() for g in gradients])
inp = {f"{i}": _x for (i, _x) in enumerate(x)}
inp["sample_weight"] = kwargs["sample_weight"]
ds = tf.data.Dataset.from_tensor_slices((inp))
ds = ds.batch(20000)
grad = 0
n_batches = tf.data.experimental.cardinality(ds).numpy()
with tqdm(total=n_batches) as pbar:
for _x in ds:
if "sample_weight" in _x:
sw = _x["sample_weight"]
_x.pop("sample_weight")
inp = list(_x.values())
with tf.GradientTape() as tape:
tape.watch(inp)
pred = model(inp, training=False)
ix = tf.argsort(pred, axis=-1)[:, -1]
decision = tf.gather(pred, ix, batch_dims=1)
gradients = tape.gradient(decision, inp) # gradients for decision nodes
gradients = [
grad if grad is not None else 0 for grad in gradients
] # categorical tensors
gradients = [
np.array(_g) * _x.std(axis=0) for (_g, _x) in zip(gradients, x)
] # norm to value ranges
gradients = [(_g.T * sw.numpy()).T for _g in gradients] # apply ev weight
grad += np.concatenate([np.abs(g).mean(axis=0).flatten() for g in gradients])
pbar.update()
mean_gradients = grad / n_batches
return mean_gradients / mean_gradients.max()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment