Commit 55bed0c3 authored by Julian Treiber's avatar Julian Treiber

adjusted target_code for semantic segmentation task

parent 9fbfa58f
......@@ -512,11 +512,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment