diff --git a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl index bf67e888db8d9d4b657d56350712a7c9b84d810a..057a647071f3ec0aa9434a2d34171a5e00d48819 100644 --- a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl +++ b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl @@ -58,6 +58,35 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss): loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label)) return loss.mean(axis=self._batch_axis, exclude=True) +class DiceLoss(gluon.loss.Loss): + def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None, + batch_axis=0, **kwargs): + super(DiceLoss, self).__init__(weight, batch_axis, **kwargs) + self._axis = axis + self._sparse_label = sparse_label + self._from_logits = from_logits + + def dice_loss(self, F, pred, label): + smooth = 1. + pred_y = F.argmax(pred, axis = self._axis) + intersection = pred_y * label + score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \ + / (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth) + + return - F.log(score) + + def hybrid_forward(self, F, pred, label, sample_weight=None): + if not self._from_logits: + pred = F.log_softmax(pred, self._axis) + if self._sparse_label: + loss = -F.pick(pred, label, axis=self._axis, keepdims=True) + else: + label = gluon.loss._reshape_like(F, label, pred) + loss = -F.sum(pred*label, axis=self._axis, keepdims=True) + loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight) + diceloss = self.dice_loss(F, pred, label) + return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss + @mx.metric.register class BLEU(mx.metric.EvalMetric): N = 4 @@ -261,11 +290,15 @@ class ${tc.fileNameWithoutEnding}: loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) + elif loss == 'dice_loss': + fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False + dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None + loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1':