Commit 87f84f4c authored by Sebastian N.'s avatar Sebastian N.

Added shuffle_data and clip_global_grad_norm params

parent 1422c428
Pipeline #225094 failed with stages
in 6 seconds
......@@ -33,6 +33,10 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue;
LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
ShuffleDataEntry implements ConfigEntry = name:"shuffle_data" ":" value:BooleanValue;
ClipGlobalGradNormEntry implements ConfigEntry = name:"clip_global_grad_norm" ":" value:NumberValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
......
......@@ -156,6 +156,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTShuffleDataEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTClipGlobalGradNormEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTTrainContextEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment