Commit 6e6da4e2 authored by Julian Treiber's avatar Julian Treiber

adjusted accumulation of predictions for eval metric to work with semantic segmentation task

parent d6eb626f
......@@ -393,10 +393,6 @@ class ${tc.fileNameWithoutEnding}:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -512,10 +512,6 @@ class CNNSupervisedTrainer_Alexnet:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -512,10 +512,6 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -497,10 +497,6 @@ class CNNSupervisedTrainer_Invariant:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -485,10 +485,6 @@ class CNNSupervisedTrainer_MultipleStreams:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -599,10 +599,6 @@ class CNNSupervisedTrainer_RNNencdec:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -596,10 +596,6 @@ class CNNSupervisedTrainer_RNNsearch:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -565,10 +565,6 @@ class CNNSupervisedTrainer_RNNtest:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -470,10 +470,6 @@ class CNNSupervisedTrainer_ResNeXt50:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -587,10 +587,6 @@ class CNNSupervisedTrainer_Show_attend_tell:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -476,10 +476,6 @@ class CNNSupervisedTrainer_ThreeInputCNN_M14:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
......@@ -512,10 +512,6 @@ class CNNSupervisedTrainer_VGG16:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment