diff --git a/keras.py b/keras.py index 5e6a4f95e081978f1c0e3f52ad334ae0da0143ea..385a48265946f0024a443b6797771c1388144eba 100644 --- a/keras.py +++ b/keras.py @@ -133,7 +133,7 @@ class KFeed(object): ) ) if len(keys) > 1 - else src[keys[0]].batch_generator(batch_size, rng=rng) + else src[list(keys)[0]].batch_generator(batch_size, rng=rng) ) gs = float(batch_size // len(keys)) steps = int(auto_steps([src[k].blen / gs for k in keys])) or None @@ -1123,3 +1123,58 @@ class LBNLayer(tf.keras.layers.Layer): feats = self.lbn_layer(ll) feats = self.batch_norm(feats) return feats + + +def feature_importance_grad(model, x=None, **kwargs): + with tf.GradientTape() as tape: + inp = [tf.Variable(v) for v in x] + pred = model(inp) + ix = np.argsort(-pred, axis=-1)[:, 0] + loss = tf.gather(pred, ix, batch_dims=1) + + gradients = tape.gradient(loss, inp) # gradients for decision nodes + normed_gradients = [_g * _x for (_g, _x) in zip(gradients, x)] # normed to input values + + mean_gradients = np.concatenate( + [np.abs(g.numpy()).mean(axis=0).flatten() for g in normed_gradients] + ) + return mean_gradients / mean_gradients.max() + + +def feature_importance_perm(model, x=None, **kwargs): + inp_list = list(x) + feat = 1 # acc + ref = model.evaluate(x=x, **kwargs)[feat] + accs = [] + for index, tensor in enumerate(inp_list): + s = tensor.shape + for i in range(np.prod(s[1:])): + arr = tensor.reshape((-1, np.prod(s[1:]))) + + slice_before = arr[:, :i] + slice_shuffled = np.random.permutation(arr[:, i : i + 1]) + slice_after = arr[:, i + 1 :] + arr_shuffled = np.concatenate([slice_before, slice_shuffled, slice_after], axis=-1) + + arr_shuffled_reshaped = arr_shuffled.reshape(s) + + valid = inp_list.copy() + valid[index] = arr_shuffled_reshaped + accs.append(model.evaluate(x=valid, **kwargs)[feat]) + + return ref / np.array(accs) + + +def feature_importance(*args, method="grad", columns=[], **kwargs): + if method == "grad": + importance = feature_importance_grad(*args, **kwargs) + elif method == "perm": + importance = feature_importance_perm(*args, **kwargs) + else: + raise NotImplementedError("Feature importance method not implemented") + return { + k: v + for k, v in sorted( + dict(zip(columns, importance.astype(float))).items(), key=lambda item: item[1] + ) + }