Skip to content
Snippets Groups Projects
Commit 579382cc authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] GroupedXEnt: now supports one-element groups

parent 90d8f7a9
No related branches found
No related tags found
No related merge requests found
......@@ -1265,20 +1265,25 @@ def grouped_cross_entropy_t(
if sample_weight is not None:
losses *= sample_weight[:, tf.newaxis]
# create grouped labels and predictions
labels_grouped = tf.concat(
[
labels_grouped = []
for _, ids in group_ids:
labels_grouped.append(
tf.reduce_sum(tf.gather(labels, ids, axis=-1), axis=-1, keepdims=True)
for _, ids in group_ids
],
axis=-1,
)
labels_grouped = (
tf.concat(labels_grouped, axis=-1) if len(labels_grouped) > 1 else labels_grouped[0]
)
predictions_grouped = tf.concat(
[
predictions_grouped = []
for _, ids in group_ids:
predictions_grouped.append(
tf.reduce_sum(tf.gather(predictions, ids, axis=-1), axis=-1, keepdims=True)
for _, ids in group_ids
],
axis=-1,
)
predictions_grouped = (
tf.concat(predictions_grouped, axis=-1)
if len(predictions_grouped) > 1
else predictions_grouped[0]
)
predictions_grouped = tf.clip_by_value(predictions_grouped, epsilon, 1 - epsilon)
# grouped true-negative component
tn_grouped = labels_grouped * tf.math.log(predictions_grouped)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment