diff --git a/pom.xml b/pom.xml
index d7d5e507a748f9f1aa69ee32d4565bde512f6b8d..ddeaf6685f8e06ba6677bd5d7e9069aa25407300 100644
--- a/pom.xml
+++ b/pom.xml
@@ -18,7 +18,7 @@
de.monticore.lang.monticar
cnn-train
- 0.3.8-SNAPSHOT
+ 0.3.9-SNAPSHOT
@@ -28,7 +28,7 @@
5.0.1
1.7.8
0.0.6
- 0.0.14-20180704.113055-2
+ 0.0.19-SNAPSHOT
18.0
diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
index 29ff89b10e12a0656cdcc6ce01abb6a44e8991c2..4f1d7cb05de8dd72c72c6f22034919d04ccc47ce 100644
--- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
+++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
@@ -30,6 +30,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
+ CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue;
+ LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
@@ -52,6 +54,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
+ EvalTrainEntry implements ConfigEntry = name:"eval_train" ":" value:BooleanValue;
+
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
@@ -81,6 +85,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface SoftmaxCrossEntropyEntry extends Entry;
SoftmaxCrossEntropyLoss implements LossValue = name:"softmax_cross_entropy" ("{" params:SoftmaxCrossEntropyEntry* "}")?;
+ interface SoftmaxCrossEntropyIgnoreIndicesEntry extends Entry;
+ SoftmaxCrossEntropyIgnoreIndicesLoss implements LossValue = name:"softmax_cross_entropy_ignore_indices" ("{" params:SoftmaxCrossEntropyIgnoreIndicesEntry* "}")?;
+
SigmoidBinaryCrossEntropyLoss implements LossValue = name:"sigmoid_binary_cross_entropy" ("{" params:Entry* "}")?;
interface HingeEntry extends Entry;
@@ -95,8 +102,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface KullbackLeiblerEntry extends Entry;
KullbackLeiblerLoss implements LossValue = name:"kullback_leibler" ("{" params:KullbackLeiblerEntry* "}")?;
- SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry = name:"sparse_label" ":" value:BooleanValue;
- FromLogitsEntry implements SoftmaxCrossEntropyEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
+ SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue;
+ FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
+ IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
@@ -140,6 +148,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue;
+ UseTeacherForcing implements ConfigEntry = name:"use_teacher_forcing" ":" value:BooleanValue;
+
// Visual attention Extension
SaveAttentionImage implements ConfigEntry = name:"save_attention_image" ":" value:BooleanValue;
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java
index b23ebb03e8515cf48f3161eb79c722b6888e33dc..0bd42320fcbaa8e67727d58f20560eefb789b576 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java
@@ -37,7 +37,10 @@ class ParameterAlgorithmMapping {
private static final List EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList(
ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class,
+ ASTCheckpointPeriodEntry.class,
+ ASTLogPeriodEntry.class,
ASTEvalMetricEntry.class,
+ ASTEvalTrainEntry.class,
ASTExcludeBleuEntry.class,
ASTNormalizeEntry.class,
ASTNumEpochEntry.class,
@@ -45,10 +48,12 @@ class ParameterAlgorithmMapping {
ASTLossWeightsEntry.class,
ASTSparseLabelEntry.class,
ASTFromLogitsEntry.class,
+ ASTIgnoreIndicesEntry.class,
ASTMarginEntry.class,
ASTLabelFormatEntry.class,
ASTRhoEntry.class,
ASTPreprocessingEntry.class,
+ ASTUseTeacherForcing.class,
ASTSaveAttentionImage.class
);
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
index 8b27c04e985aa9462b2425c22b13fc4372371ab1..1e65020c27c147a95dca765956cfc5a56a764c8b 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
@@ -132,6 +132,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
+ @Override
+ public void endVisit(ASTCheckpointPeriodEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForInteger(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
+ @Override
+ public void endVisit(ASTLogPeriodEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForInteger(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
@Override
public void endVisit(ASTNormalizeEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
@@ -165,6 +181,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node);
}
+ @Override
+ public void endVisit(ASTEvalTrainEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForBoolean(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
@Override
public void visit(ASTLossEntry node) {
LossSymbol loss = new LossSymbol(node.getValue().getName());
@@ -324,6 +348,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.collect(Collectors.toList());
}
+ @Override
+ public void endVisit(ASTUseTeacherForcing node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForBoolean(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
@Override
public void endVisit(ASTSaveAttentionImage node) {
EntrySymbol entry = new EntrySymbol(node.getName());