Commit 6e6da4e2 authored by Julian Treiber's avatar Julian Treiber

adjusted accumulation of predictions for eval metric to work with semantic segmentation task

parent d6eb626f
...@@ -393,11 +393,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -393,11 +393,7 @@ class ${tc.fileNameWithoutEnding}:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -512,11 +512,7 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -512,11 +512,7 @@ class CNNSupervisedTrainer_Alexnet:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -512,11 +512,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -512,11 +512,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -497,11 +497,7 @@ class CNNSupervisedTrainer_Invariant: ...@@ -497,11 +497,7 @@ class CNNSupervisedTrainer_Invariant:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -485,11 +485,7 @@ class CNNSupervisedTrainer_MultipleStreams: ...@@ -485,11 +485,7 @@ class CNNSupervisedTrainer_MultipleStreams:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -599,11 +599,7 @@ class CNNSupervisedTrainer_RNNencdec: ...@@ -599,11 +599,7 @@ class CNNSupervisedTrainer_RNNencdec:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -596,11 +596,7 @@ class CNNSupervisedTrainer_RNNsearch: ...@@ -596,11 +596,7 @@ class CNNSupervisedTrainer_RNNsearch:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -565,11 +565,7 @@ class CNNSupervisedTrainer_RNNtest: ...@@ -565,11 +565,7 @@ class CNNSupervisedTrainer_RNNtest:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -470,11 +470,7 @@ class CNNSupervisedTrainer_ResNeXt50: ...@@ -470,11 +470,7 @@ class CNNSupervisedTrainer_ResNeXt50:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -587,11 +587,7 @@ class CNNSupervisedTrainer_Show_attend_tell: ...@@ -587,11 +587,7 @@ class CNNSupervisedTrainer_Show_attend_tell:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -476,11 +476,7 @@ class CNNSupervisedTrainer_ThreeInputCNN_M14: ...@@ -476,11 +476,7 @@ class CNNSupervisedTrainer_ThreeInputCNN_M14:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
...@@ -512,11 +512,7 @@ class CNNSupervisedTrainer_VGG16: ...@@ -512,11 +512,7 @@ class CNNSupervisedTrainer_VGG16:
predictions = [] predictions = []
for output_name in outputs: for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1: predictions.append(output_name)
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=labels) metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment