diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index fe392d52ddb44e928a7dda9c2235d6dd8a05ce2f..97540cde6bc51efcb6b81d241b4088db1caf0083 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -259,10 +259,17 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue; // GANs Extensions + KValueEntry implements ConfigEntry = name:"k_value" ":" value:IntegerValue; - GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:StringValue; - ConditionalInputEntry implements ConfigEntry = name:"conditional_input" ":" value:StringValue; + GeneratorTargetNameEntry implements ConfigEntry = name:"generator_target_name" ":" value:StringValue; + GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:GeneratorLossValue; + GeneratorLossValue implements ConfigValue = (l1: "l1" | l2: "l2"); NoiseInputEntry implements ConfigEntry = name:"noise_input" ":" value:StringValue; + GeneratorLossWeightEntry implements ConfigEntry = name:"generator_loss_weight" ":" value:NumberValue; + DiscriminatorLossWeightEntry implements ConfigEntry = name:"discriminator_loss_weight" ":" value:NumberValue; + + SpeedPeriodEntry implements ConfigEntry = name:"speed_period" ":" value:IntegerValue; + PrintImagesEntry implements ConfigEntry = name:"print_images" ":" value:BooleanValue; interface MultiParamValueMapConfigEntry extends ConfigEntry; interface MultiParamValueMapParamValue extends ConfigValue; @@ -285,6 +292,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue; SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue; + NoiseDistributionUniformValue implements NoiseDistributionValue = name:"uniform" ("{" "}")?; + // Constraint Distributions ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue; ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 14e67217cc3332b3c89a48ea06aa87e15ec3579d..c2ee74a36092a3a7014dfb7b49ff37d98a5f50e5 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -51,4 +51,11 @@ public class CNNTrainCocos { .addCoCo(new CheckCriticNetworkInputs()); checker.checkAll(configurationSymbol); } + + public static void checkGANCocos(final ConfigurationSymbol configurationSymbol) { + CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker() + .addCoCo(new CheckGANNetworkPorts()) + .addCoCo(new CheckGANConfigurationDependencies()); + checker.checkAll(configurationSymbol); + } } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANConfigurationDependencies.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANConfigurationDependencies.java new file mode 100644 index 0000000000000000000000000000000000000000..ff2b314934f984f64efc1852ba0c829aa05d4565 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANConfigurationDependencies.java @@ -0,0 +1,60 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._ast.ASTEntry; +import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry; +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod; +import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; +import sun.security.krb5.internal.ccache.CredentialsCache; + +public class CheckGANConfigurationDependencies implements CNNTrainConfigurationSymbolCoCo{ + + public CheckGANConfigurationDependencies() { } + + @Override + public void check(ConfigurationSymbol configurationSymbol) { + + if(configurationSymbol.getLearningMethod() == LearningMethod.GAN) { + + if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_LOSS) != null) + if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_TARGET_NAME) == null) + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + + " Generator loss specified but conditional input is missing"); + + if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_TARGET_NAME) != null) + if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_LOSS) == null) + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + + " Conditional input specified but generator loss is missing"); + + if (configurationSymbol.getEntry(ConfigEntryNameConstants.LOSS) != null) + Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER + + " Loss parameter not valid for GAN learning"); + + if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_INPUT) != null) + if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_DISTRIBUTION) == null) + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + + " Noise input specified but noise distribution parameter is missing"); + + if (configurationSymbol.getEntry(ConfigEntryNameConstants.CONSTRAINT_DISTRIBUTION) != null) + if (configurationSymbol.getEntry(ConfigEntryNameConstants.QNETWORK_NAME) == null) + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + + " Constraint distributions are given but q-network is missing"); + + if (configurationSymbol.getEntry(ConfigEntryNameConstants.CONSTRAINT_LOSS) != null) + if (configurationSymbol.getEntry(ConfigEntryNameConstants.QNETWORK_NAME) == null) + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + + " Constraint losses are given but q-network is missing"); + + if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_INPUT) == null) + Log.warn(" No noise input specified. Are you sure this is correct?"); + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANNetworkPorts.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANNetworkPorts.java new file mode 100644 index 0000000000000000000000000000000000000000..ef9a024a06eadf05815c516e6aa3e8d2f37b4653 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckGANNetworkPorts.java @@ -0,0 +1,61 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Optional; + +public class CheckGANNetworkPorts implements CNNTrainConfigurationSymbolCoCo { + + public void CheckGANNetworkPorts() { } + + @Override + public void check(ConfigurationSymbol configurationSymbol) { + + NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get(); + NNArchitectureSymbol dis = configurationSymbol.getDiscriminatorNetwork().get(); + Optional qnet = configurationSymbol.getQNetwork(); + + if(gen.getOutputs().size() != 1) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator network has more then one output, " + + "but is supposed to only have one"); + + if(qnet.isPresent() && qnet.get().getInputs().size() != 1) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Q-Network has more then one input, " + + "but is supposed to only have one"); + + if(qnet.isPresent() && dis.getOutputs().size() != 2) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 2 output " + + "ports when q-network is given"); + + if(!qnet.isPresent() && dis.getOutputs().size() != 1) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 1 output " + + "port when no q-network is given"); + + if(qnet.isPresent() && dis.getOutputs().size() == 2) + if(!dis.getOutputs().get(1).equals("features")) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Second output of discriminator network " + + "has to be named features when " + + "q-network is given"); + + if(qnet.isPresent() && !qnet.get().getInputs().get(0).equals("features")) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Input to q-network needs to be named features"); + + if(!gen.getOutputs().get(0).equals(dis.getInputs().get(0))) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " The generator networks output name does not " + + "fit the first discriminators input name"); + + if(qnet.isPresent()) + if(gen.getInputs().contains(qnet.get().getOutputs())) + Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator input does not contain all " + + "latent-codes outputted by q-network"); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java index 98eb48246174b06e696fe1ff723ac6c063317eaa..c82b6de8d3446b48b6c3db16c993245f61cabc73 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java @@ -53,16 +53,18 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { = parameterAlgorithmMapping.isSupervisedLearningParameter(node.getClass()); final boolean reinforcementLearningParameter = parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass()); + final boolean ganLearningParameter + = parameterAlgorithmMapping.isGANLearningParameter(node.getClass()); - assert (supervisedLearningParameter || reinforcementLearningParameter) : + assert (supervisedLearningParameter || reinforcementLearningParameter || ganLearningParameter) : "Parameter " + node.getName() + " is not checkable, because it is unknown to Condition"; if (supervisedLearningParameter && !reinforcementLearningParameter) { setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node); } else if(!supervisedLearningParameter && reinforcementLearningParameter) { setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node); - } } +} private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) { if (isLearningMethodKnown()) { @@ -91,11 +93,10 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { private void evaluateLearningMethodEntry(ASTEntry node) { ASTLearningMethodValue learningMethodValue = (ASTLearningMethodValue)node.getValue(); LearningMethod evaluatedLearningMethod; - if(learningMethodValue.isPresentReinforcement()) { + if(learningMethodValue.isPresentReinforcement()) evaluatedLearningMethod = LearningMethod.REINFORCEMENT; - } else { + else evaluatedLearningMethod = LearningMethod.SUPERVISED; - } if (isLearningMethodKnown()) { logErrorIfEvaluatedLearningMethoNotEqualToActual(node, evaluatedLearningMethod); @@ -127,16 +128,16 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { if (learningMethod.equals(LearningMethod.REINFORCEMENT)) { return parameterAlgorithmMapping.getAllReinforcementParameters(); } + return parameterAlgorithmMapping.getAllSupervisedParameters(); } private void setLearningMethod(final LearningMethod learningMethod) { - if (learningMethod.equals(LearningMethod.REINFORCEMENT)) { + if (learningMethod.equals(LearningMethod.REINFORCEMENT)) setLearningMethodToReinforcement(); - } else { + else setLearningMethodToSupervised(); - } } private void setLearningMethodToSupervised() { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 2fd8436a63d4774e43df49849b96056e0b9cdee6..272b49666466d27c289bf17a9b1755f4c994fa0d 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -129,9 +129,12 @@ class ParameterAlgorithmMapping { ASTDiscriminatorOptimizerEntry.class, ASTKValueEntry.class, ASTGeneratorLossEntry.class, - ASTConditionalInputEntry.class, - ASTNoiseInputEntry.class - + ASTGeneratorTargetNameEntry.class, + ASTNoiseInputEntry.class, + ASTGeneratorLossWeightEntry.class, + ASTDiscriminatorLossWeightEntry.class, + ASTSpeedPeriodEntry.class, + ASTPrintImagesEntry.class ); ParameterAlgorithmMapping() { @@ -157,7 +160,13 @@ class ParameterAlgorithmMapping { boolean isSupervisedLearningParameter(Class entryClazz) { return GENERAL_PARAMETERS.contains(entryClazz) - || EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz); + + } + + boolean isGANLearningParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz) || GENERAL_GAN_PARAMETERS.contains(entryClazz); } @@ -180,6 +189,14 @@ class ParameterAlgorithmMapping { || EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz); } + List getAllGANParameters() { + return ImmutableList. builder() + .addAll(GENERAL_PARAMETERS) + .addAll(EXCLUSIVE_SUPERVISED_PARAMETERS) + .addAll(GENERAL_GAN_PARAMETERS) + .build(); + } + List getAllReinforcementParameters() { return ImmutableList. builder() .addAll(GENERAL_PARAMETERS) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index c059a9f603dd8f2312a444a388a3ce571a8e87f2..49bf4d626ccad5c8aeaf419592b8010c926bef07 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -145,7 +145,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } @Override - public void endVisit(ASTGeneratorLossEntry node) { + public void endVisit(ASTNoiseInputEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); entry.setValue(getValueSymbolForString(node.getValue())); addToScopeAndLinkWithNode(entry, node); @@ -153,7 +153,39 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } @Override - public void endVisit(ASTConditionalInputEntry node) { + public void endVisit(ASTGeneratorLossWeightEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForDouble(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTDiscriminatorLossWeightEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForDouble(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTSpeedPeriodEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTPrintImagesEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForBoolean(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTGeneratorTargetNameEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); entry.setValue(getValueSymbolForString(node.getValue())); addToScopeAndLinkWithNode(entry, node); @@ -442,6 +474,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTGeneratorLossEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + ValueSymbol value = new ValueSymbol(); + + if (node.getValue().isPresentL1()) { + value.setValue(GeneratorLoss.L1); + } else if (node.getValue().isPresentL2()) { + value.setValue(GeneratorLoss.L2); + } + + entry.setValue(value); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void visit(ASTRLAlgorithmEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); @@ -564,14 +612,6 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } - @Override - public void visit(ASTPreprocessingEntry node) { - EntrySymbol entry = new EntrySymbol(node.getName()); - entry.setValue(getValueSymbolForComponentNameAsString(node.getValue())); - addToScopeAndLinkWithNode(entry, node); - configuration.getEntryMap().put(node.getName(), entry); - } - @Override public void visit(ASTReplayMemoryEntry node) { @@ -628,11 +668,13 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { @Override public void visit(ASTNoiseDistributionEntry node) { NoiseDistribution noiseDistribution; - if(node.getValue().getName().equals("gaussian")) { + if(node.getValue().getName().equals("gaussian")) noiseDistribution = NoiseDistribution.GAUSSIAN; - } else { + else if (node.getValue().getName().equals("uniform")) + noiseDistribution = NoiseDistribution.UNIFORM; + else noiseDistribution = NoiseDistribution.GAUSSIAN; - } + processMultiParamConfigVisit(node, noiseDistribution); } @@ -650,6 +692,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { addToScopeAndLinkWithNode(symbol, node); } + @Override + public void visit(ASTPreprocessingEntry node) { + PreprocessingComponentSymbol symbol = new PreprocessingComponentSymbol(node.getName()); + symbol.setPreprocessingComponentName(node.getValue().getNameList()); + configuration.setPreprocessingComponent(symbol); + addToScopeAndLinkWithNode(symbol, node); + } + @Override public void visit(ASTSoftTargetUpdateRateEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index b55671ed04c8181e87dcc6ec07825562bf888091..d236b7c4534476046ff55655929b7062929dbea1 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -20,6 +20,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private OptimizerSymbol criticOptimizer; private LossSymbol loss; private RewardFunctionSymbol rlRewardFunctionSymbol; + private PreprocessingComponentSymbol preprocessingComponentSymbol; private NNArchitectureSymbol trainedArchitecture; private NNArchitectureSymbol criticNetwork; private NNArchitectureSymbol discriminatorNetwork; @@ -30,6 +31,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { public ConfigurationSymbol() { super("", KIND); rlRewardFunctionSymbol = null; + preprocessingComponentSymbol = null; trainedArchitecture = null; } @@ -65,6 +67,18 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return Optional.ofNullable(this.rlRewardFunctionSymbol); } + public void setPreprocessingComponent(PreprocessingComponentSymbol preprocessingComponentSymbol) { + this.preprocessingComponentSymbol = preprocessingComponentSymbol; + } + + public Optional getPreprocessingComponent() { + return Optional.ofNullable(this.preprocessingComponentSymbol); + } + + public boolean hasPreprocessor() { + return this.preprocessingComponentSymbol != null; + } + public Optional getTrainedArchitecture() { return Optional.ofNullable(trainedArchitecture); } @@ -118,10 +132,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return getLearningMethod().equals(LearningMethod.GAN); } - public boolean hasPreprocessor() { - return getEntryMap().containsKey(PREPROCESSING_NAME); - } - public boolean hasCritic() { return getEntryMap().containsKey(CRITIC); } @@ -144,16 +154,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return Optional.of((String)criticNameValue); } - public Optional getPreprocessingName() { - if (!hasPreprocessor()) { - return Optional.empty(); - } - - final Object preprocessingNameValue = getEntry(PREPROCESSING_NAME).getValue().getValue(); - assert preprocessingNameValue instanceof String; - return Optional.of((String)preprocessingNameValue); - } - public Optional getDiscriminatorName() { if (!hasDiscriminator()) { return Optional.empty(); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/GeneratorLoss.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/GeneratorLoss.java new file mode 100644 index 0000000000000000000000000000000000000000..f00d3ec433b9af4e872e004ba9830ce836d7f958 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/GeneratorLoss.java @@ -0,0 +1,22 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +public enum GeneratorLoss { + L1{ + @Override + public String toString() { + return "l1"; + } + }, + L2{ + @Override + public String toString() { + return "l2"; + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NoiseDistribution.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NoiseDistribution.java index 999857efe819e89bf627ddc69fb8c0f0e2871ba6..af87e7d71ec1f8f8e3d61cfe58680b0e750831ba 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NoiseDistribution.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NoiseDistribution.java @@ -12,5 +12,11 @@ public enum NoiseDistribution { public String toString() { return "gaussian"; } + }, + UNIFORM{ + @Override + public String toString() { + return "uniform"; + } } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/PreprocessingComponentSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/PreprocessingComponentSymbol.java new file mode 100644 index 0000000000000000000000000000000000000000..65d4609b21f5039e32fa764ccff91f9000616f8d --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/PreprocessingComponentSymbol.java @@ -0,0 +1,48 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +import com.google.common.collect.Lists; +import de.monticore.lang.monticar.cnntrain.annotations.PreprocessingComponentParameter; +import de.monticore.lang.monticar.cnntrain.annotations.PreprocessingComponentParameter; +import de.monticore.symboltable.CommonSymbol; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * + */ +public class PreprocessingComponentSymbol extends CommonSymbol { + public static final PreprocessingComponentSymbolKind KIND = new PreprocessingComponentSymbolKind(); + + private List preprocessingComponentName; + private PreprocessingComponentParameter preprocessingComponentParameter; + + public PreprocessingComponentSymbol(String name) { + super(name, KIND); + preprocessingComponentName = new ArrayList<>(); + } + + protected void setPreprocessingComponentName(List preprocessingComponentNamePath) { + this.preprocessingComponentName = Lists.newArrayList(preprocessingComponentNamePath); + } + + public List getPreprocessingComponentName() { + return Lists.newArrayList(preprocessingComponentName); + } + + public void setPreprocessingComponentParameter(PreprocessingComponentParameter preprocessingComponentParameter) { + this.preprocessingComponentParameter = preprocessingComponentParameter; + } + + public Optional getPreprocessingComponentParameter() { + return Optional.ofNullable(preprocessingComponentParameter); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/PreprocessingComponentSymbolKind.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/PreprocessingComponentSymbolKind.java new file mode 100644 index 0000000000000000000000000000000000000000..eac871a798625a33a4814c57376bbcedc80318f8 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/PreprocessingComponentSymbolKind.java @@ -0,0 +1,27 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +import de.monticore.symboltable.SymbolKind; + +/** + * + */ +public class PreprocessingComponentSymbolKind implements SymbolKind { + private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.PreprocessingComponentSymbolKind"; + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean isKindOf(SymbolKind kind) { + return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/PreprocessingComponentParameter.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/PreprocessingComponentParameter.java new file mode 100644 index 0000000000000000000000000000000000000000..87533a7e67c2f9363b8b17666f9c315388b3a67f --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/PreprocessingComponentParameter.java @@ -0,0 +1,23 @@ +/** + * (c) https://github.com/MontiCore/monticore + * + * The license generally applicable for this project + * can be found under https://github.com/MontiCore/monticore. + */ +/* (c) https://github.com/MontiCore/monticore */ +package de.monticore.lang.monticar.cnntrain.annotations; + +import java.util.List; +import java.util.Optional; + +/** + * + */ +public interface PreprocessingComponentParameter { + List getInputNames(); + List getOutputNames(); + Optional getTypeOfInputPort(String portName); + Optional getTypeOfOutputPort(String portName); + Optional> getInputPortDimensionOfPort(String portName); + Optional> getOutputPortDimensionOfPort(String portName); +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java index 19e6259c9cf5da78c1e850e903dd2af300e52be5..c84b5be8dc0fe76270690f90331ce47372596b9c 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java @@ -55,7 +55,11 @@ public class ConfigEntryNameConstants { public static final String DISCRIMINATOR_OPTIMIZER = "discriminator_optimizer"; public static final String K_VALUE = "k_value"; public static final String GENERATOR_LOSS = "generator_loss"; - public static final String CONDITIONAL_INPUT = "conditional_input"; + public static final String GENERATOR_TARGET_NAME = "generator_target_name"; public static final String NOISE_INPUT = "noise_input"; + public static final String GENERATOR_LOSS_WEIGHT = "generator_loss_weight"; + public static final String DISCRIMINATOR_LOSS_WEIGHT = "discriminator_loss_weight"; + public static final String SPEED_PERIOD_ENTRY = "speed_period"; + public static final String PRINT_IMAGES = "print_images"; } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java index 55bfb62d9dede9eb262018cc08b5f1e1afef15e6..9fae0f0531fa93610955440f67e03b6ea4dd5d1b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java @@ -21,4 +21,5 @@ public class ErrorCodes { public static final String CRITIC_NETWORK_ERROR = "xC7100"; public static final String MISSING_TRAINED_ARCHITECTURE = "xC7101"; public static final String TRAINED_ARCHITECTURE_ERROR = "xC7102"; + public static final String GAN_ARCHITECTURE_ERROR = "xC7103"; } \ No newline at end of file