Skip to content
Snippets Groups Projects
Commit ecedc074 authored by Dennis Noll's avatar Dennis Noll
Browse files

[plotting] black

parent ccaa2da8
No related branches found
No related tags found
No related merge requests found
......@@ -73,7 +73,13 @@ def figure_confusion_matrix(
plt.yticks(tick_marks, class_names)
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7)
plt.text(
j,
i,
np.around(cm[i, j], decimals=2),
horizontalalignment="center",
size=7,
)
plt.ylabel("True label" + " (normed)" * (normalize == "true"))
plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred"))
......@@ -109,18 +115,18 @@ def figure_history(history_csv_path):
def plot_histories(history_csv_path, path, cut=None, roll=1):
pdf = pd.read_csv(history_csv_path)
pdf = pdf.set_index("epoch")
pdf = pdf.truncate(after=cut)
for col in pdf.columns:
df = pd.read_csv(history_csv_path)
df = df.set_index("epoch")
df = df.truncate(after=cut)
for col in df.columns:
if col.startswith("val_"):
continue
fig = plt.figure()
if "val_" + col in pdf.columns:
if "val_" + col in df.columns:
ind = [col, "val_" + col]
else:
ind = col
value = pdf[ind]
value = df[ind]
ax = value.rolling(roll, min_periods=1).mean().plot()
ax.set_xlabel("Epoch")
ax.set_ylabel(col.capitalize())
......@@ -228,7 +234,13 @@ def figure_node_activations(
def figure_roc_curve(
truth, prediction, indices=[0], class_names=None, sample_weight=None, lw=2, scale="linear"
truth,
prediction,
indices=[0],
class_names=None,
sample_weight=None,
lw=2,
scale="linear",
):
fig = plt.figure()
for index in indices:
......@@ -310,7 +322,13 @@ def figure_inputs(
def figure_weight_study(
class_inps, sample_weights=None, columns=None, label=None, log=False, mode="plain", **kwargs
class_inps,
sample_weights=None,
columns=None,
label=None,
log=False,
mode="plain",
**kwargs,
):
multiplot = Multiplot(class_inps[0].shape[1:][::-1])
rows, cols = multiplot.lenghts()
......@@ -352,7 +370,10 @@ def figure_weight_study(
if mode == "weight":
mask = sample_weights[0] > 0
pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask]
neg_feat, neg_weight = class_inps[0][:, feat][~mask], sample_weights[0][~mask]
neg_feat, neg_weight = (
class_inps[0][:, feat][~mask],
sample_weights[0][~mask],
)
val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight)
val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment