diff --git a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl index 01b3ab7dbaa7e58b68128dadc7a6d9b67ef21728..e47a002301975b7b79d999eb12012c3d0c3dd62a 100644 --- a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl +++ b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl @@ -255,16 +255,17 @@ class ${tc.fileNameWithoutEnding}: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py index 5d38cbb0b27163467596847f969c89aea1a61d61..2f17688bd4e481792071b45dd78a0d88d57b15c4 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py @@ -254,16 +254,17 @@ class CNNSupervisedTrainer_Alexnet: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py index 4e52f0d31a47ca8453f0c37a13e5c283f6ce3681..54f4c69af69f6fa5f68bf9ab0cc81e98307e9148 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py @@ -254,16 +254,17 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Invariant.py b/src/test/resources/target_code/CNNSupervisedTrainer_Invariant.py index a31eac6b37df005c713e5620f8a13dbda468acc7..cc63d68919f623789628b31ff137922f303bfe93 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_Invariant.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_Invariant.py @@ -247,16 +247,17 @@ class CNNSupervisedTrainer_Invariant: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_MultipleStreams.py b/src/test/resources/target_code/CNNSupervisedTrainer_MultipleStreams.py index f8db6d2dfd7d2403aa0923ed157282455562aadb..0903c286d7d289c8bbb24bb291213a7764c4aaec 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_MultipleStreams.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_MultipleStreams.py @@ -247,16 +247,17 @@ class CNNSupervisedTrainer_MultipleStreams: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_RNNencdec.py b/src/test/resources/target_code/CNNSupervisedTrainer_RNNencdec.py index 963fb7bd94a1e880fb807c5ea45fdf4e2580c8e9..53765c13399c2bfaa684cf9a82d0452afb3a076d 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_RNNencdec.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_RNNencdec.py @@ -247,16 +247,17 @@ class CNNSupervisedTrainer_RNNencdec: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_RNNsearch.py b/src/test/resources/target_code/CNNSupervisedTrainer_RNNsearch.py index fd12e18f6052fe40823ea1c0e8784110f103e9ea..068ca8c466d1934a8ed05b950fed6917481e5cae 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_RNNsearch.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_RNNsearch.py @@ -247,16 +247,17 @@ class CNNSupervisedTrainer_RNNsearch: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_RNNtest.py b/src/test/resources/target_code/CNNSupervisedTrainer_RNNtest.py index f7d34328c12925afe98eee62ae0e94ed38a95752..052800cbb9bc67eafc94abc0d9503bde74b8055a 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_RNNtest.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_RNNtest.py @@ -247,16 +247,17 @@ class CNNSupervisedTrainer_RNNtest: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_ResNeXt50.py b/src/test/resources/target_code/CNNSupervisedTrainer_ResNeXt50.py index b9f60cea5041c744cf73ed614e11ec38a2c8208f..957ae83997a7177262691a59318cdb977ff76106 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_ResNeXt50.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_ResNeXt50.py @@ -247,16 +247,17 @@ class CNNSupervisedTrainer_ResNeXt50: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Show_attend_tell.py b/src/test/resources/target_code/CNNSupervisedTrainer_Show_attend_tell.py index c428ed0c1cf9de80004c4562574673ed8f822fd2..a223bc409a85a0d95c66b6244ec272124df2c28e 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_Show_attend_tell.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_Show_attend_tell.py @@ -247,16 +247,17 @@ class CNNSupervisedTrainer_Show_attend_tell: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_ThreeInputCNN_M14.py b/src/test/resources/target_code/CNNSupervisedTrainer_ThreeInputCNN_M14.py index d4c05848f279cedddf450623fdad13faccce7663..7ed3647dc11e8e822364e8b94aae163421438b67 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_ThreeInputCNN_M14.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_ThreeInputCNN_M14.py @@ -247,16 +247,17 @@ class CNNSupervisedTrainer_ThreeInputCNN_M14: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py index 9df9611a1ad66077e355dd5dfab981fb122e42da..7f4d8a87e3e11d9950346eed91790d292c4f6a28 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py @@ -254,16 +254,17 @@ class CNNSupervisedTrainer_VGG16: sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else [] loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1 + batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0 if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'softmax_cross_entropy_ignore_indices': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False - loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel) + loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': - loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel) + loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1':