Skip to content
Snippets Groups Projects
Commit a8175b77 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] callbacks: adds callback to plot lbn weights

parent c5b919aa
No related branches found
No related tags found
No related merge requests found
......@@ -30,6 +30,7 @@ from .plotting import (
figure_weights,
figure_inputs,
figure_dict,
figure_lbn_weights,
)
# various helper functions
......@@ -553,6 +554,49 @@ class PlotMulticlassEval(PlotMulticlass):
self.make_eval_plots(0)
class PlotLBN(TFSummaryCallback):
def __init__(self, lbn_layer=None, inp_particle_names=None, *args, **kwargs):
self.lbn = lbn_layer
self.inp_particle_names = inp_particle_names
super().__init__(*args, **kwargs)
def on_train_end(self, logs=None):
self.make_plots()
def make_plots(self, epoch=0):
imgs = {}
pkwargs = {"inp_particle_names": self.inp_particle_names}
imgs["lbn_particles"] = figure_to_image(
figure_lbn_weights(
self.lbn.weights[0].numpy(), name="particles", cmap="OrRd", **pkwargs
)
)
imgs["lbn_restframes"] = figure_to_image(
figure_lbn_weights(
self.lbn.weights[1].numpy(), name="restframes", cmap="YlGn", **pkwargs
)
)
pkwargs["norm"] = True
imgs["lbn_particles_normed"] = figure_to_image(
figure_lbn_weights(
self.lbn.weights[0].numpy(), name="particles", cmap="OrRd", **pkwargs
)
)
imgs["lbn_restframes_normed"] = figure_to_image(
figure_lbn_weights(
self.lbn.weights[1].numpy(), name="restframes", cmap="YlGn", **pkwargs
)
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(name, img, step=epoch)
class PlotLBNEval(PlotLBN):
def on_test_end(self, logs=None):
self.make_plots()
class ModelLH(tf.keras.Model):
def __init__(self, *args, **kwargs):
self.loss_hook = kwargs.pop("loss_hook", None)
......
......@@ -419,3 +419,40 @@ def figure_to_image(figure):
# Add the batch dimension
image = tf.expand_dims(image, 0)
return image
def figure_lbn_weights(
weights,
name,
norm=False,
path="",
cmap="OrRd",
inp_particle_names=None,
**fig_kwargs,
):
# normalize weight tensor to a sum of 100 per row
if norm:
weights = np.abs(weights)
weights = weights / np.sum(weights, axis=0).reshape((1, weights.shape[1])) * 100
# create the figure
ph, pv = weights.shape
fig, ax = plt.subplots(1, 1, figsize=((pv + 1) // 2, (ph + 1) // 2))
# create and style the plot
ax.imshow(weights, cmap=cmap, vmin=0, vmax=100, origin="lower")
ax.set_title("{} weights".format(name), fontdict={"fontsize": 12})
ax.set_xlabel("LBN particle number")
ax.set_xticks(list(range(weights.shape[1])))
ax.set_ylabel("Input particle")
ax.set_yticks(list(range(weights.shape[0])))
if inp_particle_names:
ax.set_yticklabels(inp_particle_names)
# write weights into each bin
for (i, j), val in np.ndenumerate(weights):
ax.text(j, i, int(round(weights[i, j])), fontsize=8, ha="center", va="center", color="k")
return fig
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment