Skip to content
Snippets Groups Projects
Commit 4a51f0e1 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] PlotMulticlass: adds plotting

parent c6454af7
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,9 @@ from .plotting import (
figure_node_activations,
figure_roc_curve,
figure_multihist,
figure_y,
figure_weights,
figure_inputs,
)
# various helper functions
......@@ -431,14 +434,30 @@ class PlotMulticlass(TFSummaryCallback):
figure_multihist(inp, columns=self.columns[part])
)
if self.sample_weight is not None:
for part, inp in zip(self.columns.keys(), inps):
imgs[f"inputs_{part}"] = figure_to_image(
figure_inputs(
inp,
self.truth,
sample_weight=self.sample_weight,
columns=self.columns[part],
class_names=self.class_names,
)
)
imgs["weights"] = figure_to_image(
figure_weights(self.sample_weight, self.truth, class_names=self.class_names)
)
imgs["processes"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
imgs["processes_relative"] = figure_to_image(
figure_y(self.truth, class_names=self.class_names, relative=True)
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(name, img, step=0)
imgs["processes"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
imgs["processes_relative"] = figure_to_image(
figure_y(self.truth, class_names=self.class_names, relative=True)
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(name, img, step=0)
def on_test_end(self, logs=None):
self.on_epoch_end(epoch=0, logs=logs)
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment