Commit f02920c0 authored by Sebastian Nickels's avatar Sebastian Nickels

Implemented multiple outputs

parent 5196adbc
Pipeline #147998 failed with stages
in 4 minutes and 54 seconds
......@@ -53,7 +53,7 @@ public class IntegrationGluonTest extends IntegrationTest {
assertTrue(Log.getFindings().isEmpty());
}
/*@Test
@Test
public void testMultipleOutputs() {
Log.getFindings().clear();
......@@ -63,7 +63,7 @@ public class IntegrationGluonTest extends IntegrationTest {
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}*/
}
private void deleteHashFile(Path hashFile) {
try {
......
......@@ -71,6 +71,7 @@ class NoNormalization(gluon.HybridBlock):
class Net(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
if data_mean:
assert(data_std)
......@@ -109,8 +110,7 @@ class Net(gluon.HybridBlock):
self.fc3_ = gluon.nn.Dense(units=10, use_bias=True)
# fc3_, output shape: {[10,1,1]}
self.last_layer = 'softmax'
self.last_layers['predictions'] = 'softmax'
def hybrid_forward(self, F, image):
......@@ -124,3 +124,4 @@ class Net(gluon.HybridBlock):
relu2_ = self.relu2_(fc2_)
fc3_ = self.fc3_(relu2_)
return fc3_
......@@ -69,15 +69,18 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
trainer = mx.gluon.Trainer(self._net.collect_params(), optimizer, optimizer_params)
if self._net.last_layer == 'softmax':
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss()
elif self._net.last_layer == 'sigmoid':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif self._net.last_layer == 'linear':
loss_function = mx.gluon.loss.L2Loss()
else:
loss_function = mx.gluon.loss.L2Loss()
logging.warning("Invalid last_layer, defaulting to L2 loss")
loss_functions = {}
for output_name, last_layer in self._net.last_layers.items():
if last_layer == 'softmax':
loss_functions[output_name] = mx.gluon.loss.SoftmaxCrossEntropyLoss()
elif last_layer == 'sigmoid':
loss_functions[output_name] = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif last_layer == 'linear':
loss_functions[output_name] = mx.gluon.loss.L2Loss()
else:
loss_functions[output_name] = mx.gluon.loss.L2Loss()
logging.warning("Invalid last layer, defaulting to L2 loss")
speed_period = 50
tic = None
......@@ -85,12 +88,13 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
for epoch in range(begin_epoch, begin_epoch + num_epoch):
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
image = batch.data[0].as_in_context(mx_context)
label = batch.label[0].as_in_context(mx_context)
image_data = batch.data[0].as_in_context(mx_context)
predictions_label = batch.label[0].as_in_context(mx_context)
with autograd.record():
output = self._net(image)
loss = loss_function(output, label)
predictions_output = self._net(image_data)
loss = loss_functions['predictions'](predictions_output, predictions_label)
loss.backward()
trainer.step(batch_size)
......@@ -113,23 +117,37 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
train_iter.reset()
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(train_iter):
image = batch.data[0].as_in_context(mx_context)
label = batch.label[0].as_in_context(mx_context)
image_data = batch.data[0].as_in_context(mx_context)
labels = [
batch.label[0].as_in_context(mx_context)
]
predictions_output = self._net(image_data)
output = self._net(image)
predictions = mx.nd.argmax(output, axis=1)
metric.update(preds=predictions, labels=label)
predictions = [
mx.nd.argmax(predictions_output, axis=1)
]
metric.update(preds=predictions, labels=labels)
train_metric_score = metric.get()[1]
test_iter.reset()
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(test_iter):
image = batch.data[0].as_in_context(mx_context)
label = batch.label[0].as_in_context(mx_context)
image_data = batch.data[0].as_in_context(mx_context)
labels = [
batch.label[0].as_in_context(mx_context)
]
predictions_output = self._net(image_data)
predictions = [
mx.nd.argmax(predictions_output, axis=1)
]
output = self._net(image)
predictions = mx.nd.argmax(output, axis=1)
metric.update(preds=predictions, labels=label)
metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1]
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment