Commit 19d9ed72 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Added several parameters

See merge request !26
parents 6d091d65 7574b6e6
Pipeline #228878 passed with stages
......@@ -24,13 +24,18 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")";
IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue;
LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
ShuffleDataEntry implements ConfigEntry = name:"shuffle_data" ":" value:BooleanValue;
ClipGlobalGradNormEntry implements ConfigEntry = name:"clip_global_grad_norm" ":" value:NumberValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
......@@ -52,6 +57,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
EvalTrainEntry implements ConfigEntry = name:"eval_train" ":" value:BooleanValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
| exp:"exp"
......@@ -62,7 +69,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface OptimizerParamEntry extends Entry;
interface LossValue extends ConfigValue;
interface LossValue extends MultiParamValue;
L1Loss implements LossValue = name:"l1" ("{" params:Entry* "}")?;
......@@ -251,17 +258,33 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
// GANs Extensions
interface MultiParamValueMapConfigEntry extends ConfigEntry;
interface MultiParamValueMapParamValue extends ConfigValue;
interface MultiParamValueMapTupleValue extends ConfigValue;
DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue;
QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue;
PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue;
ImgResizeEntry implements ConfigEntry = name:"img_resize" ":" value:IntegerTupelValue;
// Noise Distribution Creator
NoiseDistributionEntry implements MultiParamConfigEntry = name:"noise_distribution" ":" value:NoiseDistributionValue;
interface NoiseDistributionValue extends MultiParamValue;
interface NoiseDistributionGaussianEntry extends Entry;
interface NoiseDistributionParamEntry extends Entry;
interface NoiseDistributionGaussianEntry extends NoiseDistributionParamEntry;
NoiseDistributionGaussianValue implements NoiseDistributionValue = name:"gaussian" ("{" params:NoiseDistributionGaussianEntry* "}")?;
MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue;
SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue;
// Constraint Distributions
ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue;
ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?;
ConstraintDistributionParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:NoiseDistributionValue;
// Constraint losses
ConstraintLossEntry implements MultiParamValueMapConfigEntry = name:"constraint_losses" ":" value:ConstraintLossValue;
ConstraintLossValue implements MultiParamValueMapParamValue = ("{" params:ConstraintLossParam* "}")?;
ConstraintLossParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:LossValue;
}
\ No newline at end of file
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.ArrayList;
import java.util.List;
/**
*
*/
public interface ASTMultiParamValueMapParamValue extends ASTMultiParamValueMapParamValueTOP {
default List<? extends ASTConfigValue> getParamsList() {
return new ArrayList<>();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
*
*/
public interface ASTMultiParamValueMapTupleValue extends ASTMultiParamValueMapTupleValueTOP {
default ASTMultiParamValue getMultiParamValue() {
return null;
}
default ASTStringValue getName() {
return new ASTStringValue();
}
}
......@@ -20,6 +20,7 @@ public class CheckEntryRepetition implements CNNTrainASTEntryCoCo {
private final static Set<Class<? extends ASTEntry>> REPEATABLE_ENTRIES = ImmutableSet
.<Class<? extends ASTEntry>>builder()
.add(ASTOptimizerParamEntry.class)
.add(ASTNoiseDistributionParamEntry.class)
.build();
......
......@@ -37,7 +37,10 @@ class ParameterAlgorithmMapping {
private static final List<Class> EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList(
ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class,
ASTCheckpointPeriodEntry.class,
ASTLogPeriodEntry.class,
ASTEvalMetricEntry.class,
ASTEvalTrainEntry.class,
ASTExcludeBleuEntry.class,
ASTNormalizeEntry.class,
ASTNumEpochEntry.class,
......@@ -117,8 +120,10 @@ class ParameterAlgorithmMapping {
private static final List<Class> GENERAL_GAN_PARAMETERS = Lists.newArrayList(
ASTDiscriminatorNetworkEntry.class,
ASTQNetworkEntry.class,
ASTNoiseDistributionEntry.class,
ASTImgResizeEntry.class
ASTConstraintDistributionEntry.class,
ASTConstraintLossEntry.class
);
ParameterAlgorithmMapping() {
......
......@@ -9,11 +9,13 @@ package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.ast.ASTCNode;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._parser.CNNTrainAntlrParser;
import de.monticore.symboltable.ArtifactScope;
import de.monticore.symboltable.ImportStatement;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.ResolvingConfiguration;
import de.se_rwth.commons.logging.Log;
import org.antlr.v4.runtime.misc.Pair;
import java.util.*;
import java.util.stream.Collectors;
......@@ -52,6 +54,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
CNNTrainCompilationUnitSymbol compilationUnitSymbol = new CNNTrainCompilationUnitSymbol(compilationUnit.getName());
addToScopeAndLinkWithNode(compilationUnitSymbol, compilationUnit);
}
@Override
public void endVisit(ASTCNNTrainCompilationUnit ast) {
......@@ -131,6 +134,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTCheckpointPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLogPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTNormalizeEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......@@ -139,6 +158,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTShuffleDataEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTClipGlobalGradNormEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTTrainContextEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......@@ -164,6 +199,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node);
}
@Override
public void endVisit(ASTEvalTrainEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTLossEntry node) {
LossSymbol loss = new LossSymbol(node.getValue().getName());
......@@ -472,7 +515,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
}
@Override
public void visit(ASTPreprocessingEntry node) {
public void visit(ASTQNetworkEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
......@@ -480,19 +523,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
}
@Override
public void visit(ASTImgResizeEntry node) {
EntrySymbol width_entry = new EntrySymbol(node.getName());
EntrySymbol height_entry = new EntrySymbol(node.getName());
width_entry.setValue(getValueSymbolForInteger(node.getValue().getFirst()));
height_entry.setValue(getValueSymbolForInteger(node.getValue().getSecond()));
addToScopeAndLinkWithNode(width_entry, node);
addToScopeAndLinkWithNode(height_entry, node);
configuration.getEntryMap().put(node.getName() + "_width", width_entry);
configuration.getEntryMap().put(node.getName() + "_height", height_entry);
public void visit(ASTPreprocessingEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTReplayMemoryEntry node) {
processMultiParamConfigVisit(node, node.getValue().getName());
......@@ -525,6 +563,26 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node);
}
@Override
public void visit(ASTConstraintDistributionEntry node) {
processMultiParamMapConfigVisit(node, node.getName());
}
@Override
public void endVisit(ASTConstraintDistributionEntry node) {
processMultiParamMapConfigEndVisit(node);
}
@Override
public void visit(ASTConstraintLossEntry node) {
processMultiParamMapConfigVisit(node, node.getName());
}
@Override
public void endVisit(ASTConstraintLossEntry node) {
processMultiParamMapConfigEndVisit(node);
}
@Override
public void visit(ASTNoiseDistributionEntry node) {
NoiseDistribution noiseDistribution;
......@@ -541,6 +599,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node);
}
@Override
public void visit(ASTRewardFunctionEntry node) {
RewardFunctionSymbol symbol = new RewardFunctionSymbol(node.getName());
......@@ -623,7 +682,36 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
retrievePrimitiveValueByConfigValue(nodeParam.getValue()));
}
}
private void processMultiParamMapConfigVisit(ASTMultiParamValueMapConfigEntry node, Object value) {
EntrySymbol entry = new EntrySymbol(node.getName());
MultiParamValueMapSymbol valueSymbol = new MultiParamValueMapSymbol();
valueSymbol.setValue(value);
entry.setValue(valueSymbol);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
private void processMultiParamMapConfigEndVisit(ASTMultiParamValueMapConfigEntry node) {
ValueSymbol valueSymbol = configuration.getEntryMap().get(node.getName()).getValue();
assert valueSymbol instanceof MultiParamValueMapSymbol : "Value symbol is not a multi parameter symbol";
MultiParamValueMapSymbol multiParamValueMapSymbol = (MultiParamValueMapSymbol)valueSymbol;
for (ASTConfigValue nodeParam : ((ASTMultiParamValueMapParamValue)node.getValue()).getParamsList()) {
ASTMultiParamValueMapTupleValue tuple = ((ASTMultiParamValueMapTupleValue)nodeParam);
ASTStringValue name = tuple.getName();
ASTMultiParamValue multiValue = tuple.getMultiParamValue();
String valueName = multiValue.getName();
multiParamValueMapSymbol.addMultiParamValueName(getStringFromStringValue(name), valueName);
HashMap<String, Object> mapEntry = new HashMap<>();
for (ASTEntry param : multiValue.getParamsList()) {
String valueEntryName = param.getName();
Object res = retrievePrimitiveValueByConfigValue(param.getValue());
mapEntry.put(valueEntryName, res);
}
multiParamValueMapSymbol.addParameter(getStringFromStringValue(name), mapEntry);
}
}
private Object retrievePrimitiveValueByConfigValue(final ASTConfigValue configValue) {
if (configValue instanceof ASTIntegerValue) {
return getIntegerFromNumber((ASTIntegerValue)configValue);
......
......@@ -23,6 +23,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private NNArchitectureSymbol trainedArchitecture;
private NNArchitectureSymbol criticNetwork;
private NNArchitectureSymbol discriminatorNetwork;
private NNArchitectureSymbol qNetwork;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
......@@ -80,6 +81,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(discriminatorNetwork);
}
public Optional<NNArchitectureSymbol> getQNetwork() {
return Optional.ofNullable(qNetwork);
}
public void setCriticNetwork(NNArchitectureSymbol criticNetwork) {
this.criticNetwork = criticNetwork;
}
......@@ -88,6 +93,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
this.discriminatorNetwork = discriminatorNetwork;
}
public void setQNetwork(NNArchitectureSymbol qNetwork) {
this.qNetwork = qNetwork;
}
public Map<String, EntrySymbol> getEntryMap() {
return entryMap;
}
......@@ -121,6 +130,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return getEntryMap().containsKey(DISCRIMINATOR_NAME);
}
public boolean hasQNetwork() {
return getEntryMap().containsKey(QNETWORK_NAME);
}
public Optional<String> getCriticName() {
if (!hasCritic()) {
return Optional.empty();
......@@ -150,4 +163,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
assert discriminatorNameValue instanceof String;
return Optional.of((String)discriminatorNameValue);
}
public Optional<String> getQNetworkName() {
if (!hasQNetwork()) {
return Optional.empty();
}
final Object qnetNameValue = getEntry(QNETWORK_NAME).getValue().getValue();
assert qnetNameValue instanceof String;
return Optional.of((String)qnetNameValue);
}
}
\ No newline at end of file
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import java.util.HashMap;
import java.util.Map;
/**
*
*/
public class MultiParamValueMapSymbol extends ValueSymbol {
public static final MultiParamValueMapSymbolKind KIND = new MultiParamValueMapSymbolKind();
private Map<String, Map<String,Object>> parameters;
private Map<String, String> multiParamValueNames;
public MultiParamValueMapSymbol() {
super("", KIND);
this.parameters = new HashMap<>();
this.multiParamValueNames = new HashMap<>();
}
public Map<String, Map<String,Object>> getParameters() {
return parameters;
}
public Map<String, String> getMultiParamValueNames() { return multiParamValueNames; }
public Object getParameter(final String parameterName) {
return parameters.get(parameterName);
}
public boolean hasParameter(final String parameterName) {
return parameters.containsKey(parameterName);
}
public void addParameter(final String parameterName, final Map<String, Object> value) {
parameters.put(parameterName, value);
}
public void addMultiParamValueName(final String parameterName, final String name) {
multiParamValueNames.put(parameterName, name);
}
@Override
public String toString() {
return super.toString() + '{' + parameters + '}';
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
/**
*
*/
public class MultiParamValueMapSymbolKind extends ValueKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueMapSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || super.isKindOf(kind);
}
}
......@@ -47,9 +47,9 @@ public class ConfigEntryNameConstants {
public static final String CRITIC = "critic";
public static final String DISCRIMINATOR_NAME = "discriminator_name";
public static final String QNETWORK_NAME = "qnet_name";
public static final String PREPROCESSING_NAME = "preprocessing_name";
public static final String NOISE_DISTRIBUTION = "noise_distribution";
public static final String IMG_RESIZE = "img_resize";
public static final String IMG_RESIZE_WIDTH = "img_resize_width";
public static final String IMG_RESIZE_HEIGHT = "img_resize_height";
public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions";
public static final String CONSTRAINT_LOSS = "constraint_losses";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment