From 942f9441e80886bcd8ece51f23e2bc1e76e4f9bd Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Thu, 14 Jan 2021 10:03:29 +0100 Subject: [PATCH] [keras] Importance: Cleaned up importance functions Commit may be pretty resource extensive, TODO to look into it at another time again --- keras.py | 83 ++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 29 deletions(-) diff --git a/keras.py b/keras.py index b9aede7..6497baa 100644 --- a/keras.py +++ b/keras.py @@ -530,17 +530,34 @@ class PlotMulticlass(TFSummaryCallback): self.clear_figure(fig) if self.plot_importance: - importance = feature_importance( - self.model, - x=self.x, - y=self.truth, - sample_weight=self.sample_weight, - method="grad", - columns=[c for col in self.columns.values() for c in col], - ) - fig = figure_dict(importance) - imgs["importance"] = figure_to_image(fig) - self.clear_figure(fig) + for method in ["grad", "perm"]: + importance = feature_importance( + self.model, + x=[a[:10000] for a in self.x], + y=self.truth[:10000], + sample_weight=self.sample_weight[:10000], + method=method, + columns=[c for col in self.columns.values() for c in col], + ) + particles = ["px", "py", "pz", "energy"] + fig = figure_dict( + { + key: importance[key] + for key in importance.keys() + if any(key.endswith(p) for p in particles) + } + ) + imgs[f"importance_{method}_ll"] = figure_to_image(fig) + self.clear_figure(fig) + fig = figure_dict( + { + key: importance[key] + for key in importance.keys() + if not any(key.endswith(p) for p in particles) + } + ) + imgs[f"importance_{method}_hl"] = figure_to_image(fig) + self.clear_figure(fig) for name, img in imgs.items(): with self.writer.as_default(): @@ -1358,12 +1375,11 @@ class LBNLayer(tf.keras.layers.Layer): return feats -def feature_importance_grad(model, x=None, **kwargs): +def feature_importance_grad(model, x=None, filter_types=["int32", "int64"], **kwargs): inp = {f"{i}": _x for (i, _x) in enumerate(x)} inp["sample_weight"] = kwargs["sample_weight"] - ds = tf.data.Dataset.from_tensor_slices((inp)) - ds = ds.batch(10000) + ds = ds.batch(256) grad = 0 n_batches = tf.data.experimental.cardinality(ds).numpy() @@ -1371,26 +1387,33 @@ def feature_importance_grad(model, x=None, **kwargs): pbar.set_description("Importance grad") for _x in ds: if "sample_weight" in _x: - sw = tf.cast(_x["sample_weight"], tf.float32) - _x.pop("sample_weight") + sw = tf.cast(_x.pop("sample_weight"), tf.float32) inp = list(_x.values()) - with tf.GradientTape() as tape: - tape.watch(inp) + with tf.GradientTape(watch_accessed_variables=False) as tape: + [tape.watch(i) for i in inp if i.dtype not in filter_types] pred = model(inp, training=False) - ix = tf.argsort(pred, axis=-1)[:, -1] + ix = tf.keras.layers.Lambda(lambda x: x[:, -1])(tf.argsort(pred, axis=-1)) decision = tf.gather(pred, ix, batch_dims=1) - gradients = tape.gradient(decision, inp) # gradients for decision nodes + gradients = tape.gradient( + decision, [i for i in inp if i.dtype not in filter_types] + ) # gradients for decision nodes gradients = [ grad if grad is not None else tf.constant(0.0, dtype=tf.float32) for grad in gradients ] # categorical tensors - gradients = [tf.transpose(tf.transpose(_g) * sw) for _g in gradients] - gradients = [tf.math.reduce_mean(tf.math.abs(_g), axis=0) for _g in gradients] - gradients = [ - _g * tf.math.reduce_std(i, axis=0) for (_g, i) in zip(gradients, inp) - ] # norm to value ranges - grad += tf.concat([tf.keras.backend.flatten(_g) for _g in gradients], axis=-1) + gradients = [tf.transpose(tf.transpose(g) * sw) for g in gradients] + gradients = [tf.math.reduce_mean(tf.math.abs(g), axis=0) for g in gradients] + + # norm gradients to std + normed_gradients = [] + for g, i in zip(gradients, inp): + if i.dtype in filter_types: + val = np.nan * tf.math.reduce_std(i, axis=0) + else: + val = g * tf.math.reduce_std(i, axis=0) + normed_gradients.append(val) + grad += tf.concat([tf.keras.backend.flatten(g) for g in normed_gradients], axis=-1) pbar.update() mean_gradients = grad.numpy() / n_batches @@ -1398,16 +1421,17 @@ def feature_importance_grad(model, x=None, **kwargs): return mean_gradients / mean_gradients.max() -def feature_importance_perm(model, x=None, **kwargs): +def feature_importance_perm(model, x=None, filter_types=["int32", "int64"], **kwargs): inp_list = list(x) feat = 0 # loss ref = model.evaluate(x=x, **kwargs)[feat] vals = [] for index, tensor in enumerate(inp_list): + if tensor.dtype in filter_types: + vals.append(np.nan) s = tensor.shape for i in range(np.prod(s[1:])): arr = tensor.reshape((-1, np.prod(s[1:]))) - slice_before = arr[:, :i] slice_shuffled = np.random.permutation(arr[:, i : i + 1]) slice_after = arr[:, i + 1 :] @@ -1418,7 +1442,7 @@ def feature_importance_perm(model, x=None, **kwargs): valid = inp_list.copy() valid[index] = arr_shuffled_reshaped vals.append(model.evaluate(x=valid, **kwargs)[feat]) - return np.array(vals) / ref + return (np.array(vals) - np.min(np.array(vals))) / ref def feature_importance(*args, method="grad", columns=[], **kwargs): @@ -1434,6 +1458,7 @@ def feature_importance(*args, method="grad", columns=[], **kwargs): dict(zip(columns, importance.astype(float))).items(), key=lambda item: item[1], ) + if v is not np.nan } -- GitLab