Skip to content
Snippets Groups Projects
Commit b26baeec authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] PlotCallback: added feature importance

parent 871e399f
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ from .plotting import (
figure_y,
figure_weights,
figure_inputs,
figure_dict,
)
# various helper functions
......@@ -403,6 +404,7 @@ class PlotMulticlass(TFSummaryCallback):
columns=None,
plot_inputs=False,
signalvsbkg=False,
plot_importance=False,
tag="",
**kwargs,
):
......@@ -416,6 +418,7 @@ class PlotMulticlass(TFSummaryCallback):
self.to_file = to_file
self.signalvsbkg = signalvsbkg
self.tag = tag
self.plot_importance = plot_importance
def on_test_begin(self, logs=None):
self.on_train_begin(logs=logs)
......@@ -467,6 +470,7 @@ class PlotMulticlass(TFSummaryCallback):
prediction = self.model.predict(self.x)
truth = self.truth
imgs = {}
imgs["roc_curve"] = figure_to_image(
figure_roc_curve(
truth, prediction, class_names=self.class_names, sample_weight=self.sample_weight
......@@ -522,6 +526,16 @@ class PlotMulticlass(TFSummaryCallback):
sample_weight=self.sample_weight,
)
)
if self.plot_importance:
importance = feature_importance(
self.model,
x=[feat[:5000] for feat in self.x],
y=self.truth[:5000],
sample_weight=self.sample_weight[:5000],
method="grad",
columns=[c for col in self.columns.values() for c in col],
)
imgs["importance"] = figure_to_image(figure_dict(importance))
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(f"{name}{self.tag}", img, step=epoch)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment