Commit c3dc2a23 authored by Julian Treiber's avatar Julian Treiber

added metric loss accuracy_ignore_label

parent 2990a673
...@@ -87,15 +87,35 @@ class DiceLoss(gluon.loss.Loss): ...@@ -87,15 +87,35 @@ class DiceLoss(gluon.loss.Loss):
diceloss = self.dice_loss(F, pred, label) diceloss = self.dice_loss(F, pred, label)
return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss
class SoftmaxCrossEntropyLossIgnoreLabel(gluon.loss.Loss):
def __init__(self, axis=-1, from_logits=False, weight=None,
batch_axis=0, ignore_label=255, **kwargs):
super(SoftmaxCrossEntropyLossIgnoreLabel, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._from_logits = from_logits
self._ignore_label = ignore_label
def hybrid_forward(self, F, output, label, sample_weight=None):
if not self._from_logits:
output = F.log_softmax(output, axis=self._axis)
valid_label_map = (label != self._ignore_label)
loss = -(F.pick(output, label, axis=self._axis, keepdims=True) * valid_label_map )
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.sum(loss) / F.sum(valid_label_map)
@mx.metric.register @mx.metric.register
class ACCURACY_IGNORE_LABEL(mx.metric.EvalMetric): class ACCURACY_IGNORE_LABEL(mx.metric.EvalMetric):
def __init__(self, axis=1, ignore_label=255, name='accuracy', """Ignores a label when computing accuracy.
"""
def __init__(self, axis=1, metric_ignore_label=255, name='accuracy',
output_names=None, label_names=None): output_names=None, label_names=None):
super(ACCURACY_IGNORE_LABEL, self).__init__( super(ACCURACY_IGNORE_LABEL, self).__init__(
name, axis=axis, name, axis=axis,
output_names=output_names, label_names=label_names) output_names=output_names, label_names=label_names)
self.axis = axis self.axis = axis
self.ignore_label = ignore_label self.ignore_label = metric_ignore_label
def update(self, labels, preds): def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds) mx.metric.check_label_shapes(labels, preds)
...@@ -328,6 +348,10 @@ class ${tc.fileNameWithoutEnding}: ...@@ -328,6 +348,10 @@ class ${tc.fileNameWithoutEnding}:
elif loss == 'dice_loss': elif loss == 'dice_loss':
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_label':
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_ignore_label = loss_params['loss_ignore_label'] if 'loss_ignore_label' in loss_params else None
loss_function = SoftmaxCrossEntropyLossIgnoreLabel(axis=loss_axis, ignore_label=loss_ignore_label, weight=loss_weight, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -71,7 +71,7 @@ if __name__ == "__main__": ...@@ -71,7 +71,7 @@ if __name__ == "__main__":
'axis': ${config.evalMetric.axis}, 'axis': ${config.evalMetric.axis},
</#if> </#if>
<#if (config.evalMetric.exclude)??> <#if (config.evalMetric.exclude)??>
'ignore_label': ${config.evalMetric.ignore_label}, 'metric_ignore_label': ${config.evalMetric.metric_ignore_label},
</#if> </#if>
}, },
</#if> </#if>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment