Commit fa60d6bc authored by Sebastian N.'s avatar Sebastian N.

Added parameters checkpoint_period, log_period and eval_train

parent 14beccf5
Pipeline #221317 failed with stages
......@@ -30,6 +30,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue;
LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
......@@ -52,6 +54,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
EvalTrainEntry implements ConfigEntry = name:"eval_train" ":" value:BooleanValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
......
......@@ -37,7 +37,10 @@ class ParameterAlgorithmMapping {
private static final List<Class> EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList(
ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class,
ASTCheckpointPeriodEntry.class,
ASTLogPeriodEntry.class,
ASTEvalMetricEntry.class,
ASTEvalTrainEntry.class,
ASTExcludeBleuEntry.class,
ASTNormalizeEntry.class,
ASTNumEpochEntry.class,
......
......@@ -131,6 +131,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTCheckpointPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLogPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTNormalizeEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......@@ -164,6 +180,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node);
}
@Override
public void endVisit(ASTEvalTrainEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTLossEntry node) {
LossSymbol loss = new LossSymbol(node.getValue().getName());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment