Commit 7ba36d4c authored by Julian Treiber's avatar Julian Treiber

fix for DiceLoss

parent f8468f9d
Pipeline #236672 failed with stages
in 2 minutes and 42 seconds
...@@ -298,7 +298,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -298,7 +298,7 @@ class ${tc.fileNameWithoutEnding}:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -297,7 +297,7 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -297,7 +297,7 @@ class CNNSupervisedTrainer_Alexnet:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -297,7 +297,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -297,7 +297,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_Invariant: ...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_Invariant:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_MultipleStreams: ...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_MultipleStreams:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_RNNencdec: ...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_RNNencdec:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_RNNsearch: ...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_RNNsearch:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_RNNtest: ...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_RNNtest:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_ResNeXt50: ...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_ResNeXt50:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_Show_attend_tell: ...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_Show_attend_tell:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_ThreeInputCNN_M14: ...@@ -290,7 +290,7 @@ class CNNSupervisedTrainer_ThreeInputCNN_M14:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
...@@ -297,7 +297,7 @@ class CNNSupervisedTrainer_VGG16: ...@@ -297,7 +297,7 @@ class CNNSupervisedTrainer_VGG16:
elif loss == 'dice_loss': elif loss == 'dice_loss':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis) loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment