Commit 419cedce authored by Benjamin Fischer's avatar Benjamin Fischer
Browse files

utils/plot: add plot_cov

parent 715ad7b5
......@@ -272,6 +272,62 @@ def plotratio(
return ax
def plot_cov(data, labels, size=5, step=0.5, text=None):
assert data.shape == data.shape[2::-1]
n = size + step * data.shape[0]
fig, ax = plt.subplots(figsize=(n, n))
im = ax.imshow(data, cmap="PuOr", vmin=-1, vmax=1)
# Create colorbar
cbar = fig.colorbar(im, ax=ax)
cbar.ax.set_ylabel("covariance", rotation=-90, va="bottom")
# We want to show all ticks...
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
# ... and label them with the respective list entries.
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
# Let the horizontal axes labeling appear on top.
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
# Rotate the tick labels and set their alignment.
# plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.setp(ax.get_xticklabels(), rotation=-45, ha="right", rotation_mode="anchor")
# Turn spines off and create white grid.
ax.spines[:].set_visible(False)
# ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
# ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
# ax.tick_params(which="minor", bottom=False, left=False)
# ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
if text is None:
text = data.shape[0] <= 30
if text:
colors = ("black", "white")
for idx, val in np.ndenumerate(data):
txt = f"{val:.2f}"
if ".00" in txt:
continue
ax.text(
*idx,
txt.replace("0.", "."),
color=colors[abs(val) > 0.5],
horizontalalignment="center",
verticalalignment="center",
size=7,
)
fig.tight_layout()
return fig
def _label_text_props(halign, valign="baseline", **kwargs):
return dict(
verticalalignment=valign,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment