Commit 8d4751fd authored by Julian Dierkes's avatar Julian Dierkes

added entry to specify constraint distributions and created new Map like type in CNNTrain grammar

parent 1422c428
......@@ -24,7 +24,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")";
IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
......@@ -255,6 +254,10 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
// GANs Extensions
interface MultiParamValueMapConfigEntry extends ConfigEntry;
interface MultiParamValueMapParamValue extends ConfigValue;
interface MultiParamValueMapTupleValue extends ConfigValue;
DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue;
QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue;
PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue;
......@@ -263,9 +266,16 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NoiseDistributionEntry implements MultiParamConfigEntry = name:"noise_distribution" ":" value:NoiseDistributionValue;
interface NoiseDistributionValue extends MultiParamValue;
interface NoiseDistributionGaussianEntry extends Entry;
interface NoiseDistributionParamEntry extends Entry;
interface NoiseDistributionGaussianEntry extends NoiseDistributionParamEntry;
NoiseDistributionGaussianValue implements NoiseDistributionValue = name:"gaussian" ("{" params:NoiseDistributionGaussianEntry* "}")?;
MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue;
SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue;
// Constraint Distributions
ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue;
ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?;
ConstraintDistributionParam implements MultiParamValueMapTupleValue = name:StringValue ":" distribution:NoiseDistributionValue;
}
\ No newline at end of file
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.ArrayList;
import java.util.List;
/**
*
*/
public interface ASTMultiParamValueMapParamValue extends ASTMultiParamValueMapParamValueTOP {
default List<? extends ASTConfigValue> getParamsList() {
return new ArrayList<>();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
*
*/
public interface ASTMultiParamValueMapTupleValue extends ASTMultiParamValueMapTupleValueTOP {
default ASTNoiseDistributionValue getDistribution() {
return null;
}
default ASTStringValue getName() {
return new ASTStringValue();
}
}
......@@ -20,6 +20,7 @@ public class CheckEntryRepetition implements CNNTrainASTEntryCoCo {
private final static Set<Class<? extends ASTEntry>> REPEATABLE_ENTRIES = ImmutableSet
.<Class<? extends ASTEntry>>builder()
.add(ASTOptimizerParamEntry.class)
.add(ASTNoiseDistributionParamEntry.class)
.build();
......
......@@ -121,7 +121,8 @@ class ParameterAlgorithmMapping {
private static final List<Class> GENERAL_GAN_PARAMETERS = Lists.newArrayList(
ASTDiscriminatorNetworkEntry.class,
ASTQNetworkEntry.class,
ASTNoiseDistributionEntry.class
ASTNoiseDistributionEntry.class,
ASTConstraintDistributionEntry.class
);
ParameterAlgorithmMapping() {
......
......@@ -9,11 +9,13 @@ package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.ast.ASTCNode;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._parser.CNNTrainAntlrParser;
import de.monticore.symboltable.ArtifactScope;
import de.monticore.symboltable.ImportStatement;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.ResolvingConfiguration;
import de.se_rwth.commons.logging.Log;
import org.antlr.v4.runtime.misc.Pair;
import java.util.*;
import java.util.stream.Collectors;
......@@ -512,6 +514,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTReplayMemoryEntry node) {
processMultiParamConfigVisit(node, node.getValue().getName());
......@@ -544,6 +547,16 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node);
}
@Override
public void visit(ASTConstraintDistributionEntry node) {
processMultiParamMapConfigVisit(node, node.getName());
}
@Override
public void endVisit(ASTConstraintDistributionEntry node) {
processMultiParamMapConfigEndVisit(node);
}
@Override
public void visit(ASTNoiseDistributionEntry node) {
NoiseDistribution noiseDistribution;
......@@ -560,6 +573,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node);
}
@Override
public void visit(ASTRewardFunctionEntry node) {
RewardFunctionSymbol symbol = new RewardFunctionSymbol(node.getName());
......@@ -642,7 +656,36 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
retrievePrimitiveValueByConfigValue(nodeParam.getValue()));
}
}
private void processMultiParamMapConfigVisit(ASTMultiParamValueMapConfigEntry node, Object value) {
EntrySymbol entry = new EntrySymbol(node.getName());
MultiParamValueMapSymbol valueSymbol = new MultiParamValueMapSymbol();
valueSymbol.setValue(value);
entry.setValue(valueSymbol);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
private void processMultiParamMapConfigEndVisit(ASTMultiParamValueMapConfigEntry node) {
ValueSymbol valueSymbol = configuration.getEntryMap().get(node.getName()).getValue();
assert valueSymbol instanceof MultiParamValueMapSymbol : "Value symbol is not a multi parameter symbol";
MultiParamValueMapSymbol multiParamValueMapSymbol = (MultiParamValueMapSymbol)valueSymbol;
for (ASTConfigValue nodeParam : ((ASTMultiParamValueMapParamValue)node.getValue()).getParamsList()) {
ASTMultiParamValueMapTupleValue tuple = ((ASTMultiParamValueMapTupleValue)nodeParam);
ASTStringValue name = tuple.getName();
ASTMultiParamValue distribution = tuple.getDistribution();
String distrName = distribution.getName();
multiParamValueMapSymbol.addMultiParamValueName(getStringFromStringValue(name), distrName);
HashMap<String, Object> mapEntry = new HashMap<>();
for (ASTEntry param : distribution.getParamsList()) {
String distrEntryName = param.getName();
Object res = retrievePrimitiveValueByConfigValue(param.getValue());
mapEntry.put(distrEntryName, res);
}
multiParamValueMapSymbol.addParameter(getStringFromStringValue(name), mapEntry);
}
}
private Object retrievePrimitiveValueByConfigValue(final ASTConfigValue configValue) {
if (configValue instanceof ASTIntegerValue) {
return getIntegerFromNumber((ASTIntegerValue)configValue);
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import java.util.HashMap;
import java.util.Map;
/**
*
*/
public class MultiParamValueMapSymbol extends ValueSymbol {
public static final MultiParamValueMapSymbolKind KIND = new MultiParamValueMapSymbolKind();
private Map<String, Map<String,Object>> parameters;
private Map<String, String> multiParamValueNames;
public MultiParamValueMapSymbol() {
super("", KIND);
this.parameters = new HashMap<>();
this.multiParamValueNames = new HashMap<>();
}
public Map<String, Map<String,Object>> getParameters() {
return parameters;
}
public Map<String, String> getMultiParamValueNames() { return multiParamValueNames; }
public Object getParameter(final String parameterName) {
return parameters.get(parameterName);
}
public boolean hasParameter(final String parameterName) {
return parameters.containsKey(parameterName);
}
public void addParameter(final String parameterName, final Map<String, Object> value) {
parameters.put(parameterName, value);
}
public void addMultiParamValueName(final String parameterName, final String name) {
multiParamValueNames.put(parameterName, name);
}
@Override
public String toString() {
return super.toString() + '{' + parameters + '}';
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
/**
*
*/
public class MultiParamValueMapSymbolKind extends ValueKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueMapSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || super.isKindOf(kind);
}
}
......@@ -50,4 +50,5 @@ public class ConfigEntryNameConstants {
public static final String QNETWORK_NAME = "qnet_name";
public static final String PREPROCESSING_NAME = "preprocessing_name";
public static final String NOISE_DISTRIBUTION = "noise_distribution";
public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment