Commit b2cecf35 authored by Julian Treiber's avatar Julian Treiber

added loss_axis to Trainer template

parent 66d4211e
Pipeline #235145 failed with stages
in 1 minute and 11 seconds
...@@ -254,16 +254,17 @@ class ${tc.fileNameWithoutEnding}: ...@@ -254,16 +254,17 @@ class ${tc.fileNameWithoutEnding}:
margin = loss_params['margin'] if 'margin' in loss_params else 1.0 margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
lossAxis = loss_params['lossAxis'] if 'lossAxis' in loss_params else -1
if loss == 'softmax_cross_entropy': if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel) loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=lossAxis, from_logits=fromLogits, sparse_label=sparseLabel)
elif loss == 'softmax_cross_entropy_ignore_indices': elif loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
elif loss == 'sigmoid_binary_cross_entropy': elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif loss == 'cross_entropy': elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(sparse_label=sparseLabel) loss_function = CrossEntropyLoss(axis=lossAxis, sparse_label=sparseLabel)
elif loss == 'l2': elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1': elif loss == 'l1':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment