Commit a0f8413b authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'weight_initializer' into 'master'

Weight Initializer

See merge request !33
parents ec13dde6 d073b942
Pipeline #403425 passed with stage
in 3 minutes and 48 seconds
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId> <artifactId>cnn-train</artifactId>
<version>0.4.4-SNAPSHOT</version> <version>0.4.5-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
......
...@@ -38,6 +38,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -38,6 +38,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ShuffleDataEntry implements ConfigEntry = name:"shuffle_data" ":" value:BooleanValue; ShuffleDataEntry implements ConfigEntry = name:"shuffle_data" ":" value:BooleanValue;
ClipGlobalGradNormEntry implements ConfigEntry = name:"clip_global_grad_norm" ":" value:NumberValue; ClipGlobalGradNormEntry implements ConfigEntry = name:"clip_global_grad_norm" ":" value:NumberValue;
InitializerEntry implements ConfigEntry = (name:"initializer" | name:"actor_initializer") ":" value:InitializerValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue; OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue; TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
...@@ -75,7 +77,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -75,7 +77,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
| sigmoid:"sigmoid"); | sigmoid:"sigmoid");
TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu"); TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu");
interface InitializerParamEntry extends Entry;
interface OptimizerParamEntry extends Entry; interface OptimizerParamEntry extends Entry;
interface LossValue extends MultiParamValue; interface LossValue extends MultiParamValue;
...@@ -128,6 +132,12 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -128,6 +132,12 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
IgnoreLabelEntry implements SoftmaxCrossEntropyIgnoreLabelEntry = name:"loss_ignore_label" ":" value:IntegerValue; IgnoreLabelEntry implements SoftmaxCrossEntropyIgnoreLabelEntry = name:"loss_ignore_label" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue; MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue; LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
interface InitializerValue extends ConfigValue;
interface InitializerNormalEntry extends InitializerParamEntry;
NormalInitializer implements InitializerValue = name:"normal" ("{" params:InitializerNormalEntry* "}")?;
InitializerNormalSigma implements InitializerNormalEntry = name:"sigma" ":" value:NumberValue;
interface OptimizerValue extends ConfigValue; interface OptimizerValue extends ConfigValue;
interface SGDEntry extends OptimizerParamEntry; interface SGDEntry extends OptimizerParamEntry;
...@@ -269,6 +279,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -269,6 +279,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
// DDPG and TD3 exclusive parameters // DDPG and TD3 exclusive parameters
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue; SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue;
CriticInitializerEntry implements ConfigEntry = name:"critic_initializer" ":" value:InitializerValue;
CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue; CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue;
// TD3 exclusive parameters // TD3 exclusive parameters
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.List;
public interface ASTInitializerValue extends ASTInitializerValueTOP {
String getName();
List<? extends ASTEntry> getParamsList();
}
...@@ -76,6 +76,42 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -76,6 +76,42 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
public void endVisit(final ASTConfiguration trainingConfiguration) { public void endVisit(final ASTConfiguration trainingConfiguration) {
removeCurrentScope(); removeCurrentScope();
} }
@Override
public void visit(ASTInitializerEntry node) {
InitializerSymbol initializer = new InitializerSymbol(node.getValue().getName());
configuration.setInitializer(initializer);
addToScopeAndLinkWithNode(initializer, node);
}
@Override
public void endVisit(ASTInitializerEntry node) {
for (ASTEntry nodeParam : node.getValue().getParamsList()) {
InitializerParamSymbol param = new InitializerParamSymbol();
InitializerParamValueSymbol valueSymbol = (InitializerParamValueSymbol) nodeParam.getValue().getSymbolOpt().get();
param.setValue(valueSymbol);
configuration.getInitializer().getInitializerParamMap().put(nodeParam.getName(), param);
}
}
@Override
public void visit(ASTCriticInitializerEntry node) {
InitializerSymbol initializer = new InitializerSymbol(node.getValue().getName());
configuration.setCriticInitializer(initializer);
addToScopeAndLinkWithNode(initializer, node);
}
@Override
public void endVisit(ASTCriticInitializerEntry node) {
assert configuration.getCriticInitializer().isPresent(): "Critic initializer not present";
for (ASTEntry paramNode : node.getValue().getParamsList()) {
InitializerParamSymbol param = new InitializerParamSymbol();
InitializerParamValueSymbol valueSymbol = (InitializerParamValueSymbol)paramNode.getValue().getSymbolOpt().get();
param.setValue(valueSymbol);
configuration.getCriticInitializer().get().getInitializerParamMap().put(paramNode.getName(), param);
}
}
@Override @Override
public void visit(ASTOptimizerEntry node) { public void visit(ASTOptimizerEntry node) {
......
...@@ -16,6 +16,8 @@ import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstant ...@@ -16,6 +16,8 @@ import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstant
public class ConfigurationSymbol extends CommonScopeSpanningSymbol { public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map<String, EntrySymbol> entryMap = new HashMap<>(); private Map<String, EntrySymbol> entryMap = new HashMap<>();
private InitializerSymbol initializer;
private InitializerSymbol criticInitializer;
private OptimizerSymbol optimizer; private OptimizerSymbol optimizer;
private OptimizerSymbol criticOptimizer; private OptimizerSymbol criticOptimizer;
private LossSymbol loss; private LossSymbol loss;
...@@ -34,6 +36,22 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -34,6 +36,22 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
preprocessingComponentSymbol = null; preprocessingComponentSymbol = null;
trainedArchitecture = null; trainedArchitecture = null;
} }
public InitializerSymbol getInitializer() {
return initializer;
}
public void setInitializer(InitializerSymbol initializer) {
this.initializer = initializer;
}
public void setCriticInitializer(InitializerSymbol criticInitializer) {
this.criticInitializer = criticInitializer;
}
public Optional<InitializerSymbol> getCriticInitializer() {
return Optional.ofNullable(criticInitializer);
}
public OptimizerSymbol getOptimizer() { public OptimizerSymbol getOptimizer() {
return optimizer; return optimizer;
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamSymbol extends CommonSymbol {
public static final EntryKind KIND = new EntryKind();
private InitializerParamValueSymbol value;
public InitializerParamSymbol() {
super("", KIND);
}
public InitializerParamSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public InitializerParamValueSymbol getValue() {
return value;
}
public void setValue(InitializerParamValueSymbol value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerParamSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamValueSymbol extends CommonSymbol {
public static final InitializerParamValueSymbolKind KIND = new InitializerParamValueSymbolKind();
private Object value;
public InitializerParamValueSymbol() {
super("", KIND);
}
public InitializerParamValueSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamValueSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerParamValueSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import java.util.HashMap;
import java.util.Map;
public class InitializerSymbol extends de.monticore.symboltable.CommonSymbol {
private Map<String, InitializerParamSymbol> initializerParamMap = new HashMap<>();
public static final InitializerSymbolKind KIND = InitializerSymbolKind.INSTANCE;
public InitializerSymbol(String name) {
super(name, KIND);
}
public Map<String, InitializerParamSymbol> getInitializerParamMap() {
return initializerParamMap;
}
public void setInitializerParamMap(Map<String, InitializerParamSymbol> initializerParamMap) {
this.initializerParamMap = initializerParamMap;
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerSymbolKind implements SymbolKind {
public static final InitializerSymbolKind INSTANCE = new InitializerSymbolKind();
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment