From c3dc2a23e0edd60fec9b8d20a08f21ed4528b89c Mon Sep 17 00:00:00 2001 From: "julian.treiber" Date: Sat, 7 Mar 2020 09:47:30 +0100 Subject: [PATCH] added metric loss accuracy_ignore_label --- .../templates/gluon/CNNSupervisedTrainer.ftl | 28 +++++++++++++++++-- .../resources/templates/gluon/CNNTrainer.ftl | 2 +- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl index df5b75a..5dfbe4a 100644 --- a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl +++ b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl @@ -87,15 +87,35 @@ class DiceLoss(gluon.loss.Loss): diceloss = self.dice_loss(F, pred, label) return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss +class SoftmaxCrossEntropyLossIgnoreLabel(gluon.loss.Loss): + def __init__(self, axis=-1, from_logits=False, weight=None, + batch_axis=0, ignore_label=255, **kwargs): + super(SoftmaxCrossEntropyLossIgnoreLabel, self).__init__(weight, batch_axis, **kwargs) + self._axis = axis + self._from_logits = from_logits + self._ignore_label = ignore_label + + def hybrid_forward(self, F, output, label, sample_weight=None): + if not self._from_logits: + output = F.log_softmax(output, axis=self._axis) + + valid_label_map = (label != self._ignore_label) + loss = -(F.pick(output, label, axis=self._axis, keepdims=True) * valid_label_map ) + + loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight) + return F.sum(loss) / F.sum(valid_label_map) + @mx.metric.register class ACCURACY_IGNORE_LABEL(mx.metric.EvalMetric): - def __init__(self, axis=1, ignore_label=255, name='accuracy', + """Ignores a label when computing accuracy. + """ + def __init__(self, axis=1, metric_ignore_label=255, name='accuracy', output_names=None, label_names=None): super(ACCURACY_IGNORE_LABEL, self).__init__( name, axis=axis, output_names=output_names, label_names=label_names) self.axis = axis - self.ignore_label = ignore_label + self.ignore_label = metric_ignore_label def update(self, labels, preds): mx.metric.check_label_shapes(labels, preds) @@ -328,6 +348,10 @@ class ${tc.fileNameWithoutEnding}: elif loss == 'dice_loss': loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis) + elif loss == 'softmax_cross_entropy_ignore_label': + loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None + loss_ignore_label = loss_params['loss_ignore_label'] if 'loss_ignore_label' in loss_params else None + loss_function = SoftmaxCrossEntropyLossIgnoreLabel(axis=loss_axis, ignore_label=loss_ignore_label, weight=loss_weight, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/main/resources/templates/gluon/CNNTrainer.ftl b/src/main/resources/templates/gluon/CNNTrainer.ftl index 48d5be2..5b5e79b 100644 --- a/src/main/resources/templates/gluon/CNNTrainer.ftl +++ b/src/main/resources/templates/gluon/CNNTrainer.ftl @@ -71,7 +71,7 @@ if __name__ == "__main__": 'axis': ${config.evalMetric.axis}, <#if (config.evalMetric.exclude)??> - 'ignore_label': ${config.evalMetric.ignore_label}, + 'metric_ignore_label': ${config.evalMetric.metric_ignore_label}, }, -- GitLab