diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 191a0cb1eb1eec097cc3fdee98c5107ead9f9dd1..fe392d52ddb44e928a7dda9c2235d6dd8a05ce2f 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -107,6 +107,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue; FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue; + LossAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"loss_axis" ":" value:IntegerValue; + BatchAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"batch_axis" ":" value:IntegerValue; IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue; MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue; LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 531b6417f07b1a1656568c1e4f625cdb9a402175..2fd8436a63d4774e43df49849b96056e0b9cdee6 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -47,6 +47,8 @@ class ParameterAlgorithmMapping { ASTLossEntry.class, ASTLossWeightsEntry.class, ASTSparseLabelEntry.class, + ASTLossAxisEntry.class, + ASTBatchAxisEntry.class, ASTFromLogitsEntry.class, ASTIgnoreIndicesEntry.class, ASTMarginEntry.class, diff --git a/src/test/resources/valid_tests/FullConfig.cnnt b/src/test/resources/valid_tests/FullConfig.cnnt index 1d054f6d9d81ed48eb49d910b7b8c109850ebbff..76319ac420cb0cfe3c16876f8983f7bef1ee49c4 100644 --- a/src/test/resources/valid_tests/FullConfig.cnnt +++ b/src/test/resources/valid_tests/FullConfig.cnnt @@ -7,6 +7,8 @@ configuration FullConfig{ loss: softmax_cross_entropy{ sparse_label: true from_logits: true + loss_axis : -1 + batch_axis: 0 } context : gpu normalize : true