Commit 329c3130 authored by Sascha Dewes's avatar Sascha Dewes
Browse files

added weight initializer symbols

parent ec13dde6
Pipeline #403248 failed with stage
in 3 minutes and 36 seconds
......@@ -38,6 +38,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ShuffleDataEntry implements ConfigEntry = name:"shuffle_data" ":" value:BooleanValue;
ClipGlobalGradNormEntry implements ConfigEntry = name:"clip_global_grad_norm" ":" value:NumberValue;
InitializerEntry implements ConfigEntry = (name:"initializer" | name:"actor_initializer") ":" value:InitializerValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
......@@ -75,7 +77,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
| sigmoid:"sigmoid");
TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu");
interface InitializerParamEntry extends Entry;
interface OptimizerParamEntry extends Entry;
interface LossValue extends MultiParamValue;
......@@ -128,6 +132,12 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
IgnoreLabelEntry implements SoftmaxCrossEntropyIgnoreLabelEntry = name:"loss_ignore_label" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
interface InitializerValue extends ConfigValue;
interface InitializerNormalEntry extends InitializerParamEntry;
NormalInitializer implements InitializerValue = name:"normal" ("{" params:InitializerNormalEntry* "}")?;
InitializerNormalSigma implements InitializerNormalEntry = name:"sigma" ":" value:NumberValue;
interface OptimizerValue extends ConfigValue;
interface SGDEntry extends OptimizerParamEntry;
......@@ -269,6 +279,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
// DDPG and TD3 exclusive parameters
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue;
CriticInitializerEntry implements ConfigEntry = name:"critic_initializer" ":" value:InitializerValue;
CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue;
// TD3 exclusive parameters
......
......@@ -76,6 +76,42 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
public void endVisit(final ASTConfiguration trainingConfiguration) {
removeCurrentScope();
}
@Override
public void visit(ASTInitializerEntry node) {
InitializerSymbol initializer = new InitializerSymbol(node.getValue().getName());
configuration.setInitializer(initializer);
addToScopeAndLinkWithNode(initializer, node);
}
@Override
public void endVisit(ASTInitializerEntry node) {
for (ASTEntry nodeParam : node.getValue().getParamsList()) {
InitializerParamSymbol param = new InitializerParamSymbol();
InitializerParamValueSymbol valueSymbol = (InitializerParamValueSymbol) nodeParam.getValue().getSymbolOpt().get();
param.setValue(valueSymbol);
configuration.getInitializer().getInitializerParamMap().put(nodeParam.getName(), param);
}
}
@Override
public void visit(ASTCriticInitializerEntry node) {
InitializerSymbol initializer = new InitializerSymbol(node.getValue().getName());
configuration.setCriticInitializer(initializer);
addToScopeAndLinkWithNode(initializer, node);
}
@Override
public void endVisit(ASTCriticInitializerEntry node) {
assert configuration.getCriticInitializer().isPresent(): "Critic initializer not present";
for (ASTEntry paramNode : node.getValue().getParamsList()) {
InitializerParamSymbol param = new InitializerParamSymbol();
InitializerParamValueSymbol valueSymbol = (InitializerParamValueSymbol)paramNode.getValue().getSymbolOpt().get();
param.setValue(valueSymbol);
configuration.getCriticInitializer().get().getInitializerParamMap().put(paramNode.getName(), param);
}
}
@Override
public void visit(ASTOptimizerEntry node) {
......
......@@ -16,6 +16,8 @@ import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstant
public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map<String, EntrySymbol> entryMap = new HashMap<>();
private InitializerSymbol initializer;
private InitializerSymbol criticInitializer;
private OptimizerSymbol optimizer;
private OptimizerSymbol criticOptimizer;
private LossSymbol loss;
......@@ -34,6 +36,22 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
preprocessingComponentSymbol = null;
trainedArchitecture = null;
}
public InitializerSymbol getInitializer() {
return initializer;
}
public void setInitializer(InitializerSymbol initializer) {
this.initializer = initializer;
}
public void setCriticInitializer(InitializerSymbol criticInitializer) {
this.criticInitializer = criticInitializer;
}
public Optional<InitializerSymbol> getCriticInitializer() {
return Optional.ofNullable(criticInitializer);
}
public OptimizerSymbol getOptimizer() {
return optimizer;
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamSymbol extends CommonSymbol {
public static final EntryKind KIND = new EntryKind();
private InitializerParamValueSymbol value;
public InitializerParamSymbol() {
super("", KIND);
}
public InitializerParamSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public InitializerParamValueSymbol getValue() {
return value;
}
public void setValue(InitializerParamValueSymbol value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerParamSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamValueSymbol extends CommonSymbol {
public static final InitializerParamValueSymbolKind KIND = new InitializerParamValueSymbolKind();
private Object value;
public InitializerParamValueSymbol() {
super("", KIND);
}
public InitializerParamValueSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamValueSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerParamValueSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import java.util.HashMap;
import java.util.Map;
public class InitializerSymbol extends de.monticore.symboltable.CommonSymbol {
private Map<String, InitializerParamSymbol> initializerParamMap = new HashMap<>();
public static final InitializerSymbolKind KIND = InitializerSymbolKind.INSTANCE;
public InitializerSymbol(String name) {
super(name, KIND);
}
public Map<String, InitializerParamSymbol> getInitializerParamMap() {
return initializerParamMap;
}
public void setInitializerParamMap(Map<String, InitializerParamSymbol> initializerParamMap) {
this.initializerParamMap = initializerParamMap;
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerSymbolKind implements SymbolKind {
public static final InitializerSymbolKind INSTANCE = new InitializerSymbolKind();
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment