Commit cf8723f3 authored by Sascha Dewes's avatar Sascha Dewes

bug fixes

parent d073b942
Pipeline #406986 passed with stage
in 3 minutes and 16 seconds
......@@ -38,8 +38,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ShuffleDataEntry implements ConfigEntry = name:"shuffle_data" ":" value:BooleanValue;
ClipGlobalGradNormEntry implements ConfigEntry = name:"clip_global_grad_norm" ":" value:NumberValue;
InitializerEntry implements ConfigEntry = (name:"initializer" | name:"actor_initializer") ":" value:InitializerValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
......@@ -78,8 +76,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu");
interface InitializerParamEntry extends Entry;
interface OptimizerParamEntry extends Entry;
interface LossValue extends MultiParamValue;
......@@ -132,12 +128,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
IgnoreLabelEntry implements SoftmaxCrossEntropyIgnoreLabelEntry = name:"loss_ignore_label" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
interface InitializerValue extends ConfigValue;
interface InitializerNormalEntry extends InitializerParamEntry;
NormalInitializer implements InitializerValue = name:"normal" ("{" params:InitializerNormalEntry* "}")?;
InitializerNormalSigma implements InitializerNormalEntry = name:"sigma" ":" value:NumberValue;
interface OptimizerValue extends ConfigValue;
interface SGDEntry extends OptimizerParamEntry;
......@@ -207,6 +197,14 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm" | tdThree:"td3-algorithm");
interface MultiParamConfigEntry extends ConfigEntry;
// Initializer
InitializerEntry implements MultiParamConfigEntry = (name:"initializer" | name:"actor_initializer") ":" value:InitializerValue;
interface InitializerValue extends MultiParamValue;
interface InitializerNormalEntry extends Entry;
InitializerNormalValue implements InitializerValue = name:"normal" ("{" params:InitializerNormalEntry* "}")?;
InitializerNormalSigma implements InitializerNormalEntry = name:"sigma" ":" value:NumberValue;
// Replay Memory
ReplayMemoryEntry implements MultiParamConfigEntry = name:"replay_memory" ":" value:ReplayMemoryValue;
......@@ -279,7 +277,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
// DDPG and TD3 exclusive parameters
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue;
CriticInitializerEntry implements ConfigEntry = name:"critic_initializer" ":" value:InitializerValue;
CriticInitializerEntry implements MultiParamConfigEntry = name:"critic_initializer" ":" value:InitializerValue;
CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue;
// TD3 exclusive parameters
......@@ -331,4 +329,4 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ConstraintLossEntry implements MultiParamValueMapConfigEntry = name:"constraint_losses" ":" value:ConstraintLossValue;
ConstraintLossValue implements MultiParamValueMapParamValue = ("{" params:ConstraintLossParam* "}")?;
ConstraintLossParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:LossValue;
}
\ No newline at end of file
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.List;
public interface ASTInitializerValue extends ASTInitializerValueTOP {
String getName();
List<? extends ASTEntry> getParamsList();
}
......@@ -16,6 +16,8 @@ import java.util.List;
class ParameterAlgorithmMapping {
private static final List<Class> GENERAL_PARAMETERS = Lists.newArrayList(
ASTTrainContextEntry.class,
ASTInitializerEntry.class,
ASTInitializerNormalSigma.class,
ASTOptimizerEntry.class,
ASTLearningRateEntry.class,
ASTMinimumLearningRateEntry.class,
......@@ -104,6 +106,7 @@ class ParameterAlgorithmMapping {
private static final List<Class> EXCLUSIVE_DDPG_PARAMETERS = Lists.newArrayList(
ASTCriticNetworkEntry.class,
ASTSoftTargetUpdateRateEntry.class,
ASTCriticInitializerEntry.class,
ASTCriticOptimizerEntry.class,
ASTStrategyOUMu.class,
ASTStrategyOUTheta.class,
......@@ -114,6 +117,7 @@ class ParameterAlgorithmMapping {
private static final List<Class> EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList(
ASTCriticNetworkEntry.class,
ASTSoftTargetUpdateRateEntry.class,
ASTCriticInitializerEntry.class,
ASTCriticOptimizerEntry.class,
ASTStrategyOUMu.class,
ASTStrategyOUTheta.class,
......
......@@ -79,38 +79,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void visit(ASTInitializerEntry node) {
InitializerSymbol initializer = new InitializerSymbol(node.getValue().getName());
configuration.setInitializer(initializer);
addToScopeAndLinkWithNode(initializer, node);
processMultiParamConfigVisit(node, node.getValue().getName());
}
@Override
public void endVisit(ASTInitializerEntry node) {
for (ASTEntry nodeParam : node.getValue().getParamsList()) {
InitializerParamSymbol param = new InitializerParamSymbol();
InitializerParamValueSymbol valueSymbol = (InitializerParamValueSymbol) nodeParam.getValue().getSymbolOpt().get();
param.setValue(valueSymbol);
configuration.getInitializer().getInitializerParamMap().put(nodeParam.getName(), param);
}
processMultiParamConfigEndVisit(node);
}
@Override
public void visit(ASTCriticInitializerEntry node) {
InitializerSymbol initializer = new InitializerSymbol(node.getValue().getName());
configuration.setCriticInitializer(initializer);
addToScopeAndLinkWithNode(initializer, node);
processMultiParamConfigVisit(node, node.getValue().getName());
}
@Override
public void endVisit(ASTCriticInitializerEntry node) {
assert configuration.getCriticInitializer().isPresent(): "Critic initializer not present";
for (ASTEntry paramNode : node.getValue().getParamsList()) {
InitializerParamSymbol param = new InitializerParamSymbol();
InitializerParamValueSymbol valueSymbol = (InitializerParamValueSymbol)paramNode.getValue().getSymbolOpt().get();
param.setValue(valueSymbol);
configuration.getCriticInitializer().get().getInitializerParamMap().put(paramNode.getName(), param);
}
processMultiParamConfigEndVisit(node);
}
@Override
......
......@@ -16,8 +16,6 @@ import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstant
public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map<String, EntrySymbol> entryMap = new HashMap<>();
private InitializerSymbol initializer;
private InitializerSymbol criticInitializer;
private OptimizerSymbol optimizer;
private OptimizerSymbol criticOptimizer;
private LossSymbol loss;
......@@ -36,22 +34,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
preprocessingComponentSymbol = null;
trainedArchitecture = null;
}
public InitializerSymbol getInitializer() {
return initializer;
}
public void setInitializer(InitializerSymbol initializer) {
this.initializer = initializer;
}
public void setCriticInitializer(InitializerSymbol criticInitializer) {
this.criticInitializer = criticInitializer;
}
public Optional<InitializerSymbol> getCriticInitializer() {
return Optional.ofNullable(criticInitializer);
}
public OptimizerSymbol getOptimizer() {
return optimizer;
......@@ -191,4 +173,4 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
assert qnetNameValue instanceof String;
return Optional.of((String)qnetNameValue);
}
}
\ No newline at end of file
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamSymbol extends CommonSymbol {
public static final EntryKind KIND = new EntryKind();
private InitializerParamValueSymbol value;
public InitializerParamSymbol() {
super("", KIND);
}
public InitializerParamSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public InitializerParamValueSymbol getValue() {
return value;
}
public void setValue(InitializerParamValueSymbol value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerParamSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamValueSymbol extends CommonSymbol {
public static final InitializerParamValueSymbolKind KIND = new InitializerParamValueSymbolKind();
private Object value;
public InitializerParamValueSymbol() {
super("", KIND);
}
public InitializerParamValueSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public String toString(){
return getValue().toString();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerParamValueSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerParamValueSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import java.util.HashMap;
import java.util.Map;
public class InitializerSymbol extends de.monticore.symboltable.CommonSymbol {
private Map<String, InitializerParamSymbol> initializerParamMap = new HashMap<>();
public static final InitializerSymbolKind KIND = InitializerSymbolKind.INSTANCE;
public InitializerSymbol(String name) {
super(name, KIND);
}
public Map<String, InitializerParamSymbol> getInitializerParamMap() {
return initializerParamMap;
}
public void setInitializerParamMap(Map<String, InitializerParamSymbol> initializerParamMap) {
this.initializerParamMap = initializerParamMap;
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
public class InitializerSymbolKind implements SymbolKind {
public static final InitializerSymbolKind INSTANCE = new InitializerSymbolKind();
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.InitializerSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment