Skip to content
Snippets Groups Projects
Commit c0eb422d authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] feature_importance: now uses more events + loss

parent 41bdb6d1
No related branches found
No related tags found
No related merge requests found
......@@ -531,9 +531,9 @@ class PlotMulticlass(TFSummaryCallback):
if self.plot_importance:
importance = feature_importance(
self.model,
x=[feat[:5000] for feat in self.x],
y=self.truth[:5000],
sample_weight=self.sample_weight[:5000],
x=[feat[:20000] for feat in self.x],
y=self.truth[:20000],
sample_weight=self.sample_weight[:20000],
method="grad",
columns=[c for col in self.columns.values() for c in col],
)
......@@ -1332,9 +1332,9 @@ def feature_importance_grad(model, x=None, **kwargs):
def feature_importance_perm(model, x=None, **kwargs):
inp_list = list(x)
feat = 1 # acc
feat = 0 # loss
ref = model.evaluate(x=x, **kwargs)[feat]
accs = []
vals = []
for index, tensor in enumerate(inp_list):
s = tensor.shape
for i in range(np.prod(s[1:])):
......@@ -1349,8 +1349,8 @@ def feature_importance_perm(model, x=None, **kwargs):
valid = inp_list.copy()
valid[index] = arr_shuffled_reshaped
accs.append(model.evaluate(x=valid, **kwargs)[feat])
return ref / np.array(accs)
vals.append(model.evaluate(x=valid, **kwargs)[feat])
return np.array(vals) / ref
def feature_importance(*args, method="grad", columns=[], **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment