Commit b9a79c99 authored by Julian Treiber's avatar Julian Treiber

adjusted test for loss_weight gluon

parent 2cc6d4c6
......@@ -282,8 +282,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......
......@@ -295,8 +295,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment