Commit 687465cb authored by Julian Treiber's avatar Julian Treiber

added batch_axis parameter for CrossEntropy losses

parent c0f3254a
Pipeline #235630 passed with stages
in 7 minutes and 58 seconds
......@@ -108,6 +108,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
LossAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"loss_axis" ":" value:IntegerValue;
BatchAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"batch_axis" ":" value:IntegerValue;
IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
......
......@@ -48,6 +48,7 @@ class ParameterAlgorithmMapping {
ASTLossWeightsEntry.class,
ASTSparseLabelEntry.class,
ASTLossAxisEntry.class,
ASTBatchAxisEntry.class,
ASTFromLogitsEntry.class,
ASTIgnoreIndicesEntry.class,
ASTMarginEntry.class,
......
......@@ -8,6 +8,7 @@ configuration FullConfig{
sparse_label: true
from_logits: true
loss_axis : -1
batch_axis: 0
}
context : gpu
normalize : true
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment