Commit 300c11a6 authored by Julian Dierkes's avatar Julian Dierkes
Browse files

beginning to add support for GAN

parent dc20415c
[INFO] Scanning for projects...
[INFO]
[INFO] ----------------< de.monticore.lang.monticar:cnn-train >----------------
[INFO] Building cnn-train 0.3.6-SNAPSHOT
[INFO] --------------------------------[ jar ]---------------------------------
[INFO]
[INFO] --- maven-clean-plugin:2.5:clean (default-clean) @ cnn-train ---
[INFO] Deleting /home/julian/Dropbox/Dokumente/bachelorarbeit/pipeline/CNNTrainLang/target
[INFO]
[INFO] --- jacoco-maven-plugin:0.8.1:prepare-agent (pre-unit-test) @ cnn-train ---
[INFO] argLine set to -javaagent:/home/julian/.m2/repository/org/jacoco/org.jacoco.agent/0.8.1/org.jacoco.agent-0.8.1-runtime.jar=destfile=/home/julian/Dropbox/Dokumente/bachelorarbeit/pipeline/CNNTrainLang/target/jacoco.exec
[INFO]
[INFO] --- monticore-maven-plugin:5.0.1:generate (default) @ cnn-train ---
[INFO] Changes detected for /home/julian/Dropbox/Dokumente/bachelorarbeit/pipeline/CNNTrainLang/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4. Regenerating...
...@@ -24,6 +24,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -24,6 +24,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
ComponentNameValue implements ConfigValue = Name ("."Name)*; ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")"; DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
...@@ -147,7 +148,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -147,7 +148,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
StartTrainingAtEntry implements ConfigEntry = name:"start_training_at" ":" value:IntegerValue; StartTrainingAtEntry implements ConfigEntry = name:"start_training_at" ":" value:IntegerValue;
EvaluationSamplesEntry implements ConfigEntry = name:"evaluation_samples" ":" value:IntegerValue; EvaluationSamplesEntry implements ConfigEntry = name:"evaluation_samples" ":" value:IntegerValue;
LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement"); LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement" | gan:"gan");
RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm" | tdThree:"td3-algorithm"); RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm" | tdThree:"td3-algorithm");
...@@ -230,4 +231,20 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -230,4 +231,20 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
PolicyNoiseEntry implements ConfigEntry = name:"policy_noise" ":" value:NumberValue; PolicyNoiseEntry implements ConfigEntry = name:"policy_noise" ":" value:NumberValue;
NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue; NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue;
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue; PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
// GANs Extensions
DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue;
PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue;
ImgResizeEntry implements ConfigEntry = name:"img_resize" ":" value:IntegerTupelValue;
// Noise Distribution Creator
NoiseDistributionEntry implements MultiParamConfigEntry = name:"noise_distribution" ":" value:NoiseDistributionValue;
interface NoiseDistributionValue extends MultiParamValue;
interface NoiseDistributionGaussianEntry extends Entry;
NoiseDistributionGaussianValue implements NoiseDistributionValue = name:"gaussian" ("{" params:NoiseDistributionGaussianEntry* "}")?;
MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue;
SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue;
} }
\ No newline at end of file
/** /**
* (c) https://github.com/MontiCore/monticore
* *
* ****************************************************************************** * The license generally applicable for this project
* MontiCAR Modeling Family, www.se-rwth.de * can be found under https://github.com/MontiCore/monticore.
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnntrain._cocos; package de.monticore.lang.monticar.cnntrain._cocos;
......
...@@ -59,10 +59,9 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { ...@@ -59,10 +59,9 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
if (supervisedLearningParameter && !reinforcementLearningParameter) { if (supervisedLearningParameter && !reinforcementLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node); setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node);
} else if(!supervisedLearningParameter) { } else if(!supervisedLearningParameter && reinforcementLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node); setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node);
} }
} }
private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) { private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) {
...@@ -91,8 +90,12 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { ...@@ -91,8 +90,12 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
private void evaluateLearningMethodEntry(ASTEntry node) { private void evaluateLearningMethodEntry(ASTEntry node) {
ASTLearningMethodValue learningMethodValue = (ASTLearningMethodValue)node.getValue(); ASTLearningMethodValue learningMethodValue = (ASTLearningMethodValue)node.getValue();
LearningMethod evaluatedLearningMethod = learningMethodValue.isPresentReinforcement() LearningMethod evaluatedLearningMethod;
? LearningMethod.REINFORCEMENT : LearningMethod.SUPERVISED; if(learningMethodValue.isPresentReinforcement()) {
evaluatedLearningMethod = LearningMethod.REINFORCEMENT;
} else {
evaluatedLearningMethod = LearningMethod.SUPERVISED;
}
if (isLearningMethodKnown()) { if (isLearningMethodKnown()) {
logErrorIfEvaluatedLearningMethoNotEqualToActual(node, evaluatedLearningMethod); logErrorIfEvaluatedLearningMethoNotEqualToActual(node, evaluatedLearningMethod);
......
...@@ -46,7 +46,8 @@ class ParameterAlgorithmMapping { ...@@ -46,7 +46,8 @@ class ParameterAlgorithmMapping {
ASTFromLogitsEntry.class, ASTFromLogitsEntry.class,
ASTMarginEntry.class, ASTMarginEntry.class,
ASTLabelFormatEntry.class, ASTLabelFormatEntry.class,
ASTRhoEntry.class ASTRhoEntry.class,
ASTPreprocessingEntry.class
); );
private static final List<Class> GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList( private static final List<Class> GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList(
...@@ -110,6 +111,12 @@ class ParameterAlgorithmMapping { ...@@ -110,6 +111,12 @@ class ParameterAlgorithmMapping {
ASTStrategyGaussianNoiseVarianceEntry.class ASTStrategyGaussianNoiseVarianceEntry.class
); );
private static final List<Class> GENERAL_GAN_PARAMETERS = Lists.newArrayList(
ASTDiscriminatorNetworkEntry.class,
ASTNoiseDistributionEntry.class,
ASTImgResizeEntry.class
);
ParameterAlgorithmMapping() { ParameterAlgorithmMapping() {
} }
...@@ -124,7 +131,9 @@ class ParameterAlgorithmMapping { ...@@ -124,7 +131,9 @@ class ParameterAlgorithmMapping {
boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) { boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz) return GENERAL_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz); || EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz)
|| GENERAL_GAN_PARAMETERS.contains(entryClazz);
} }
boolean isDqnParameter(Class<? extends ASTEntry> entryClazz) { boolean isDqnParameter(Class<? extends ASTEntry> entryClazz) {
...@@ -159,6 +168,7 @@ class ParameterAlgorithmMapping { ...@@ -159,6 +168,7 @@ class ParameterAlgorithmMapping {
return ImmutableList.<Class> builder() return ImmutableList.<Class> builder()
.addAll(GENERAL_PARAMETERS) .addAll(GENERAL_PARAMETERS)
.addAll(EXCLUSIVE_SUPERVISED_PARAMETERS) .addAll(EXCLUSIVE_SUPERVISED_PARAMETERS)
.addAll(GENERAL_GAN_PARAMETERS)
.build(); .build();
} }
} }
...@@ -344,6 +344,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -344,6 +344,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
if (node.getValue().isPresentReinforcement()) { if (node.getValue().isPresentReinforcement()) {
value.setValue(LearningMethod.REINFORCEMENT); value.setValue(LearningMethod.REINFORCEMENT);
} if (node.getValue().isPresentGan()) {
value.setValue(LearningMethod.GAN);
} else if (node.getValue().isPresentSupervisedLearning()) { } else if (node.getValue().isPresentSupervisedLearning()) {
value.setValue(LearningMethod.SUPERVISED); value.setValue(LearningMethod.SUPERVISED);
} }
...@@ -459,7 +461,35 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -459,7 +461,35 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry); configuration.getEntryMap().put(node.getName(), entry);
} }
@Override
public void visit(ASTDiscriminatorNetworkEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTPreprocessingEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTImgResizeEntry node) {
EntrySymbol width_entry = new EntrySymbol(node.getName());
EntrySymbol height_entry = new EntrySymbol(node.getName());
width_entry.setValue(getValueSymbolForInteger(node.getValue().getFirst()));
height_entry.setValue(getValueSymbolForInteger(node.getValue().getSecond()));
addToScopeAndLinkWithNode(width_entry, node);
addToScopeAndLinkWithNode(height_entry, node);
configuration.getEntryMap().put(node.getName() + "_width", width_entry);
configuration.getEntryMap().put(node.getName() + "_height", height_entry);
}
@Override @Override
public void visit(ASTReplayMemoryEntry node) { public void visit(ASTReplayMemoryEntry node) {
...@@ -493,6 +523,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -493,6 +523,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
processMultiParamConfigEndVisit(node); processMultiParamConfigEndVisit(node);
} }
@Override
public void visit(ASTNoiseDistributionEntry node) {
NoiseDistribution noiseDistribution;
if(node.getValue().getName().equals("gaussian")) {
noiseDistribution = NoiseDistribution.GAUSSIAN;
} else {
noiseDistribution = NoiseDistribution.GAUSSIAN;
}
processMultiParamConfigVisit(node, noiseDistribution);
}
@Override
public void endVisit(ASTNoiseDistributionEntry node) {
processMultiParamConfigEndVisit(node);
}
@Override @Override
public void visit(ASTRewardFunctionEntry node) { public void visit(ASTRewardFunctionEntry node) {
RewardFunctionSymbol symbol = new RewardFunctionSymbol(node.getName()); RewardFunctionSymbol symbol = new RewardFunctionSymbol(node.getName());
......
...@@ -22,6 +22,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -22,6 +22,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private RewardFunctionSymbol rlRewardFunctionSymbol; private RewardFunctionSymbol rlRewardFunctionSymbol;
private NNArchitectureSymbol trainedArchitecture; private NNArchitectureSymbol trainedArchitecture;
private NNArchitectureSymbol criticNetwork; private NNArchitectureSymbol criticNetwork;
private NNArchitectureSymbol discriminatorNetwork;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind(); public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
...@@ -75,10 +76,18 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -75,10 +76,18 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(criticNetwork); return Optional.ofNullable(criticNetwork);
} }
public Optional<NNArchitectureSymbol> getDiscriminatorNetwork() {
return Optional.ofNullable(discriminatorNetwork);
}
public void setCriticNetwork(NNArchitectureSymbol criticNetwork) { public void setCriticNetwork(NNArchitectureSymbol criticNetwork) {
this.criticNetwork = criticNetwork; this.criticNetwork = criticNetwork;
} }
public void setDiscriminatorNetwork(NNArchitectureSymbol discriminatorNetwork) {
this.discriminatorNetwork = discriminatorNetwork;
}
public Map<String, EntrySymbol> getEntryMap() { public Map<String, EntrySymbol> getEntryMap() {
return entryMap; return entryMap;
} }
...@@ -96,10 +105,22 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -96,10 +105,22 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return getLearningMethod().equals(LearningMethod.REINFORCEMENT); return getLearningMethod().equals(LearningMethod.REINFORCEMENT);
} }
public boolean isGAN() {
return getLearningMethod().equals(LearningMethod.GAN);
}
public boolean hasPreprocessor() {
return getEntryMap().containsKey(PREPROCESSING_NAME);
}
public boolean hasCritic() { public boolean hasCritic() {
return getEntryMap().containsKey(CRITIC); return getEntryMap().containsKey(CRITIC);
} }
public boolean hasDiscriminator() {
return getEntryMap().containsKey(DISCRIMINATOR_NAME);
}
public Optional<String> getCriticName() { public Optional<String> getCriticName() {
if (!hasCritic()) { if (!hasCritic()) {
return Optional.empty(); return Optional.empty();
...@@ -109,4 +130,24 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -109,4 +130,24 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
assert criticNameValue instanceof String; assert criticNameValue instanceof String;
return Optional.of((String)criticNameValue); return Optional.of((String)criticNameValue);
} }
public Optional<String> getPreprocessingName() {
if (!hasPreprocessor()) {
return Optional.empty();
}
final Object preprocessingNameValue = getEntry(PREPROCESSING_NAME).getValue().getValue();
assert preprocessingNameValue instanceof String;
return Optional.of((String)preprocessingNameValue);
}
public Optional<String> getDiscriminatorName() {
if (!hasDiscriminator()) {
return Optional.empty();
}
final Object discriminatorNameValue = getEntry(DISCRIMINATOR_NAME).getValue().getValue();
assert discriminatorNameValue instanceof String;
return Optional.of((String)discriminatorNameValue);
}
} }
\ No newline at end of file
...@@ -22,5 +22,11 @@ public enum LearningMethod { ...@@ -22,5 +22,11 @@ public enum LearningMethod {
public String toString() { public String toString() {
return "reinforcement"; return "reinforcement";
} }
},
GAN {
@Override
public String toString() {
return "gan";
}
} }
} }
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum NoiseDistribution {
GAUSSIAN{
@Override
public String toString() {
return "gaussian";
}
}
}
...@@ -44,4 +44,11 @@ public class ConfigEntryNameConstants { ...@@ -44,4 +44,11 @@ public class ConfigEntryNameConstants {
public static final String STRATEGY_EPSDECAY = "epsdecay"; public static final String STRATEGY_EPSDECAY = "epsdecay";
public static final String CRITIC = "critic"; public static final String CRITIC = "critic";
public static final String DISCRIMINATOR_NAME = "discriminator_name";
public static final String PREPROCESSING_NAME = "preprocessing_name";
public static final String NOISE_DISTRIBUTION = "noise_distribution";
public static final String IMG_RESIZE = "img_resize";
public static final String IMG_RESIZE_WIDTH = "img_resize_width";
public static final String IMG_RESIZE_HEIGHT = "img_resize_height";
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment