From cafac5218568ccf4a1e7787b4b474323a777fd4c Mon Sep 17 00:00:00 2001 From: Julian Dierkes Date: Fri, 6 Mar 2020 13:08:39 +0100 Subject: [PATCH] added new parameters for GAN --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 7 +++- .../_cocos/ParameterAlgorithmMapping.java | 8 +++- .../CNNTrainSymbolTableCreator.java | 42 +++++++++++++++++++ .../helper/ConfigEntryNameConstants.java | 6 +++ 4 files changed, 61 insertions(+), 2 deletions(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index b5dd4c6..191a0cb 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -256,13 +256,18 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue; PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue; - // GANs Extensions + KValueEntry implements ConfigEntry = name:"k_value" ":" value:IntegerValue; + GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:StringValue; + ConditionalInputEntry implements ConfigEntry = name:"conditional_input" ":" value:StringValue; + NoiseInputEntry implements ConfigEntry = name:"noise_input" ":" value:StringValue; + interface MultiParamValueMapConfigEntry extends ConfigEntry; interface MultiParamValueMapParamValue extends ConfigValue; interface MultiParamValueMapTupleValue extends ConfigValue; DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue; + DiscriminatorOptimizerEntry implements ConfigEntry = name:"discriminator_optimizer" ":" value:OptimizerValue; QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue; PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 2614fb0..531b641 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -123,7 +123,13 @@ class ParameterAlgorithmMapping { ASTQNetworkEntry.class, ASTNoiseDistributionEntry.class, ASTConstraintDistributionEntry.class, - ASTConstraintLossEntry.class + ASTConstraintLossEntry.class, + ASTDiscriminatorOptimizerEntry.class, + ASTKValueEntry.class, + ASTGeneratorLossEntry.class, + ASTConditionalInputEntry.class, + ASTNoiseInputEntry.class + ); ParameterAlgorithmMapping() { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index a3b972e..c059a9f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -110,6 +110,24 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } } + @Override + public void visit(ASTDiscriminatorOptimizerEntry node) { + OptimizerSymbol optimizerSymbol = new OptimizerSymbol(node.getValue().getName()); + configuration.setCriticOptimizer(optimizerSymbol); + addToScopeAndLinkWithNode(optimizerSymbol, node); + } + + @Override + public void endVisit(ASTDiscriminatorOptimizerEntry node) { + assert configuration.getCriticOptimizer().isPresent(): "Critic optimizer not present"; + for (ASTEntry paramNode : node.getValue().getParamsList()) { + OptimizerParamSymbol param = new OptimizerParamSymbol(); + OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol)paramNode.getValue().getSymbolOpt().get(); + param.setValue(valueSymbol); + configuration.getCriticOptimizer().get().getOptimizerParamMap().put(paramNode.getName(), param); + } + } + @Override public void endVisit(ASTNumEpochEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); @@ -118,6 +136,30 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void endVisit(ASTKValueEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTGeneratorLossEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForString(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void endVisit(ASTConditionalInputEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForString(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void endVisit(ASTBatchSizeEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java index d6ffe34..19e6259 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java @@ -52,4 +52,10 @@ public class ConfigEntryNameConstants { public static final String NOISE_DISTRIBUTION = "noise_distribution"; public static final String CONSTRAINT_DISTRIBUTION = "constraint_distributions"; public static final String CONSTRAINT_LOSS = "constraint_losses"; + public static final String DISCRIMINATOR_OPTIMIZER = "discriminator_optimizer"; + public static final String K_VALUE = "k_value"; + public static final String GENERATOR_LOSS = "generator_loss"; + public static final String CONDITIONAL_INPUT = "conditional_input"; + public static final String NOISE_INPUT = "noise_input"; } + -- GitLab