Commit 687fd713 authored by Julian Dierkes's avatar Julian Dierkes

added new parameters for GAN training and CoCos

parent cb7e3ee5
......@@ -259,10 +259,17 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
// GANs Extensions
KValueEntry implements ConfigEntry = name:"k_value" ":" value:IntegerValue;
GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:StringValue;
ConditionalInputEntry implements ConfigEntry = name:"conditional_input" ":" value:StringValue;
GeneratorTargetNameEntry implements ConfigEntry = name:"generator_target_name" ":" value:StringValue;
GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:GeneratorLossValue;
GeneratorLossValue implements ConfigValue = (l1: "l1" | l2: "l2");
NoiseInputEntry implements ConfigEntry = name:"noise_input" ":" value:StringValue;
GeneratorLossWeightEntry implements ConfigEntry = name:"generator_loss_weight" ":" value:NumberValue;
DiscriminatorLossWeightEntry implements ConfigEntry = name:"discriminator_loss_weight" ":" value:NumberValue;
SpeedPeriodEntry implements ConfigEntry = name:"speed_period" ":" value:IntegerValue;
PrintImagesEntry implements ConfigEntry = name:"print_images" ":" value:BooleanValue;
interface MultiParamValueMapConfigEntry extends ConfigEntry;
interface MultiParamValueMapParamValue extends ConfigValue;
......@@ -285,6 +292,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue;
SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue;
NoiseDistributionUniformValue implements NoiseDistributionValue = name:"uniform" ("{" "}")?;
// Constraint Distributions
ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue;
ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?;
......
......@@ -51,4 +51,11 @@ public class CNNTrainCocos {
.addCoCo(new CheckCriticNetworkInputs());
checker.checkAll(configurationSymbol);
}
public static void checkGANCocos(final ConfigurationSymbol configurationSymbol) {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckGANNetworkPorts())
.addCoCo(new CheckGANConfigurationDependencies());
checker.checkAll(configurationSymbol);
}
}
\ No newline at end of file
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import sun.security.krb5.internal.ccache.CredentialsCache;
public class CheckGANConfigurationDependencies implements CNNTrainConfigurationSymbolCoCo{
public CheckGANConfigurationDependencies() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
if(configurationSymbol.getLearningMethod() == LearningMethod.GAN) {
if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_LOSS) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_TARGET_NAME) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Generator loss specified but conditional input is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_TARGET_NAME) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_LOSS) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Conditional input specified but generator loss is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.LOSS) != null)
Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER +
" Loss parameter not valid for GAN learning");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_INPUT) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_DISTRIBUTION) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Noise input specified but noise distribution parameter is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.CONSTRAINT_DISTRIBUTION) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.QNETWORK_NAME) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Constraint distributions are given but q-network is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.CONSTRAINT_LOSS) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.QNETWORK_NAME) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Constraint losses are given but q-network is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_INPUT) == null)
Log.warn(" No noise input specified. Are you sure this is correct?");
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckGANNetworkPorts implements CNNTrainConfigurationSymbolCoCo {
public void CheckGANNetworkPorts() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get();
NNArchitectureSymbol dis = configurationSymbol.getDiscriminatorNetwork().get();
Optional<NNArchitectureSymbol> qnet = configurationSymbol.getQNetwork();
if(gen.getOutputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator network has more then one output, " +
"but is supposed to only have one");
if(qnet.isPresent() && qnet.get().getInputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Q-Network has more then one input, " +
"but is supposed to only have one");
if(qnet.isPresent() && dis.getOutputs().size() != 2)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 2 output " +
"ports when q-network is given");
if(!qnet.isPresent() && dis.getOutputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 1 output " +
"port when no q-network is given");
if(qnet.isPresent() && dis.getOutputs().size() == 2)
if(!dis.getOutputs().get(1).equals("features"))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Second output of discriminator network " +
"has to be named features when " +
"q-network is given");
if(qnet.isPresent() && !qnet.get().getInputs().get(0).equals("features"))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Input to q-network needs to be named features");
if(!gen.getOutputs().get(0).equals(dis.getInputs().get(0)))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " The generator networks output name does not " +
"fit the first discriminators input name");
if(qnet.isPresent())
if(gen.getInputs().contains(qnet.get().getOutputs()))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator input does not contain all " +
"latent-codes outputted by q-network");
}
}
......@@ -53,16 +53,18 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
= parameterAlgorithmMapping.isSupervisedLearningParameter(node.getClass());
final boolean reinforcementLearningParameter
= parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass());
final boolean ganLearningParameter
= parameterAlgorithmMapping.isGANLearningParameter(node.getClass());
assert (supervisedLearningParameter || reinforcementLearningParameter) :
assert (supervisedLearningParameter || reinforcementLearningParameter || ganLearningParameter) :
"Parameter " + node.getName() + " is not checkable, because it is unknown to Condition";
if (supervisedLearningParameter && !reinforcementLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node);
} else if(!supervisedLearningParameter && reinforcementLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node);
}
}
}
private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) {
if (isLearningMethodKnown()) {
......@@ -91,11 +93,10 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
private void evaluateLearningMethodEntry(ASTEntry node) {
ASTLearningMethodValue learningMethodValue = (ASTLearningMethodValue)node.getValue();
LearningMethod evaluatedLearningMethod;
if(learningMethodValue.isPresentReinforcement()) {
if(learningMethodValue.isPresentReinforcement())
evaluatedLearningMethod = LearningMethod.REINFORCEMENT;
} else {
else
evaluatedLearningMethod = LearningMethod.SUPERVISED;
}
if (isLearningMethodKnown()) {
logErrorIfEvaluatedLearningMethoNotEqualToActual(node, evaluatedLearningMethod);
......@@ -127,16 +128,16 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
if (learningMethod.equals(LearningMethod.REINFORCEMENT)) {
return parameterAlgorithmMapping.getAllReinforcementParameters();
}
return parameterAlgorithmMapping.getAllSupervisedParameters();
}
private void setLearningMethod(final LearningMethod learningMethod) {
if (learningMethod.equals(LearningMethod.REINFORCEMENT)) {
if (learningMethod.equals(LearningMethod.REINFORCEMENT))
setLearningMethodToReinforcement();
} else {
else
setLearningMethodToSupervised();
}
}
private void setLearningMethodToSupervised() {
......
......@@ -129,9 +129,12 @@ class ParameterAlgorithmMapping {
ASTDiscriminatorOptimizerEntry.class,
ASTKValueEntry.class,
ASTGeneratorLossEntry.class,
ASTConditionalInputEntry.class,
ASTNoiseInputEntry.class
ASTGeneratorTargetNameEntry.class,
ASTNoiseInputEntry.class,
ASTGeneratorLossWeightEntry.class,
ASTDiscriminatorLossWeightEntry.class,
ASTSpeedPeriodEntry.class,
ASTPrintImagesEntry.class
);
ParameterAlgorithmMapping() {
......@@ -157,7 +160,13 @@ class ParameterAlgorithmMapping {
boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz);
}
boolean isGANLearningParameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz)
|| GENERAL_GAN_PARAMETERS.contains(entryClazz);
}
......@@ -180,6 +189,14 @@ class ParameterAlgorithmMapping {
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
}
List<Class> getAllGANParameters() {
return ImmutableList.<Class> builder()
.addAll(GENERAL_PARAMETERS)
.addAll(EXCLUSIVE_SUPERVISED_PARAMETERS)
.addAll(GENERAL_GAN_PARAMETERS)
.build();
}
List<Class> getAllReinforcementParameters() {
return ImmutableList.<Class> builder()
.addAll(GENERAL_PARAMETERS)
......
......@@ -145,7 +145,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
}
@Override
public void endVisit(ASTGeneratorLossEntry node) {
public void endVisit(ASTNoiseInputEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
......@@ -153,7 +153,39 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
}
@Override
public void endVisit(ASTConditionalInputEntry node) {
public void endVisit(ASTGeneratorLossWeightEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTDiscriminatorLossWeightEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTSpeedPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTPrintImagesEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTGeneratorTargetNameEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
......@@ -442,6 +474,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTGeneratorLossEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentL1()) {
value.setValue(GeneratorLoss.L1);
} else if (node.getValue().isPresentL2()) {
value.setValue(GeneratorLoss.L2);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTRLAlgorithmEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......@@ -564,14 +612,6 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTPreprocessingEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTReplayMemoryEntry node) {
......@@ -628,11 +668,13 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void visit(ASTNoiseDistributionEntry node) {
NoiseDistribution noiseDistribution;
if(node.getValue().getName().equals("gaussian")) {
if(node.getValue().getName().equals("gaussian"))
noiseDistribution = NoiseDistribution.GAUSSIAN;
} else {
else if (node.getValue().getName().equals("uniform"))
noiseDistribution = NoiseDistribution.UNIFORM;
else
noiseDistribution = NoiseDistribution.GAUSSIAN;
}
processMultiParamConfigVisit(node, noiseDistribution);
}
......@@ -650,6 +692,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
addToScopeAndLinkWithNode(symbol, node);
}
@Override
public void visit(ASTPreprocessingEntry node) {
PreprocessingComponentSymbol symbol = new PreprocessingComponentSymbol(node.getName());
symbol.setPreprocessingComponentName(node.getValue().getNameList());
configuration.setPreprocessingComponent(symbol);
addToScopeAndLinkWithNode(symbol, node);
}
@Override
public void visit(ASTSoftTargetUpdateRateEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......
......@@ -20,6 +20,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private OptimizerSymbol criticOptimizer;
private LossSymbol loss;
private RewardFunctionSymbol rlRewardFunctionSymbol;
private PreprocessingComponentSymbol preprocessingComponentSymbol;
private NNArchitectureSymbol trainedArchitecture;
private NNArchitectureSymbol criticNetwork;
private NNArchitectureSymbol discriminatorNetwork;
......@@ -30,6 +31,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
public ConfigurationSymbol() {
super("", KIND);
rlRewardFunctionSymbol = null;
preprocessingComponentSymbol = null;
trainedArchitecture = null;
}
......@@ -65,6 +67,18 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(this.rlRewardFunctionSymbol);
}
public void setPreprocessingComponent(PreprocessingComponentSymbol preprocessingComponentSymbol) {
this.preprocessingComponentSymbol = preprocessingComponentSymbol;
}
public Optional<PreprocessingComponentSymbol> getPreprocessingComponent() {
return Optional.ofNullable(this.preprocessingComponentSymbol);
}
public boolean hasPreprocessor() {
return this.preprocessingComponentSymbol != null;
}
public Optional<NNArchitectureSymbol> getTrainedArchitecture() {
return Optional.ofNullable(trainedArchitecture);
}
......@@ -118,10 +132,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return getLearningMethod().equals(LearningMethod.GAN);
}
public boolean hasPreprocessor() {
return getEntryMap().containsKey(PREPROCESSING_NAME);
}
public boolean hasCritic() {
return getEntryMap().containsKey(CRITIC);
}
......@@ -144,16 +154,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.of((String)criticNameValue);
}
public Optional<String> getPreprocessingName() {
if (!hasPreprocessor()) {
return Optional.empty();
}
final Object preprocessingNameValue = getEntry(PREPROCESSING_NAME).getValue().getValue();
assert preprocessingNameValue instanceof String;
return Optional.of((String)preprocessingNameValue);
}
public Optional<String> getDiscriminatorName() {
if (!hasDiscriminator()) {
return Optional.empty();
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum GeneratorLoss {
L1{
@Override
public String toString() {
return "l1";
}
},
L2{
@Override
public String toString() {
return "l2";
}
}
}
......@@ -12,5 +12,11 @@ public enum NoiseDistribution {
public String toString() {
return "gaussian";
}
},
UNIFORM{
@Override
public String toString() {
return "uniform";
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnntrain.annotations.PreprocessingComponentParameter;
import de.monticore.lang.monticar.cnntrain.annotations.PreprocessingComponentParameter;
import de.monticore.symboltable.CommonSymbol;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
*
*/
public class PreprocessingComponentSymbol extends CommonSymbol {
public static final PreprocessingComponentSymbolKind KIND = new PreprocessingComponentSymbolKind();
private List<String> preprocessingComponentName;
private PreprocessingComponentParameter preprocessingComponentParameter;
public PreprocessingComponentSymbol(String name) {
super(name, KIND);
preprocessingComponentName = new ArrayList<>();
}
protected void setPreprocessingComponentName(List<String> preprocessingComponentNamePath) {
this.preprocessingComponentName = Lists.newArrayList(preprocessingComponentNamePath);
}
public List<String> getPreprocessingComponentName() {
return Lists.newArrayList(preprocessingComponentName);
}
public void setPreprocessingComponentParameter(PreprocessingComponentParameter preprocessingComponentParameter) {
this.preprocessingComponentParameter = preprocessingComponentParameter;
}
public Optional<PreprocessingComponentParameter> getPreprocessingComponentParameter() {
return Optional.ofNullable(preprocessingComponentParameter);
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
import de.monticore.symboltable.SymbolKind;
/**
*
*/
public class PreprocessingComponentSymbolKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.PreprocessingComponentSymbolKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain.annotations;
import java.util.List;
import java.util.Optional;
/**
*
*/
public interface PreprocessingComponentParameter {
List<String> getInputNames();
List<String> getOutputNames();
Optional<String> getTypeOfInputPort(String portName);
Optional<String> getTypeOfOutputPort(String portName);
Optional<List<Integer>> getInputPortDimensionOfPort(String portName);
Optional<List<Integer>> getOutputPortDimensionOfPort(String portName);