diff --git a/pom.xml b/pom.xml index d7d5e507a748f9f1aa69ee32d4565bde512f6b8d..ddeaf6685f8e06ba6677bd5d7e9069aa25407300 100644 --- a/pom.xml +++ b/pom.xml @@ -18,7 +18,7 @@ de.monticore.lang.monticar cnn-train - 0.3.8-SNAPSHOT + 0.3.9-SNAPSHOT @@ -28,7 +28,7 @@ 5.0.1 1.7.8 0.0.6 - 0.0.14-20180704.113055-2 + 0.0.19-SNAPSHOT 18.0 diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 29ff89b10e12a0656cdcc6ce01abb6a44e8991c2..4f1d7cb05de8dd72c72c6f22034919d04ccc47ce 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -30,6 +30,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue; + CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue; + LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue; NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue; OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue; TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue; @@ -52,6 +54,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface BleuEntry extends Entry; ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue; + EvalTrainEntry implements ConfigEntry = name:"eval_train" ":" value:BooleanValue; + LRPolicyValue implements ConfigValue =(fixed:"fixed" | step:"step" | exp:"exp" @@ -81,6 +85,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface SoftmaxCrossEntropyEntry extends Entry; SoftmaxCrossEntropyLoss implements LossValue = name:"softmax_cross_entropy" ("{" params:SoftmaxCrossEntropyEntry* "}")?; + interface SoftmaxCrossEntropyIgnoreIndicesEntry extends Entry; + SoftmaxCrossEntropyIgnoreIndicesLoss implements LossValue = name:"softmax_cross_entropy_ignore_indices" ("{" params:SoftmaxCrossEntropyIgnoreIndicesEntry* "}")?; + SigmoidBinaryCrossEntropyLoss implements LossValue = name:"sigmoid_binary_cross_entropy" ("{" params:Entry* "}")?; interface HingeEntry extends Entry; @@ -95,8 +102,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface KullbackLeiblerEntry extends Entry; KullbackLeiblerLoss implements LossValue = name:"kullback_leibler" ("{" params:KullbackLeiblerEntry* "}")?; - SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry = name:"sparse_label" ":" value:BooleanValue; - FromLogitsEntry implements SoftmaxCrossEntropyEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue; + SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue; + FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue; + IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue; MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue; LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue; @@ -140,6 +148,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue; RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue; + UseTeacherForcing implements ConfigEntry = name:"use_teacher_forcing" ":" value:BooleanValue; + // Visual attention Extension SaveAttentionImage implements ConfigEntry = name:"save_attention_image" ":" value:BooleanValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index b23ebb03e8515cf48f3161eb79c722b6888e33dc..0bd42320fcbaa8e67727d58f20560eefb789b576 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -37,7 +37,10 @@ class ParameterAlgorithmMapping { private static final List EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList( ASTBatchSizeEntry.class, ASTLoadCheckpointEntry.class, + ASTCheckpointPeriodEntry.class, + ASTLogPeriodEntry.class, ASTEvalMetricEntry.class, + ASTEvalTrainEntry.class, ASTExcludeBleuEntry.class, ASTNormalizeEntry.class, ASTNumEpochEntry.class, @@ -45,10 +48,12 @@ class ParameterAlgorithmMapping { ASTLossWeightsEntry.class, ASTSparseLabelEntry.class, ASTFromLogitsEntry.class, + ASTIgnoreIndicesEntry.class, ASTMarginEntry.class, ASTLabelFormatEntry.class, ASTRhoEntry.class, ASTPreprocessingEntry.class, + ASTUseTeacherForcing.class, ASTSaveAttentionImage.class ); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index 8b27c04e985aa9462b2425c22b13fc4372371ab1..1e65020c27c147a95dca765956cfc5a56a764c8b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -132,6 +132,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void endVisit(ASTCheckpointPeriodEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTLogPeriodEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void endVisit(ASTNormalizeEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); @@ -165,6 +181,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { processMultiParamConfigEndVisit(node); } + @Override + public void endVisit(ASTEvalTrainEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForBoolean(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void visit(ASTLossEntry node) { LossSymbol loss = new LossSymbol(node.getValue().getName()); @@ -324,6 +348,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { .collect(Collectors.toList()); } + @Override + public void endVisit(ASTUseTeacherForcing node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForBoolean(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void endVisit(ASTSaveAttentionImage node) { EntrySymbol entry = new EntrySymbol(node.getName());