Commit 19d9ed72 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Added several parameters

See merge request !26
parents 6d091d65 7574b6e6
Pipeline #228878 passed with stages
...@@ -24,13 +24,18 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -24,13 +24,18 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
ComponentNameValue implements ConfigValue = Name ("."Name)*; ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")"; DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")";
IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]"; IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue; LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue;
LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue; NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
ShuffleDataEntry implements ConfigEntry = name:"shuffle_data" ":" value:BooleanValue;
ClipGlobalGradNormEntry implements ConfigEntry = name:"clip_global_grad_norm" ":" value:NumberValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue; OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue; TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue; LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
...@@ -52,6 +57,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -52,6 +57,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface BleuEntry extends Entry; interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue; ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
EvalTrainEntry implements ConfigEntry = name:"eval_train" ":" value:BooleanValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed" LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step" | step:"step"
| exp:"exp" | exp:"exp"
...@@ -62,7 +69,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -62,7 +69,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface OptimizerParamEntry extends Entry; interface OptimizerParamEntry extends Entry;
interface LossValue extends ConfigValue; interface LossValue extends MultiParamValue;
L1Loss implements LossValue = name:"l1" ("{" params:Entry* "}")?; L1Loss implements LossValue = name:"l1" ("{" params:Entry* "}")?;
...@@ -251,17 +258,33 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -251,17 +258,33 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
// GANs Extensions // GANs Extensions
interface MultiParamValueMapConfigEntry extends ConfigEntry;
interface MultiParamValueMapParamValue extends ConfigValue;
interface MultiParamValueMapTupleValue extends ConfigValue;
DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue; DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue;
QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue;
PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue; PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue;
ImgResizeEntry implements ConfigEntry = name:"img_resize" ":" value:IntegerTupelValue;
// Noise Distribution Creator // Noise Distribution Creator
NoiseDistributionEntry implements MultiParamConfigEntry = name:"noise_distribution" ":" value:NoiseDistributionValue; NoiseDistributionEntry implements MultiParamConfigEntry = name:"noise_distribution" ":" value:NoiseDistributionValue;
interface NoiseDistributionValue extends MultiParamValue; interface NoiseDistributionValue extends MultiParamValue;
interface NoiseDistributionGaussianEntry extends Entry; interface NoiseDistributionParamEntry extends Entry;
interface NoiseDistributionGaussianEntry extends NoiseDistributionParamEntry;
NoiseDistributionGaussianValue implements NoiseDistributionValue = name:"gaussian" ("{" params:NoiseDistributionGaussianEntry* "}")?; NoiseDistributionGaussianValue implements NoiseDistributionValue = name:"gaussian" ("{" params:NoiseDistributionGaussianEntry* "}")?;
MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue; MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue;
SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue; SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue;
// Constraint Distributions
ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue;
ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?;
ConstraintDistributionParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:NoiseDistributionValue;
// Constraint losses
ConstraintLossEntry implements MultiParamValueMapConfigEntry = name:"constraint_losses" ":" value:ConstraintLossValue;
ConstraintLossValue implements MultiParamValueMapParamValue = ("{" params:ConstraintLossParam* "}")?;
ConstraintLossParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:LossValue;
} }
\ No newline at end of file
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.ArrayList;
import java.util.List;
/**
*
*/
public interface ASTMultiParamValueMapParamValue extends ASTMultiParamValueMapParamValueTOP {
default List<? extends ASTConfigValue> getParamsList() {
return new ArrayList<>();
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._ast;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
*
*/
public interface ASTMultiParamValueMapTupleValue extends ASTMultiParamValueMapTupleValueTOP {
default ASTMultiParamValue getMultiParamValue() {
return null;
}
default ASTStringValue getName() {
return new ASTStringValue();
}
}
...@@ -20,6 +20,7 @@ public class CheckEntryRepetition implements CNNTrainASTEntryCoCo { ...@@ -20,6 +20,7 @@ public class CheckEntryRepetition implements CNNTrainASTEntryCoCo {
private final static Set<Class<? extends ASTEntry>> REPEATABLE_ENTRIES = ImmutableSet private final static Set<Class<? extends ASTEntry>> REPEATABLE_ENTRIES = ImmutableSet
.<Class<? extends ASTEntry>>builder() .<Class<? extends ASTEntry>>builder()
.add(ASTOptimizerParamEntry.class) .add(ASTOptimizerParamEntry.class)
.add(ASTNoiseDistributionParamEntry.class)
.build(); .build();
......
...@@ -37,7 +37,10 @@ class ParameterAlgorithmMapping { ...@@ -37,7 +37,10 @@ class ParameterAlgorithmMapping {
private static final List<Class> EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList( private static final List<Class> EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList(
ASTBatchSizeEntry.class, ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class, ASTLoadCheckpointEntry.class,
ASTCheckpointPeriodEntry.class,
ASTLogPeriodEntry.class,
ASTEvalMetricEntry.class, ASTEvalMetricEntry.class,
ASTEvalTrainEntry.class,
ASTExcludeBleuEntry.class, ASTExcludeBleuEntry.class,
ASTNormalizeEntry.class, ASTNormalizeEntry.class,
ASTNumEpochEntry.class, ASTNumEpochEntry.class,
...@@ -117,8 +120,10 @@ class ParameterAlgorithmMapping { ...@@ -117,8 +120,10 @@ class ParameterAlgorithmMapping {
private static final List<Class> GENERAL_GAN_PARAMETERS = Lists.newArrayList( private static final List<Class> GENERAL_GAN_PARAMETERS = Lists.newArrayList(
ASTDiscriminatorNetworkEntry.class, ASTDiscriminatorNetworkEntry.class,
ASTQNetworkEntry.class,
ASTNoiseDistributionEntry.class, ASTNoiseDistributionEntry.class,
ASTImgResizeEntry.class ASTConstraintDistributionEntry.class,
ASTConstraintLossEntry.class
); );
ParameterAlgorithmMapping() { ParameterAlgorithmMapping() {
......
...@@ -9,11 +9,13 @@ package de.monticore.lang.monticar.cnntrain._symboltable; ...@@ -9,11 +9,13 @@ package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.ast.ASTCNode; import de.monticore.ast.ASTCNode;
import de.monticore.lang.monticar.cnntrain._ast.*; import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._parser.CNNTrainAntlrParser;
import de.monticore.symboltable.ArtifactScope; import de.monticore.symboltable.ArtifactScope;
import de.monticore.symboltable.ImportStatement; import de.monticore.symboltable.ImportStatement;
import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.ResolvingConfiguration; import de.monticore.symboltable.ResolvingConfiguration;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import org.antlr.v4.runtime.misc.Pair;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
...@@ -53,6 +55,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -53,6 +55,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
addToScopeAndLinkWithNode(compilationUnitSymbol, compilationUnit); addToScopeAndLinkWithNode(compilationUnitSymbol, compilationUnit);
} }
@Override @Override
public void endVisit(ASTCNNTrainCompilationUnit ast) { public void endVisit(ASTCNNTrainCompilationUnit ast) {
CNNTrainCompilationUnitSymbol compilationUnitSymbol = (CNNTrainCompilationUnitSymbol) ast.getSymbolOpt().get(); CNNTrainCompilationUnitSymbol compilationUnitSymbol = (CNNTrainCompilationUnitSymbol) ast.getSymbolOpt().get();
...@@ -131,6 +134,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -131,6 +134,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry); configuration.getEntryMap().put(node.getName(), entry);
} }
@Override
public void endVisit(ASTCheckpointPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTLogPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void endVisit(ASTNormalizeEntry node) { public void endVisit(ASTNormalizeEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
...@@ -139,6 +158,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -139,6 +158,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry); configuration.getEntryMap().put(node.getName(), entry);
} }
@Override
public void endVisit(ASTShuffleDataEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTClipGlobalGradNormEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void visit(ASTTrainContextEntry node) { public void visit(ASTTrainContextEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
...@@ -164,6 +199,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -164,6 +199,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node); processMultiParamConfigEndVisit(node);
} }
@Override
public void endVisit(ASTEvalTrainEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void visit(ASTLossEntry node) { public void visit(ASTLossEntry node) {
LossSymbol loss = new LossSymbol(node.getValue().getName()); LossSymbol loss = new LossSymbol(node.getValue().getName());
...@@ -472,7 +515,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -472,7 +515,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
} }
@Override @Override
public void visit(ASTPreprocessingEntry node) { public void visit(ASTQNetworkEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue())); entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node); addToScopeAndLinkWithNode(entry, node);
...@@ -480,19 +523,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -480,19 +523,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
} }
@Override @Override
public void visit(ASTImgResizeEntry node) { public void visit(ASTPreprocessingEntry node) {
EntrySymbol width_entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
EntrySymbol height_entry = new EntrySymbol(node.getName()); entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
width_entry.setValue(getValueSymbolForInteger(node.getValue().getFirst())); configuration.getEntryMap().put(node.getName(), entry);
height_entry.setValue(getValueSymbolForInteger(node.getValue().getSecond()));
addToScopeAndLinkWithNode(width_entry, node);
addToScopeAndLinkWithNode(height_entry, node);
configuration.getEntryMap().put(node.getName() + "_width", width_entry);
configuration.getEntryMap().put(node.getName() + "_height", height_entry);
} }
@Override @Override
public void visit(ASTReplayMemoryEntry node) { public void visit(ASTReplayMemoryEntry node) {
processMultiParamConfigVisit(node, node.getValue().getName()); processMultiParamConfigVisit(node, node.getValue().getName());
...@@ -525,6 +563,26 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -525,6 +563,26 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node); processMultiParamConfigEndVisit(node);
} }
@Override
public void visit(ASTConstraintDistributionEntry node) {
processMultiParamMapConfigVisit(node, node.getName());
}
@Override
public void endVisit(ASTConstraintDistributionEntry node) {
processMultiParamMapConfigEndVisit(node);
}
@Override
public void visit(ASTConstraintLossEntry node) {
processMultiParamMapConfigVisit(node, node.getName());
}
@Override
public void endVisit(ASTConstraintLossEntry node) {
processMultiParamMapConfigEndVisit(node);
}
@Override @Override
public void visit(ASTNoiseDistributionEntry node) { public void visit(ASTNoiseDistributionEntry node) {
NoiseDistribution noiseDistribution; NoiseDistribution noiseDistribution;
...@@ -541,6 +599,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -541,6 +599,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node); processMultiParamConfigEndVisit(node);
} }
@Override @Override
public void visit(ASTRewardFunctionEntry node) { public void visit(ASTRewardFunctionEntry node) {
RewardFunctionSymbol symbol = new RewardFunctionSymbol(node.getName()); RewardFunctionSymbol symbol = new RewardFunctionSymbol(node.getName());
...@@ -624,6 +683,35 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -624,6 +683,35 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
} }
} }
private void processMultiParamMapConfigVisit(ASTMultiParamValueMapConfigEntry node, Object value) {
EntrySymbol entry = new EntrySymbol(node.getName());
MultiParamValueMapSymbol valueSymbol = new MultiParamValueMapSymbol();
valueSymbol.setValue(value);
entry.setValue(valueSymbol);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
private void processMultiParamMapConfigEndVisit(ASTMultiParamValueMapConfigEntry node) {
ValueSymbol valueSymbol = configuration.getEntryMap().get(node.getName()).getValue();
assert valueSymbol instanceof MultiParamValueMapSymbol : "Value symbol is not a multi parameter symbol";
MultiParamValueMapSymbol multiParamValueMapSymbol = (MultiParamValueMapSymbol)valueSymbol;
for (ASTConfigValue nodeParam : ((ASTMultiParamValueMapParamValue)node.getValue()).getParamsList()) {
ASTMultiParamValueMapTupleValue tuple = ((ASTMultiParamValueMapTupleValue)nodeParam);
ASTStringValue name = tuple.getName();
ASTMultiParamValue multiValue = tuple.getMultiParamValue();
String valueName = multiValue.getName();
multiParamValueMapSymbol.addMultiParamValueName(getStringFromStringValue(name), valueName);
HashMap<String, Object> mapEntry = new HashMap<>();
for (ASTEntry param : multiValue.getParamsList()) {
String valueEntryName = param.getName();
Object res = retrievePrimitiveValueByConfigValue(param.getValue());
mapEntry.put(valueEntryName, res);
}
multiParamValueMapSymbol.addParameter(getStringFromStringValue(name), mapEntry);
}
}
private Object retrievePrimitiveValueByConfigValue(final ASTConfigValue configValue) { private Object retrievePrimitiveValueByConfigValue(final ASTConfigValue configValue) {
if (configValue instanceof ASTIntegerValue) { if (configValue instanceof ASTIntegerValue) {
return getIntegerFromNumber((ASTIntegerValue)configValue); return getIntegerFromNumber((ASTIntegerValue)configValue);
......
...@@ -23,6 +23,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -23,6 +23,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private NNArchitectureSymbol trainedArchitecture; private NNArchitectureSymbol trainedArchitecture;
private NNArchitectureSymbol criticNetwork; private NNArchitectureSymbol criticNetwork;
private NNArchitectureSymbol discriminatorNetwork; private NNArchitectureSymbol discriminatorNetwork;
private NNArchitectureSymbol qNetwork;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind(); public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
...@@ -80,6 +81,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -80,6 +81,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(discriminatorNetwork); return Optional.ofNullable(discriminatorNetwork);
} }
public Optional<NNArchitectureSymbol> getQNetwork() {
return Optional.ofNullable(qNetwork);
}
public void setCriticNetwork(NNArchitectureSymbol criticNetwork) { public void setCriticNetwork(NNArchitectureSymbol criticNetwork) {
this.criticNetwork = criticNetwork; this.criticNetwork = criticNetwork;
} }
...@@ -88,6 +93,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -88,6 +93,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
this.discriminatorNetwork = discriminatorNetwork; this.discriminatorNetwork = discriminatorNetwork;
} }
public void setQNetwork(NNArchitectureSymbol qNetwork) {
this.qNetwork = qNetwork;
}
public Map<String, EntrySymbol> getEntryMap() { public Map<String, EntrySymbol> getEntryMap() {
return entryMap; return entryMap;
} }
...@@ -121,6 +130,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -121,6 +130,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return getEntryMap().containsKey(DISCRIMINATOR_NAME); return getEntryMap().containsKey(DISCRIMINATOR_NAME);
} }
public boolean hasQNetwork() {
return getEntryMap().containsKey(QNETWORK_NAME);
}
public Optional<String> getCriticName() { public Optional<String> getCriticName() {
if (!hasCritic()) { if (!hasCritic()) {
return Optional.empty(); return Optional.empty();
...@@ -150,4 +163,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -150,4 +163,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
assert discriminatorNameValue instanceof String; assert discriminatorNameValue instanceof String;
return Optional.of((String)discriminatorNameValue); return Optional.of((String)discriminatorNameValue);
} }
public Optional<String> getQNetworkName() {
if (!hasQNetwork()) {
return Optional.empty();
}
final Object qnetNameValue = getEntry(QNETWORK_NAME).getValue().getValue();
assert qnetNameValue instanceof String;
return Optional.of((String)qnetNameValue);
}
} }
\ No newline at end of file
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import java.util.HashMap;
import java.util.Map;
/**
*
*/
public class MultiParamValueMapSymbol extends ValueSymbol {
public static final MultiParamValueMapSymbolKind KIND = new MultiParamValueMapSymbolKind();
private Map<String, Map<String,Object>> parameters;
private Map<String, String> multiParamValueNames;
public MultiParamValueMapSymbol() {
super("", KIND);
this.parameters = new HashMap<>();
this.multiParamValueNames = new HashMap<>();
}
public Map<String, Map<String,Object>> getParameters() {
return parameters;
}
public Map<String, String> getMultiParamValueNames() { return multiParamValueNames; }
public Object getParameter(final String parameterName) {
return parameters.get(parameterName);
}
public boolean hasParameter(final String parameterName) {
return parameters.containsKey(parameterName);
}
public void addParameter(final String parameterName, final Map<String, Object> value) {