diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java index a55ed53f6a514161ef6d3298594481d7b592a25d..e213c30b0dc0d166bdd204274c72f12995cc9111 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java @@ -15,8 +15,8 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport return true; } - /*protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) { + protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) { return true; - }*/ + } } diff --git a/src/main/resources/templates/gluon/CNNNet.ftl b/src/main/resources/templates/gluon/CNNNet.ftl index d032cd4a4c31ae46f03a63e8da356f0386ebb178..0c75c37dded0dd72d7947bb155ae1d4dd8811beb 100644 --- a/src/main/resources/templates/gluon/CNNNet.ftl +++ b/src/main/resources/templates/gluon/CNNNet.ftl @@ -71,8 +71,15 @@ class NoNormalization(gluon.HybridBlock): class Net(gluon.HybridBlock): def __init__(self, data_mean=None, data_std=None, **kwargs): super(Net, self).__init__(**kwargs) + self.last_layers = {} with self.name_scope(): ${tc.include(tc.architecture.streams[0], "ARCHITECTURE_DEFINITION")} def hybrid_forward(self, F, ${tc.join(tc.architectureInputs, ", ")}): -${tc.include(tc.architecture.streams[0], "FORWARD_FUNCTION")} \ No newline at end of file + <#if tc.architectureOutputs?size gt 1> + outputs = [] + </#if> +${tc.include(tc.architecture.streams[0], "FORWARD_FUNCTION")} + <#if tc.architectureOutputs?size gt 1> + return tuple(outputs) + </#if> \ No newline at end of file diff --git a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl index 04aa5a657b1abfb14643a75233f328307345ac27..08a640b6c66ac05c598fa0bef4e19b093e99854a 100644 --- a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl +++ b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl @@ -69,15 +69,18 @@ class ${tc.fileNameWithoutEnding}: trainer = mx.gluon.Trainer(self._net.collect_params(), optimizer, optimizer_params) - if self._net.last_layer == 'softmax': - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss() - elif self._net.last_layer == 'sigmoid': - loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() - elif self._net.last_layer == 'linear': - loss_function = mx.gluon.loss.L2Loss() - else: - loss_function = mx.gluon.loss.L2Loss() - logging.warning("Invalid last_layer, defaulting to L2 loss") + loss_functions = {} + + for output_name, last_layer in self._net.last_layers.items(): + if last_layer == 'softmax': + loss_functions[output_name] = mx.gluon.loss.SoftmaxCrossEntropyLoss() + elif last_layer == 'sigmoid': + loss_functions[output_name] = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() + elif last_layer == 'linear': + loss_functions[output_name] = mx.gluon.loss.L2Loss() + else: + loss_functions[output_name] = mx.gluon.loss.L2Loss() + logging.warning("Invalid last layer, defaulting to L2 loss") speed_period = 50 tic = None @@ -86,13 +89,16 @@ class ${tc.fileNameWithoutEnding}: train_iter.reset() for batch_i, batch in enumerate(train_iter): <#list tc.architectureInputs as input_name> - ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context) + ${input_name}_data = batch.data[${input_name?index}].as_in_context(mx_context) + </#list> + <#list tc.architectureOutputs as output_name> + ${output_name}_label = batch.label[${output_name?index}].as_in_context(mx_context) </#list> - label = batch.label[0].as_in_context(mx_context) with autograd.record(): - output = self._net(${tc.join(tc.architectureInputs, ",")}) - loss = loss_function(output, label) + ${tc.join(tc.architectureOutputs, ", ", "", "_output")} = self._net(${tc.join(tc.architectureInputs, ", ", "", "_data")}) + + loss = <#list tc.architectureOutputs as output_name>loss_functions['${output_name}'](${output_name}_output, ${output_name}_label)<#sep> + </#list> loss.backward() trainer.step(batch_size) @@ -116,26 +122,40 @@ class ${tc.fileNameWithoutEnding}: metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(train_iter): <#list tc.architectureInputs as input_name> - ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context) + ${input_name}_data = batch.data[${input_name?index}].as_in_context(mx_context) </#list> - label = batch.label[0].as_in_context(mx_context) - output = self._net(${tc.join(tc.architectureInputs, ",")}) - predictions = mx.nd.argmax(output, axis=1) - metric.update(preds=predictions, labels=label) + labels = [ + <#list tc.architectureOutputs as output_name>batch.label[${output_name?index}].as_in_context(mx_context)<#sep>, </#list> + ] + + ${tc.join(tc.architectureOutputs, ", ", "", "_output")} = self._net(${tc.join(tc.architectureInputs, ", ", "", "_data")}) + + predictions = [ + <#list tc.architectureOutputs as output_name>mx.nd.argmax(${output_name}_output, axis=1)<#sep>, </#list> + ] + + metric.update(preds=predictions, labels=labels) train_metric_score = metric.get()[1] test_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(test_iter): <#list tc.architectureInputs as input_name> - ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context) + ${input_name}_data = batch.data[${input_name?index}].as_in_context(mx_context) </#list> - label = batch.label[0].as_in_context(mx_context) - output = self._net(${tc.join(tc.architectureInputs, ",")}) - predictions = mx.nd.argmax(output, axis=1) - metric.update(preds=predictions, labels=label) + labels = [ + <#list tc.architectureOutputs as output_name>batch.label[${output_name?index}].as_in_context(mx_context)<#sep>, </#list> + ] + + ${tc.join(tc.architectureOutputs, ", ", "", "_output")} = self._net(${tc.join(tc.architectureInputs, ", ", "", "_data")}) + + predictions = [ + <#list tc.architectureOutputs as output_name>mx.nd.argmax(${output_name}_output, axis=1)<#sep>, </#list> + ] + + metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1] logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score)) diff --git a/src/main/resources/templates/gluon/elements/Output.ftl b/src/main/resources/templates/gluon/elements/Output.ftl index 2d18ef6784e2808cf18de62a22b28ca970bb4b2b..0b48633eb5c6472419f9f8c83f3af18198a76834 100644 --- a/src/main/resources/templates/gluon/elements/Output.ftl +++ b/src/main/resources/templates/gluon/elements/Output.ftl @@ -2,13 +2,17 @@ <#assign mode = definition_mode.toString()> <#if mode == "ARCHITECTURE_DEFINITION"> <#if element.softmaxOutput> - self.last_layer = 'softmax' + self.last_layers['${element.name}'] = 'softmax' <#elseif element.logisticRegressionOutput> - self.last_layer = 'sigmoid' + self.last_layers['${element.name}'] = 'sigmoid' <#elseif element.linearRegressionOutput> - self.last_layer = 'linear' + self.last_layers['${element.name}'] = 'linear' </#if> </#if> <#if mode == "FORWARD_FUNCTION"> + <#if tc.architectureOutputs?size gt 1> + outputs.append(${input}) + <#else> return ${input} + </#if> </#if> diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java index 805d9b1a06a1ef4d9dae82f048b191f0feff786c..347041b1fa6f47b3bc65f8c204a7601654c8b854 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java @@ -127,7 +127,7 @@ public class GenerationTest extends AbstractSymtabTest { Log.getFindings().clear(); String[] args = {"-m", "src/test/resources/invalid_tests", "-r", "MultipleOutputs"}; CNNArch2GluonCli.main(args); - assertTrue(Log.getFindings().size() == 3); + assertTrue(Log.getFindings().size() == 2); } @Test diff --git a/src/test/resources/target_code/CNNNet_Alexnet.py b/src/test/resources/target_code/CNNNet_Alexnet.py index 3c142ca9d3946bf7f57b5f02c2c4806aaa839b78..9d6dc8901bcad040625712ca59eb4763f770e8b3 100644 --- a/src/test/resources/target_code/CNNNet_Alexnet.py +++ b/src/test/resources/target_code/CNNNet_Alexnet.py @@ -71,6 +71,7 @@ class NoNormalization(gluon.HybridBlock): class Net(gluon.HybridBlock): def __init__(self, data_mean=None, data_std=None, **kwargs): super(Net, self).__init__(**kwargs) + self.last_layers = {} with self.name_scope(): if data_mean: assert(data_std) @@ -194,8 +195,7 @@ class Net(gluon.HybridBlock): self.fc8_ = gluon.nn.Dense(units=10, use_bias=True) # fc8_, output shape: {[10,1,1]} - - self.last_layer = 'softmax' + self.last_layers['predictions'] = 'softmax' def hybrid_forward(self, F, data): @@ -261,3 +261,4 @@ class Net(gluon.HybridBlock): dropout7_ = self.dropout7_(relu7_) fc8_ = self.fc8_(dropout7_) return fc8_ + diff --git a/src/test/resources/target_code/CNNNet_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNNet_CifarClassifierNetwork.py index 3b1794a054c3c9c0f16e3d9b6579a97b8a7acdca..d4ecbf96c4d42abf842eceb984e6515fa59b22f6 100644 --- a/src/test/resources/target_code/CNNNet_CifarClassifierNetwork.py +++ b/src/test/resources/target_code/CNNNet_CifarClassifierNetwork.py @@ -71,6 +71,7 @@ class NoNormalization(gluon.HybridBlock): class Net(gluon.HybridBlock): def __init__(self, data_mean=None, data_std=None, **kwargs): super(Net, self).__init__(**kwargs) + self.last_layers = {} with self.name_scope(): if data_mean: assert(data_std) @@ -345,8 +346,7 @@ class Net(gluon.HybridBlock): self.fc32_ = gluon.nn.Dense(units=10, use_bias=True) # fc32_, output shape: {[10,1,1]} - - self.last_layer = 'softmax' + self.last_layers['softmax'] = 'softmax' def hybrid_forward(self, F, data): @@ -454,3 +454,4 @@ class Net(gluon.HybridBlock): dropout31_ = self.dropout31_(fc31_) fc32_ = self.fc32_(dropout31_) return fc32_ + diff --git a/src/test/resources/target_code/CNNNet_VGG16.py b/src/test/resources/target_code/CNNNet_VGG16.py index 6e633348223207514edc342c621bc98ecb33d57e..5fc931c97a0c73ae4d39e22ef50f6e978afa7e46 100644 --- a/src/test/resources/target_code/CNNNet_VGG16.py +++ b/src/test/resources/target_code/CNNNet_VGG16.py @@ -71,6 +71,7 @@ class NoNormalization(gluon.HybridBlock): class Net(gluon.HybridBlock): def __init__(self, data_mean=None, data_std=None, **kwargs): super(Net, self).__init__(**kwargs) + self.last_layers = {} with self.name_scope(): if data_mean: assert(data_std) @@ -222,8 +223,7 @@ class Net(gluon.HybridBlock): self.fc15_ = gluon.nn.Dense(units=1000, use_bias=True) # fc15_, output shape: {[1000,1,1]} - - self.last_layer = 'softmax' + self.last_layers['predictions'] = 'softmax' def hybrid_forward(self, F, data): @@ -281,3 +281,4 @@ class Net(gluon.HybridBlock): dropout15_ = self.dropout15_(relu15_) fc15_ = self.fc15_(dropout15_) return fc15_ + diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py index 01e32cda7cd4a5e1451c24f33f21c606ebef314e..3355e6643cdf81636a24fb7a59cc1c305b0c2a8b 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py @@ -69,15 +69,18 @@ class CNNSupervisedTrainer_Alexnet: trainer = mx.gluon.Trainer(self._net.collect_params(), optimizer, optimizer_params) - if self._net.last_layer == 'softmax': - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss() - elif self._net.last_layer == 'sigmoid': - loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() - elif self._net.last_layer == 'linear': - loss_function = mx.gluon.loss.L2Loss() - else: - loss_function = mx.gluon.loss.L2Loss() - logging.warning("Invalid last_layer, defaulting to L2 loss") + loss_functions = {} + + for output_name, last_layer in self._net.last_layers.items(): + if last_layer == 'softmax': + loss_functions[output_name] = mx.gluon.loss.SoftmaxCrossEntropyLoss() + elif last_layer == 'sigmoid': + loss_functions[output_name] = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() + elif last_layer == 'linear': + loss_functions[output_name] = mx.gluon.loss.L2Loss() + else: + loss_functions[output_name] = mx.gluon.loss.L2Loss() + logging.warning("Invalid last layer, defaulting to L2 loss") speed_period = 50 tic = None @@ -85,12 +88,13 @@ class CNNSupervisedTrainer_Alexnet: for epoch in range(begin_epoch, begin_epoch + num_epoch): train_iter.reset() for batch_i, batch in enumerate(train_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + predictions_label = batch.label[0].as_in_context(mx_context) with autograd.record(): - output = self._net(data) - loss = loss_function(output, label) + predictions_output = self._net(data_data) + + loss = loss_functions['predictions'](predictions_output, predictions_label) loss.backward() trainer.step(batch_size) @@ -113,23 +117,37 @@ class CNNSupervisedTrainer_Alexnet: train_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(train_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + + labels = [ + batch.label[0].as_in_context(mx_context) + ] + + predictions_output = self._net(data_data) - output = self._net(data) - predictions = mx.nd.argmax(output, axis=1) - metric.update(preds=predictions, labels=label) + predictions = [ + mx.nd.argmax(predictions_output, axis=1) + ] + + metric.update(preds=predictions, labels=labels) train_metric_score = metric.get()[1] test_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(test_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + + labels = [ + batch.label[0].as_in_context(mx_context) + ] + + predictions_output = self._net(data_data) + + predictions = [ + mx.nd.argmax(predictions_output, axis=1) + ] - output = self._net(data) - predictions = mx.nd.argmax(output, axis=1) - metric.update(preds=predictions, labels=label) + metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1] logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score)) diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py index 05d4e708be9052ca376b40767fa99739a788da1a..3a4095116bb681e71b26950682a3b002e8b192a4 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py @@ -69,15 +69,18 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: trainer = mx.gluon.Trainer(self._net.collect_params(), optimizer, optimizer_params) - if self._net.last_layer == 'softmax': - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss() - elif self._net.last_layer == 'sigmoid': - loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() - elif self._net.last_layer == 'linear': - loss_function = mx.gluon.loss.L2Loss() - else: - loss_function = mx.gluon.loss.L2Loss() - logging.warning("Invalid last_layer, defaulting to L2 loss") + loss_functions = {} + + for output_name, last_layer in self._net.last_layers.items(): + if last_layer == 'softmax': + loss_functions[output_name] = mx.gluon.loss.SoftmaxCrossEntropyLoss() + elif last_layer == 'sigmoid': + loss_functions[output_name] = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() + elif last_layer == 'linear': + loss_functions[output_name] = mx.gluon.loss.L2Loss() + else: + loss_functions[output_name] = mx.gluon.loss.L2Loss() + logging.warning("Invalid last layer, defaulting to L2 loss") speed_period = 50 tic = None @@ -85,12 +88,13 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: for epoch in range(begin_epoch, begin_epoch + num_epoch): train_iter.reset() for batch_i, batch in enumerate(train_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + softmax_label = batch.label[0].as_in_context(mx_context) with autograd.record(): - output = self._net(data) - loss = loss_function(output, label) + softmax_output = self._net(data_data) + + loss = loss_functions['softmax'](softmax_output, softmax_label) loss.backward() trainer.step(batch_size) @@ -113,23 +117,37 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: train_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(train_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + + labels = [ + batch.label[0].as_in_context(mx_context) + ] + + softmax_output = self._net(data_data) - output = self._net(data) - predictions = mx.nd.argmax(output, axis=1) - metric.update(preds=predictions, labels=label) + predictions = [ + mx.nd.argmax(softmax_output, axis=1) + ] + + metric.update(preds=predictions, labels=labels) train_metric_score = metric.get()[1] test_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(test_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + + labels = [ + batch.label[0].as_in_context(mx_context) + ] + + softmax_output = self._net(data_data) + + predictions = [ + mx.nd.argmax(softmax_output, axis=1) + ] - output = self._net(data) - predictions = mx.nd.argmax(output, axis=1) - metric.update(preds=predictions, labels=label) + metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1] logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score)) diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py index 90728f4bebc3f78f3915b44161e976d75fbe2d2f..f4c124e2dea236cea4c4f8202571d8ed0aae64cd 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py @@ -69,15 +69,18 @@ class CNNSupervisedTrainer_VGG16: trainer = mx.gluon.Trainer(self._net.collect_params(), optimizer, optimizer_params) - if self._net.last_layer == 'softmax': - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss() - elif self._net.last_layer == 'sigmoid': - loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() - elif self._net.last_layer == 'linear': - loss_function = mx.gluon.loss.L2Loss() - else: - loss_function = mx.gluon.loss.L2Loss() - logging.warning("Invalid last_layer, defaulting to L2 loss") + loss_functions = {} + + for output_name, last_layer in self._net.last_layers.items(): + if last_layer == 'softmax': + loss_functions[output_name] = mx.gluon.loss.SoftmaxCrossEntropyLoss() + elif last_layer == 'sigmoid': + loss_functions[output_name] = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() + elif last_layer == 'linear': + loss_functions[output_name] = mx.gluon.loss.L2Loss() + else: + loss_functions[output_name] = mx.gluon.loss.L2Loss() + logging.warning("Invalid last layer, defaulting to L2 loss") speed_period = 50 tic = None @@ -85,12 +88,13 @@ class CNNSupervisedTrainer_VGG16: for epoch in range(begin_epoch, begin_epoch + num_epoch): train_iter.reset() for batch_i, batch in enumerate(train_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + predictions_label = batch.label[0].as_in_context(mx_context) with autograd.record(): - output = self._net(data) - loss = loss_function(output, label) + predictions_output = self._net(data_data) + + loss = loss_functions['predictions'](predictions_output, predictions_label) loss.backward() trainer.step(batch_size) @@ -113,23 +117,37 @@ class CNNSupervisedTrainer_VGG16: train_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(train_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + + labels = [ + batch.label[0].as_in_context(mx_context) + ] + + predictions_output = self._net(data_data) - output = self._net(data) - predictions = mx.nd.argmax(output, axis=1) - metric.update(preds=predictions, labels=label) + predictions = [ + mx.nd.argmax(predictions_output, axis=1) + ] + + metric.update(preds=predictions, labels=labels) train_metric_score = metric.get()[1] test_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(test_iter): - data = batch.data[0].as_in_context(mx_context) - label = batch.label[0].as_in_context(mx_context) + data_data = batch.data[0].as_in_context(mx_context) + + labels = [ + batch.label[0].as_in_context(mx_context) + ] + + predictions_output = self._net(data_data) + + predictions = [ + mx.nd.argmax(predictions_output, axis=1) + ] - output = self._net(data) - predictions = mx.nd.argmax(output, axis=1) - metric.update(preds=predictions, labels=label) + metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1] logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))