Commit dd274e3f authored by Julian Treiber's avatar Julian Treiber

updated tests for DiceLoss

parent 55bed0c3
......@@ -50,13 +50,42 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
label = gluon.loss._reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i))
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
class DiceLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
def dice_loss(self, F, pred, label):
smooth = 1.
pred_y = F.argmax(pred, axis = self._axis)
intersection = pred_y * label
score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
/ (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)
return - F.log(score)
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
diceloss = self.dice_loss(F, pred, label)
return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
N = 4
......@@ -244,14 +273,18 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......@@ -323,7 +356,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
train_test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......@@ -394,7 +427,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......
......@@ -50,13 +50,42 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
label = gluon.loss._reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
class DiceLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
def dice_loss(self, F, pred, label):
smooth = 1.
pred_y = F.argmax(pred, axis = self._axis)
intersection = pred_y * label
score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
/ (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)
return - F.log(score)
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
diceloss = self.dice_loss(F, pred, label)
return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
N = 4
......@@ -260,11 +289,15 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment