Commit 687465cb authored by Julian Treiber's avatar Julian Treiber
Browse files

added batch_axis parameter for CrossEntropy losses

parent c0f3254a
Pipeline #235630 passed with stages
in 7 minutes and 58 seconds
...@@ -108,6 +108,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -108,6 +108,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue; SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue; FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
LossAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"loss_axis" ":" value:IntegerValue; LossAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"loss_axis" ":" value:IntegerValue;
BatchAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"batch_axis" ":" value:IntegerValue;
IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue; IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue; MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue; LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
......
...@@ -48,6 +48,7 @@ class ParameterAlgorithmMapping { ...@@ -48,6 +48,7 @@ class ParameterAlgorithmMapping {
ASTLossWeightsEntry.class, ASTLossWeightsEntry.class,
ASTSparseLabelEntry.class, ASTSparseLabelEntry.class,
ASTLossAxisEntry.class, ASTLossAxisEntry.class,
ASTBatchAxisEntry.class,
ASTFromLogitsEntry.class, ASTFromLogitsEntry.class,
ASTIgnoreIndicesEntry.class, ASTIgnoreIndicesEntry.class,
ASTMarginEntry.class, ASTMarginEntry.class,
......
...@@ -8,6 +8,7 @@ configuration FullConfig{ ...@@ -8,6 +8,7 @@ configuration FullConfig{
sparse_label: true sparse_label: true
from_logits: true from_logits: true
loss_axis : -1 loss_axis : -1
batch_axis: 0
} }
context : gpu context : gpu
normalize : true normalize : true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment