Commit 22388100 authored by Christian Fuß's avatar Christian Fuß
Browse files

added SoftmaxCrossentropyIgnoreIndices loss. Added use_teacher_forcing flag

parent 3889c267
Pipeline #207386 passed with stages
in 9 minutes and 32 seconds
......@@ -81,6 +81,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface SoftmaxCrossEntropyEntry extends Entry;
SoftmaxCrossEntropyLoss implements LossValue = name:"softmax_cross_entropy" ("{" params:SoftmaxCrossEntropyEntry* "}")?;
interface SoftmaxCrossEntropyIgnoreIndicesEntry extends Entry;
SoftmaxCrossEntropyIgnoreIndicesLoss implements LossValue = name:"softmax_cross_entropy_ignore_indices" ("{" params:SoftmaxCrossEntropyIgnoreIndicesEntry* "}")?;
SigmoidBinaryCrossEntropyLoss implements LossValue = name:"sigmoid_binary_cross_entropy" ("{" params:Entry* "}")?;
interface HingeEntry extends Entry;
......@@ -95,8 +98,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface KullbackLeiblerEntry extends Entry;
KullbackLeiblerLoss implements LossValue = name:"kullback_leibler" ("{" params:KullbackLeiblerEntry* "}")?;
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
......@@ -140,6 +144,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue;
UseTeacherForcing implements ConfigEntry = name:"use_teacher_forcing" ":" value:BooleanValue;
// Visual attention Extension
SaveAttentionImage implements ConfigEntry = name:"save_attention_image" ":" value:BooleanValue;
......
......@@ -45,10 +45,12 @@ class ParameterAlgorithmMapping {
ASTLossWeightsEntry.class,
ASTSparseLabelEntry.class,
ASTFromLogitsEntry.class,
ASTIgnoreIndicesEntry.class,
ASTMarginEntry.class,
ASTLabelFormatEntry.class,
ASTRhoEntry.class,
ASTPreprocessingEntry.class,
ASTUseTeacherForcing.class,
ASTSaveAttentionImage.class
);
......
......@@ -323,6 +323,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.collect(Collectors.toList());
}
@Override
public void endVisit(ASTUseTeacherForcing node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTSaveAttentionImage node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment