diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index ce96073351c3412831c97de26041f9e76523bb9c..b5dd4c673948fa459be41c925b8b3e8b699b049d 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -24,13 +24,18 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); ComponentNameValue implements ConfigValue = Name ("."Name)*; DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")"; - IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")"; IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]"; NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue; + CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue; + LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue; NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue; + + ShuffleDataEntry implements ConfigEntry = name:"shuffle_data" ":" value:BooleanValue; + ClipGlobalGradNormEntry implements ConfigEntry = name:"clip_global_grad_norm" ":" value:NumberValue; + OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue; TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue; LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue; @@ -52,6 +57,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface BleuEntry extends Entry; ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue; + EvalTrainEntry implements ConfigEntry = name:"eval_train" ":" value:BooleanValue; + LRPolicyValue implements ConfigValue =(fixed:"fixed" | step:"step" | exp:"exp" @@ -62,7 +69,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface OptimizerParamEntry extends Entry; - interface LossValue extends ConfigValue; + interface LossValue extends MultiParamValue; L1Loss implements LossValue = name:"l1" ("{" params:Entry* "}")?; @@ -251,17 +258,33 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number // GANs Extensions + interface MultiParamValueMapConfigEntry extends ConfigEntry; + interface MultiParamValueMapParamValue extends ConfigValue; + interface MultiParamValueMapTupleValue extends ConfigValue; + DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue; + QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue; PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue; - ImgResizeEntry implements ConfigEntry = name:"img_resize" ":" value:IntegerTupelValue; // Noise Distribution Creator NoiseDistributionEntry implements MultiParamConfigEntry = name:"noise_distribution" ":" value:NoiseDistributionValue; interface NoiseDistributionValue extends MultiParamValue; - interface NoiseDistributionGaussianEntry extends Entry; + interface NoiseDistributionParamEntry extends Entry; + interface NoiseDistributionGaussianEntry extends NoiseDistributionParamEntry; + NoiseDistributionGaussianValue implements NoiseDistributionValue = name:"gaussian" ("{" params:NoiseDistributionGaussianEntry* "}")?; MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue; SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue; + + // Constraint Distributions + ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue; + ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?; + ConstraintDistributionParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:NoiseDistributionValue; + + // Constraint losses + ConstraintLossEntry implements MultiParamValueMapConfigEntry = name:"constraint_losses" ":" value:ConstraintLossValue; + ConstraintLossValue implements MultiParamValueMapParamValue = ("{" params:ConstraintLossParam* "}")?; + ConstraintLossParam implements MultiParamValueMapTupleValue = name:StringValue ":" multiParamValue:LossValue; } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapParamValue.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapParamValue.java new file mode 100644 index 0000000000000000000000000000000000000000..66bf5fb398a199d9135e94fd4ebab651a9d5e3bd --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapParamValue.java @@ -0,0 +1,20 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._ast; + +import java.util.ArrayList; +import java.util.List; + +/** + * + */ +public interface ASTMultiParamValueMapParamValue extends ASTMultiParamValueMapParamValueTOP { + default List getParamsList() { + return new ArrayList<>(); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapTupleValue.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapTupleValue.java new file mode 100644 index 0000000000000000000000000000000000000000..ad0fa715f5b6adf9b07b8aa92c9e44762ef0b240 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_ast/ASTMultiParamValueMapTupleValue.java @@ -0,0 +1,25 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._ast; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * + */ +public interface ASTMultiParamValueMapTupleValue extends ASTMultiParamValueMapTupleValueTOP { + default ASTMultiParamValue getMultiParamValue() { + return null; + } + default ASTStringValue getName() { + return new ASTStringValue(); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java index b74004c06f28a9bfc19dc6b80994150c47989d7b..1c1ccb94ea8527c5fa7cae5825d86fe5a18bde66 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java @@ -20,6 +20,7 @@ public class CheckEntryRepetition implements CNNTrainASTEntryCoCo { private final static Set> REPEATABLE_ENTRIES = ImmutableSet .>builder() .add(ASTOptimizerParamEntry.class) + .add(ASTNoiseDistributionParamEntry.class) .build(); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index c0eb039571e620a7acc5f3d6a55fde98d9f79bb3..2614fb0efb785a73f6263adf6051f9210e4c4887 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -37,7 +37,10 @@ class ParameterAlgorithmMapping { private static final List EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList( ASTBatchSizeEntry.class, ASTLoadCheckpointEntry.class, + ASTCheckpointPeriodEntry.class, + ASTLogPeriodEntry.class, ASTEvalMetricEntry.class, + ASTEvalTrainEntry.class, ASTExcludeBleuEntry.class, ASTNormalizeEntry.class, ASTNumEpochEntry.class, @@ -117,8 +120,10 @@ class ParameterAlgorithmMapping { private static final List GENERAL_GAN_PARAMETERS = Lists.newArrayList( ASTDiscriminatorNetworkEntry.class, + ASTQNetworkEntry.class, ASTNoiseDistributionEntry.class, - ASTImgResizeEntry.class + ASTConstraintDistributionEntry.class, + ASTConstraintLossEntry.class ); ParameterAlgorithmMapping() { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index 1f791f67d698f0258c91eda3ed3fdfd0795969a5..a3b972e41208a7f6fa9d4f2482b5dffd568c745d 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -9,11 +9,13 @@ package de.monticore.lang.monticar.cnntrain._symboltable; import de.monticore.ast.ASTCNode; import de.monticore.lang.monticar.cnntrain._ast.*; +import de.monticore.lang.monticar.cnntrain._parser.CNNTrainAntlrParser; import de.monticore.symboltable.ArtifactScope; import de.monticore.symboltable.ImportStatement; import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.ResolvingConfiguration; import de.se_rwth.commons.logging.Log; +import org.antlr.v4.runtime.misc.Pair; import java.util.*; import java.util.stream.Collectors; @@ -52,6 +54,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { CNNTrainCompilationUnitSymbol compilationUnitSymbol = new CNNTrainCompilationUnitSymbol(compilationUnit.getName()); addToScopeAndLinkWithNode(compilationUnitSymbol, compilationUnit); } + @Override public void endVisit(ASTCNNTrainCompilationUnit ast) { @@ -131,6 +134,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void endVisit(ASTCheckpointPeriodEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTLogPeriodEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void endVisit(ASTNormalizeEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); @@ -139,6 +158,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void endVisit(ASTShuffleDataEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForBoolean(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTClipGlobalGradNormEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForDouble(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void visit(ASTTrainContextEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); @@ -164,6 +199,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { processMultiParamConfigEndVisit(node); } + @Override + public void endVisit(ASTEvalTrainEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForBoolean(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void visit(ASTLossEntry node) { LossSymbol loss = new LossSymbol(node.getValue().getName()); @@ -472,7 +515,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } @Override - public void visit(ASTPreprocessingEntry node) { + public void visit(ASTQNetworkEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); entry.setValue(getValueSymbolForComponentNameAsString(node.getValue())); addToScopeAndLinkWithNode(entry, node); @@ -480,19 +523,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } @Override - public void visit(ASTImgResizeEntry node) { - EntrySymbol width_entry = new EntrySymbol(node.getName()); - EntrySymbol height_entry = new EntrySymbol(node.getName()); - - width_entry.setValue(getValueSymbolForInteger(node.getValue().getFirst())); - height_entry.setValue(getValueSymbolForInteger(node.getValue().getSecond())); - addToScopeAndLinkWithNode(width_entry, node); - addToScopeAndLinkWithNode(height_entry, node); - - configuration.getEntryMap().put(node.getName() + "_width", width_entry); - configuration.getEntryMap().put(node.getName() + "_height", height_entry); + public void visit(ASTPreprocessingEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForComponentNameAsString(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); } + @Override public void visit(ASTReplayMemoryEntry node) { processMultiParamConfigVisit(node, node.getValue().getName()); @@ -525,6 +563,26 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { processMultiParamConfigEndVisit(node); } + @Override + public void visit(ASTConstraintDistributionEntry node) { + processMultiParamMapConfigVisit(node, node.getName()); + } + + @Override + public void endVisit(ASTConstraintDistributionEntry node) { + processMultiParamMapConfigEndVisit(node); + } + + @Override + public void visit(ASTConstraintLossEntry node) { + processMultiParamMapConfigVisit(node, node.getName()); + } + + @Override + public void endVisit(ASTConstraintLossEntry node) { + processMultiParamMapConfigEndVisit(node); + } + @Override public void visit(ASTNoiseDistributionEntry node) { NoiseDistribution noiseDistribution; @@ -541,6 +599,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { processMultiParamConfigEndVisit(node); } + @Override public void visit(ASTRewardFunctionEntry node) { RewardFunctionSymbol symbol = new RewardFunctionSymbol(node.getName()); @@ -623,7 +682,36 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { retrievePrimitiveValueByConfigValue(nodeParam.getValue())); } } - + + private void processMultiParamMapConfigVisit(ASTMultiParamValueMapConfigEntry node, Object value) { + EntrySymbol entry = new EntrySymbol(node.getName()); + MultiParamValueMapSymbol valueSymbol = new MultiParamValueMapSymbol(); + valueSymbol.setValue(value); + entry.setValue(valueSymbol); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + private void processMultiParamMapConfigEndVisit(ASTMultiParamValueMapConfigEntry node) { + ValueSymbol valueSymbol = configuration.getEntryMap().get(node.getName()).getValue(); + assert valueSymbol instanceof MultiParamValueMapSymbol : "Value symbol is not a multi parameter symbol"; + MultiParamValueMapSymbol multiParamValueMapSymbol = (MultiParamValueMapSymbol)valueSymbol; + for (ASTConfigValue nodeParam : ((ASTMultiParamValueMapParamValue)node.getValue()).getParamsList()) { + ASTMultiParamValueMapTupleValue tuple = ((ASTMultiParamValueMapTupleValue)nodeParam); + ASTStringValue name = tuple.getName(); + ASTMultiParamValue multiValue = tuple.getMultiParamValue(); + String valueName = multiValue.getName(); + multiParamValueMapSymbol.addMultiParamValueName(getStringFromStringValue(name), valueName); + HashMap mapEntry = new HashMap<>(); + for (ASTEntry param : multiValue.getParamsList()) { + String valueEntryName = param.getName(); + Object res = retrievePrimitiveValueByConfigValue(param.getValue()); + mapEntry.put(valueEntryName, res); + } + multiParamValueMapSymbol.addParameter(getStringFromStringValue(name), mapEntry); + } + } + private Object retrievePrimitiveValueByConfigValue(final ASTConfigValue configValue) { if (configValue instanceof ASTIntegerValue) { return getIntegerFromNumber((ASTIntegerValue)configValue); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index a35fc052db14287f68559337fa603b18f2c79099..b55671ed04c8181e87dcc6ec07825562bf888091 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -23,6 +23,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private NNArchitectureSymbol trainedArchitecture; private NNArchitectureSymbol criticNetwork; private NNArchitectureSymbol discriminatorNetwork; + private NNArchitectureSymbol qNetwork; public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind(); @@ -80,6 +81,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return Optional.ofNullable(discriminatorNetwork); } + public Optional getQNetwork() { + return Optional.ofNullable(qNetwork); + } + public void setCriticNetwork(NNArchitectureSymbol criticNetwork) { this.criticNetwork = criticNetwork; } @@ -88,6 +93,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { this.discriminatorNetwork = discriminatorNetwork; } + public void setQNetwork(NNArchitectureSymbol qNetwork) { + this.qNetwork = qNetwork; + } + public Map getEntryMap() { return entryMap; } @@ -121,6 +130,10 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return getEntryMap().containsKey(DISCRIMINATOR_NAME); } + public boolean hasQNetwork() { + return getEntryMap().containsKey(QNETWORK_NAME); + } + public Optional getCriticName() { if (!hasCritic()) { return Optional.empty(); @@ -150,4 +163,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { assert discriminatorNameValue instanceof String; return Optional.of((String)discriminatorNameValue); } + + public Optional getQNetworkName() { + if (!hasQNetwork()) { + return Optional.empty(); + } + + final Object qnetNameValue = getEntry(QNETWORK_NAME).getValue().getValue(); + assert qnetNameValue instanceof String; + return Optional.of((String)qnetNameValue); + } } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbol.java new file mode 100644 index 0000000000000000000000000000000000000000..3d744a06e2f0f1e972b1cd992e4235eefc1fab0a --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbol.java @@ -0,0 +1,54 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +import java.util.HashMap; +import java.util.Map; + +/** + * + */ +public class MultiParamValueMapSymbol extends ValueSymbol { + public static final MultiParamValueMapSymbolKind KIND = new MultiParamValueMapSymbolKind(); + + private Map> parameters; + private Map multiParamValueNames; + + public MultiParamValueMapSymbol() { + super("", KIND); + this.parameters = new HashMap<>(); + this.multiParamValueNames = new HashMap<>(); + } + + public Map> getParameters() { + return parameters; + } + + public Map getMultiParamValueNames() { return multiParamValueNames; } + + public Object getParameter(final String parameterName) { + return parameters.get(parameterName); + } + + public boolean hasParameter(final String parameterName) { + return parameters.containsKey(parameterName); + } + + public void addParameter(final String parameterName, final Map value) { + parameters.put(parameterName, value); + } + + public void addMultiParamValueName(final String parameterName, final String name) { + multiParamValueNames.put(parameterName, name); + } + + @Override + public String toString() { + return super.toString() + '{' + parameters + '}'; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbolKind.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbolKind.java new file mode 100644 index 0000000000000000000000000000000000000000..8349bc7a268738e11eb7936ecc98e2ae0bddcfac --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueMapSymbolKind.java @@ -0,0 +1,27 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +import de.monticore.symboltable.SymbolKind; + +/** + * + */ +public class MultiParamValueMapSymbolKind extends ValueKind { + private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueMapSymbolKind"; + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean isKindOf(SymbolKind kind) { + return NAME.equals(kind.getName()) || super.isKindOf(kind); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java index aef5d1c28ba4fbd59592a50da05b916a080451c6..d6ffe343690a9e4a17afbe797b69cd016fa85c69 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java @@ -47,9 +47,9 @@ public class ConfigEntryNameConstants { public static final String CRITIC = "critic"; public static final String DISCRIMINATOR_NAME = "discriminator_name"; + public static final String QNETWORK_NAME = "qnet_name"; public static final String PREPROCESSING_NAME = "preprocessing_name"; public static final String NOISE_DISTRIBUTION = "noise_distribution"; - public static final String IMG_RESIZE = "img_resize"; - public static final String IMG_RESIZE_WIDTH = "img_resize_width"; - public static final String IMG_RESIZE_HEIGHT = "img_resize_height"; + public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions"; + public static final String CONSTRAINT_LOSS = "constraint_losses"; }