Commit 5908d1e0 authored by Julian Dierkes's avatar Julian Dierkes

Merge branch 'develop' of...

Merge branch 'develop' of git.rwth-aachen.de:monticore/EmbeddedMontiArc/languages/CNNTrainLang into develop
parents 40a1a42a fa60d6bc
Pipeline #223027 failed with stages
in 20 seconds
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId> <artifactId>cnn-train</artifactId>
<version>0.3.8-SNAPSHOT</version> <version>0.3.9-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
<monticore.version>5.0.1</monticore.version> <monticore.version>5.0.1</monticore.version>
<se-commons.version>1.7.8</se-commons.version> <se-commons.version>1.7.8</se-commons.version>
<mc.grammars.assembly.version>0.0.6</mc.grammars.assembly.version> <mc.grammars.assembly.version>0.0.6</mc.grammars.assembly.version>
<Common-MontiCar.version>0.0.14-20180704.113055-2</Common-MontiCar.version> <Common-MontiCar.version>0.0.19-SNAPSHOT</Common-MontiCar.version>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version> <guava.version>18.0</guava.version>
......
...@@ -30,6 +30,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -30,6 +30,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue; LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue;
LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue; NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue; OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue; TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
...@@ -52,6 +54,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -52,6 +54,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface BleuEntry extends Entry; interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue; ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
EvalTrainEntry implements ConfigEntry = name:"eval_train" ":" value:BooleanValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed" LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step" | step:"step"
| exp:"exp" | exp:"exp"
...@@ -81,6 +85,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -81,6 +85,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface SoftmaxCrossEntropyEntry extends Entry; interface SoftmaxCrossEntropyEntry extends Entry;
SoftmaxCrossEntropyLoss implements LossValue = name:"softmax_cross_entropy" ("{" params:SoftmaxCrossEntropyEntry* "}")?; SoftmaxCrossEntropyLoss implements LossValue = name:"softmax_cross_entropy" ("{" params:SoftmaxCrossEntropyEntry* "}")?;
interface SoftmaxCrossEntropyIgnoreIndicesEntry extends Entry;
SoftmaxCrossEntropyIgnoreIndicesLoss implements LossValue = name:"softmax_cross_entropy_ignore_indices" ("{" params:SoftmaxCrossEntropyIgnoreIndicesEntry* "}")?;
SigmoidBinaryCrossEntropyLoss implements LossValue = name:"sigmoid_binary_cross_entropy" ("{" params:Entry* "}")?; SigmoidBinaryCrossEntropyLoss implements LossValue = name:"sigmoid_binary_cross_entropy" ("{" params:Entry* "}")?;
interface HingeEntry extends Entry; interface HingeEntry extends Entry;
...@@ -95,8 +102,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -95,8 +102,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface KullbackLeiblerEntry extends Entry; interface KullbackLeiblerEntry extends Entry;
KullbackLeiblerLoss implements LossValue = name:"kullback_leibler" ("{" params:KullbackLeiblerEntry* "}")?; KullbackLeiblerLoss implements LossValue = name:"kullback_leibler" ("{" params:KullbackLeiblerEntry* "}")?;
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry = name:"sparse_label" ":" value:BooleanValue; SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue; FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue; MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue; LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
...@@ -140,6 +148,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -140,6 +148,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue; ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue; RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue;
UseTeacherForcing implements ConfigEntry = name:"use_teacher_forcing" ":" value:BooleanValue;
// Visual attention Extension // Visual attention Extension
SaveAttentionImage implements ConfigEntry = name:"save_attention_image" ":" value:BooleanValue; SaveAttentionImage implements ConfigEntry = name:"save_attention_image" ":" value:BooleanValue;
......
...@@ -37,7 +37,10 @@ class ParameterAlgorithmMapping { ...@@ -37,7 +37,10 @@ class ParameterAlgorithmMapping {
private static final List<Class> EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList( private static final List<Class> EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList(
ASTBatchSizeEntry.class, ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class, ASTLoadCheckpointEntry.class,
ASTCheckpointPeriodEntry.class,
ASTLogPeriodEntry.class,
ASTEvalMetricEntry.class, ASTEvalMetricEntry.class,
ASTEvalTrainEntry.class,
ASTExcludeBleuEntry.class, ASTExcludeBleuEntry.class,
ASTNormalizeEntry.class, ASTNormalizeEntry.class,
ASTNumEpochEntry.class, ASTNumEpochEntry.class,
...@@ -45,10 +48,12 @@ class ParameterAlgorithmMapping { ...@@ -45,10 +48,12 @@ class ParameterAlgorithmMapping {
ASTLossWeightsEntry.class, ASTLossWeightsEntry.class,
ASTSparseLabelEntry.class, ASTSparseLabelEntry.class,
ASTFromLogitsEntry.class, ASTFromLogitsEntry.class,
ASTIgnoreIndicesEntry.class,
ASTMarginEntry.class, ASTMarginEntry.class,
ASTLabelFormatEntry.class, ASTLabelFormatEntry.class,
ASTRhoEntry.class, ASTRhoEntry.class,
ASTPreprocessingEntry.class, ASTPreprocessingEntry.class,
ASTUseTeacherForcing.class,
ASTSaveAttentionImage.class ASTSaveAttentionImage.class
); );
......
...@@ -132,6 +132,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -132,6 +132,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry); configuration.getEntryMap().put(node.getName(), entry);
} }
@Override
public void endVisit(ASTCheckpointPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLogPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void endVisit(ASTNormalizeEntry node) { public void endVisit(ASTNormalizeEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
...@@ -165,6 +181,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -165,6 +181,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node); processMultiParamConfigEndVisit(node);
} }
@Override
public void endVisit(ASTEvalTrainEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void visit(ASTLossEntry node) { public void visit(ASTLossEntry node) {
LossSymbol loss = new LossSymbol(node.getValue().getName()); LossSymbol loss = new LossSymbol(node.getValue().getName());
...@@ -324,6 +348,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -324,6 +348,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@Override
public void endVisit(ASTUseTeacherForcing node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void endVisit(ASTSaveAttentionImage node) { public void endVisit(ASTSaveAttentionImage node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment