From 8d4751fd844c0e74fcc002e85c37e30daccb386b Mon Sep 17 00:00:00 2001 From: Julian Dierkes Date: Wed, 8 Jan 2020 23:49:43 +0100 Subject: [PATCH] added entry to specify constraint distributions and created new Map like type in CNNTrain grammar --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 14 ++++- .../_ast/ASTMultiParamValueMapParamValue.java | 20 +++++++ .../_ast/ASTMultiParamValueMapTupleValue.java | 25 +++++++++ .../cnntrain/_cocos/CheckEntryRepetition.java | 1 + .../_cocos/ParameterAlgorithmMapping.java | 3 +- .../CNNTrainSymbolTableCreator.java | 45 +++++++++++++++- .../MultiParamValueMapSymbol.java | 54 +++++++++++++++++++ .../MultiParamValueMapSymbolKind.java | 27 ++++++++++ .../helper/ConfigEntryNameConstants.java | 1 + 9 files changed, 186 insertions(+), 4 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapParamValue.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapTupleValue.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbol.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbolKind.java diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index ddfe3b3..07bae84 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -24,7 +24,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); ComponentNameValue implements ConfigValue = Name ("."Name)*; DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")"; - IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")"; IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]"; NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; @@ -255,6 +254,10 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number // GANs Extensions + interface MultiParamValueMapConfigEntry extends ConfigEntry; + interface MultiParamValueMapParamValue extends ConfigValue; + interface MultiParamValueMapTupleValue extends ConfigValue; + DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue; QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue; PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue; @@ -263,9 +266,16 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number NoiseDistributionEntry implements MultiParamConfigEntry = name:"noise_distribution" ":" value:NoiseDistributionValue; interface NoiseDistributionValue extends MultiParamValue; - interface NoiseDistributionGaussianEntry extends Entry; + interface NoiseDistributionParamEntry extends Entry; + interface NoiseDistributionGaussianEntry extends NoiseDistributionParamEntry; + NoiseDistributionGaussianValue implements NoiseDistributionValue = name:"gaussian" ("{" params:NoiseDistributionGaussianEntry* "}")?; MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue; SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue; + + // Constraint Distributions + ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue; + ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?; + ConstraintDistributionParam implements MultiParamValueMapTupleValue = name:StringValue ":" distribution:NoiseDistributionValue; } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapParamValue.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapParamValue.java new file mode 100644 index 0000000..66bf5fb --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapParamValue.java @@ -0,0 +1,20 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._ast; + +import java.util.ArrayList; +import java.util.List; + +/** + * + */ +public interface ASTMultiParamValueMapParamValue extends ASTMultiParamValueMapParamValueTOP { + default List getParamsList() { + return new ArrayList<>(); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapTupleValue.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapTupleValue.java new file mode 100644 index 0000000..e794004 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapTupleValue.java @@ -0,0 +1,25 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._ast; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * + */ +public interface ASTMultiParamValueMapTupleValue extends ASTMultiParamValueMapTupleValueTOP { + default ASTNoiseDistributionValue getDistribution() { + return null; + } + default ASTStringValue getName() { + return new ASTStringValue(); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java index b74004c..1c1ccb9 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java @@ -20,6 +20,7 @@ public class CheckEntryRepetition implements CNNTrainASTEntryCoCo { private final static Set> REPEATABLE_ENTRIES = ImmutableSet .>builder() .add(ASTOptimizerParamEntry.class) + .add(ASTNoiseDistributionParamEntry.class) .build(); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index c02fab0..73b3e91 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -121,7 +121,8 @@ class ParameterAlgorithmMapping { private static final List GENERAL_GAN_PARAMETERS = Lists.newArrayList( ASTDiscriminatorNetworkEntry.class, ASTQNetworkEntry.class, - ASTNoiseDistributionEntry.class + ASTNoiseDistributionEntry.class, + ASTConstraintDistributionEntry.class ); ParameterAlgorithmMapping() { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index fb11d48..6c2b41f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -9,11 +9,13 @@ package de.monticore.lang.monticar.cnntrain._symboltable; import de.monticore.ast.ASTCNode; import de.monticore.lang.monticar.cnntrain._ast.*; +import de.monticore.lang.monticar.cnntrain._parser.CNNTrainAntlrParser; import de.monticore.symboltable.ArtifactScope; import de.monticore.symboltable.ImportStatement; import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.ResolvingConfiguration; import de.se_rwth.commons.logging.Log; +import org.antlr.v4.runtime.misc.Pair; import java.util.*; import java.util.stream.Collectors; @@ -512,6 +514,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override public void visit(ASTReplayMemoryEntry node) { processMultiParamConfigVisit(node, node.getValue().getName()); @@ -544,6 +547,16 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { processMultiParamConfigEndVisit(node); } + @Override + public void visit(ASTConstraintDistributionEntry node) { + processMultiParamMapConfigVisit(node, node.getName()); + } + + @Override + public void endVisit(ASTConstraintDistributionEntry node) { + processMultiParamMapConfigEndVisit(node); + } + @Override public void visit(ASTNoiseDistributionEntry node) { NoiseDistribution noiseDistribution; @@ -560,6 +573,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { processMultiParamConfigEndVisit(node); } + @Override public void visit(ASTRewardFunctionEntry node) { RewardFunctionSymbol symbol = new RewardFunctionSymbol(node.getName()); @@ -642,7 +656,36 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { retrievePrimitiveValueByConfigValue(nodeParam.getValue())); } } - + + private void processMultiParamMapConfigVisit(ASTMultiParamValueMapConfigEntry node, Object value) { + EntrySymbol entry = new EntrySymbol(node.getName()); + MultiParamValueMapSymbol valueSymbol = new MultiParamValueMapSymbol(); + valueSymbol.setValue(value); + entry.setValue(valueSymbol); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + private void processMultiParamMapConfigEndVisit(ASTMultiParamValueMapConfigEntry node) { + ValueSymbol valueSymbol = configuration.getEntryMap().get(node.getName()).getValue(); + assert valueSymbol instanceof MultiParamValueMapSymbol : "Value symbol is not a multi parameter symbol"; + MultiParamValueMapSymbol multiParamValueMapSymbol = (MultiParamValueMapSymbol)valueSymbol; + for (ASTConfigValue nodeParam : ((ASTMultiParamValueMapParamValue)node.getValue()).getParamsList()) { + ASTMultiParamValueMapTupleValue tuple = ((ASTMultiParamValueMapTupleValue)nodeParam); + ASTStringValue name = tuple.getName(); + ASTMultiParamValue distribution = tuple.getDistribution(); + String distrName = distribution.getName(); + multiParamValueMapSymbol.addMultiParamValueName(getStringFromStringValue(name), distrName); + HashMap mapEntry = new HashMap<>(); + for (ASTEntry param : distribution.getParamsList()) { + String distrEntryName = param.getName(); + Object res = retrievePrimitiveValueByConfigValue(param.getValue()); + mapEntry.put(distrEntryName, res); + } + multiParamValueMapSymbol.addParameter(getStringFromStringValue(name), mapEntry); + } + } + private Object retrievePrimitiveValueByConfigValue(final ASTConfigValue configValue) { if (configValue instanceof ASTIntegerValue) { return getIntegerFromNumber((ASTIntegerValue)configValue); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbol.java new file mode 100644 index 0000000..3d744a0 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbol.java @@ -0,0 +1,54 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +import java.util.HashMap; +import java.util.Map; + +/** + * + */ +public class MultiParamValueMapSymbol extends ValueSymbol { + public static final MultiParamValueMapSymbolKind KIND = new MultiParamValueMapSymbolKind(); + + private Map> parameters; + private Map multiParamValueNames; + + public MultiParamValueMapSymbol() { + super("", KIND); + this.parameters = new HashMap<>(); + this.multiParamValueNames = new HashMap<>(); + } + + public Map> getParameters() { + return parameters; + } + + public Map getMultiParamValueNames() { return multiParamValueNames; } + + public Object getParameter(final String parameterName) { + return parameters.get(parameterName); + } + + public boolean hasParameter(final String parameterName) { + return parameters.containsKey(parameterName); + } + + public void addParameter(final String parameterName, final Map value) { + parameters.put(parameterName, value); + } + + public void addMultiParamValueName(final String parameterName, final String name) { + multiParamValueNames.put(parameterName, name); + } + + @Override + public String toString() { + return super.toString() + '{' + parameters + '}'; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbolKind.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbolKind.java new file mode 100644 index 0000000..8349bc7 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbolKind.java @@ -0,0 +1,27 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +import de.monticore.symboltable.SymbolKind; + +/** + * + */ +public class MultiParamValueMapSymbolKind extends ValueKind { + private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueMapSymbolKind"; + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean isKindOf(SymbolKind kind) { + return NAME.equals(kind.getName()) || super.isKindOf(kind); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java index e90a7da..9888d21 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java @@ -50,4 +50,5 @@ public class ConfigEntryNameConstants { public static final String QNETWORK_NAME = "qnet_name"; public static final String PREPROCESSING_NAME = "preprocessing_name"; public static final String NOISE_DISTRIBUTION = "noise_distribution"; + public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions"; } -- GitLab