Commit 687fd713 authored by Julian Dierkes's avatar Julian Dierkes

added new parameters for GAN training and CoCos

parent cb7e3ee5
...@@ -259,10 +259,17 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -259,10 +259,17 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue; PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
// GANs Extensions // GANs Extensions
KValueEntry implements ConfigEntry = name:"k_value" ":" value:IntegerValue; KValueEntry implements ConfigEntry = name:"k_value" ":" value:IntegerValue;
GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:StringValue; GeneratorTargetNameEntry implements ConfigEntry = name:"generator_target_name" ":" value:StringValue;
ConditionalInputEntry implements ConfigEntry = name:"conditional_input" ":" value:StringValue; GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:GeneratorLossValue;
GeneratorLossValue implements ConfigValue = (l1: "l1" | l2: "l2");
NoiseInputEntry implements ConfigEntry = name:"noise_input" ":" value:StringValue; NoiseInputEntry implements ConfigEntry = name:"noise_input" ":" value:StringValue;
GeneratorLossWeightEntry implements ConfigEntry = name:"generator_loss_weight" ":" value:NumberValue;
DiscriminatorLossWeightEntry implements ConfigEntry = name:"discriminator_loss_weight" ":" value:NumberValue;
SpeedPeriodEntry implements ConfigEntry = name:"speed_period" ":" value:IntegerValue;
PrintImagesEntry implements ConfigEntry = name:"print_images" ":" value:BooleanValue;
interface MultiParamValueMapConfigEntry extends ConfigEntry; interface MultiParamValueMapConfigEntry extends ConfigEntry;
interface MultiParamValueMapParamValue extends ConfigValue; interface MultiParamValueMapParamValue extends ConfigValue;
...@@ -285,6 +292,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -285,6 +292,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue; MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue;
SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue; SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue;
NoiseDistributionUniformValue implements NoiseDistributionValue = name:"uniform" ("{" "}")?;
// Constraint Distributions // Constraint Distributions
ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue; ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue;
ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?; ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?;
......
...@@ -51,4 +51,11 @@ public class CNNTrainCocos { ...@@ -51,4 +51,11 @@ public class CNNTrainCocos {
.addCoCo(new CheckCriticNetworkInputs()); .addCoCo(new CheckCriticNetworkInputs());
checker.checkAll(configurationSymbol); checker.checkAll(configurationSymbol);
} }
public static void checkGANCocos(final ConfigurationSymbol configurationSymbol) {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckGANNetworkPorts())
.addCoCo(new CheckGANConfigurationDependencies());
checker.checkAll(configurationSymbol);
}
} }
\ No newline at end of file
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import sun.security.krb5.internal.ccache.CredentialsCache;
public class CheckGANConfigurationDependencies implements CNNTrainConfigurationSymbolCoCo{
public CheckGANConfigurationDependencies() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
if(configurationSymbol.getLearningMethod() == LearningMethod.GAN) {
if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_LOSS) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_TARGET_NAME) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Generator loss specified but conditional input is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_TARGET_NAME) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.GENERATOR_LOSS) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Conditional input specified but generator loss is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.LOSS) != null)
Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER +
" Loss parameter not valid for GAN learning");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_INPUT) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_DISTRIBUTION) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Noise input specified but noise distribution parameter is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.CONSTRAINT_DISTRIBUTION) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.QNETWORK_NAME) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Constraint distributions are given but q-network is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.CONSTRAINT_LOSS) != null)
if (configurationSymbol.getEntry(ConfigEntryNameConstants.QNETWORK_NAME) == null)
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Constraint losses are given but q-network is missing");
if (configurationSymbol.getEntry(ConfigEntryNameConstants.NOISE_INPUT) == null)
Log.warn(" No noise input specified. Are you sure this is correct?");
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckGANNetworkPorts implements CNNTrainConfigurationSymbolCoCo {
public void CheckGANNetworkPorts() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get();
NNArchitectureSymbol dis = configurationSymbol.getDiscriminatorNetwork().get();
Optional<NNArchitectureSymbol> qnet = configurationSymbol.getQNetwork();
if(gen.getOutputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator network has more then one output, " +
"but is supposed to only have one");
if(qnet.isPresent() && qnet.get().getInputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Q-Network has more then one input, " +
"but is supposed to only have one");
if(qnet.isPresent() && dis.getOutputs().size() != 2)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 2 output " +
"ports when q-network is given");
if(!qnet.isPresent() && dis.getOutputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 1 output " +
"port when no q-network is given");
if(qnet.isPresent() && dis.getOutputs().size() == 2)
if(!dis.getOutputs().get(1).equals("features"))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Second output of discriminator network " +
"has to be named features when " +
"q-network is given");
if(qnet.isPresent() && !qnet.get().getInputs().get(0).equals("features"))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Input to q-network needs to be named features");
if(!gen.getOutputs().get(0).equals(dis.getInputs().get(0)))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " The generator networks output name does not " +
"fit the first discriminators input name");
if(qnet.isPresent())
if(gen.getInputs().contains(qnet.get().getOutputs()))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator input does not contain all " +
"latent-codes outputted by q-network");
}
}
...@@ -53,16 +53,18 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { ...@@ -53,16 +53,18 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
= parameterAlgorithmMapping.isSupervisedLearningParameter(node.getClass()); = parameterAlgorithmMapping.isSupervisedLearningParameter(node.getClass());
final boolean reinforcementLearningParameter final boolean reinforcementLearningParameter
= parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass()); = parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass());
final boolean ganLearningParameter
= parameterAlgorithmMapping.isGANLearningParameter(node.getClass());
assert (supervisedLearningParameter || reinforcementLearningParameter) : assert (supervisedLearningParameter || reinforcementLearningParameter || ganLearningParameter) :
"Parameter " + node.getName() + " is not checkable, because it is unknown to Condition"; "Parameter " + node.getName() + " is not checkable, because it is unknown to Condition";
if (supervisedLearningParameter && !reinforcementLearningParameter) { if (supervisedLearningParameter && !reinforcementLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node); setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node);
} else if(!supervisedLearningParameter && reinforcementLearningParameter) { } else if(!supervisedLearningParameter && reinforcementLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node); setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node);
}
} }
}
private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) { private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) {
if (isLearningMethodKnown()) { if (isLearningMethodKnown()) {
...@@ -91,11 +93,10 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { ...@@ -91,11 +93,10 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
private void evaluateLearningMethodEntry(ASTEntry node) { private void evaluateLearningMethodEntry(ASTEntry node) {
ASTLearningMethodValue learningMethodValue = (ASTLearningMethodValue)node.getValue(); ASTLearningMethodValue learningMethodValue = (ASTLearningMethodValue)node.getValue();
LearningMethod evaluatedLearningMethod; LearningMethod evaluatedLearningMethod;
if(learningMethodValue.isPresentReinforcement()) { if(learningMethodValue.isPresentReinforcement())
evaluatedLearningMethod = LearningMethod.REINFORCEMENT; evaluatedLearningMethod = LearningMethod.REINFORCEMENT;
} else { else
evaluatedLearningMethod = LearningMethod.SUPERVISED; evaluatedLearningMethod = LearningMethod.SUPERVISED;
}
if (isLearningMethodKnown()) { if (isLearningMethodKnown()) {
logErrorIfEvaluatedLearningMethoNotEqualToActual(node, evaluatedLearningMethod); logErrorIfEvaluatedLearningMethoNotEqualToActual(node, evaluatedLearningMethod);
...@@ -127,16 +128,16 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { ...@@ -127,16 +128,16 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
if (learningMethod.equals(LearningMethod.REINFORCEMENT)) { if (learningMethod.equals(LearningMethod.REINFORCEMENT)) {
return parameterAlgorithmMapping.getAllReinforcementParameters(); return parameterAlgorithmMapping.getAllReinforcementParameters();
} }
return parameterAlgorithmMapping.getAllSupervisedParameters(); return parameterAlgorithmMapping.getAllSupervisedParameters();
} }
private void setLearningMethod(final LearningMethod learningMethod) { private void setLearningMethod(final LearningMethod learningMethod) {
if (learningMethod.equals(LearningMethod.REINFORCEMENT)) { if (learningMethod.equals(LearningMethod.REINFORCEMENT))
setLearningMethodToReinforcement(); setLearningMethodToReinforcement();
} else { else
setLearningMethodToSupervised(); setLearningMethodToSupervised();
}
} }
private void setLearningMethodToSupervised() { private void setLearningMethodToSupervised() {
......
...@@ -129,9 +129,12 @@ class ParameterAlgorithmMapping { ...@@ -129,9 +129,12 @@ class ParameterAlgorithmMapping {
ASTDiscriminatorOptimizerEntry.class, ASTDiscriminatorOptimizerEntry.class,
ASTKValueEntry.class, ASTKValueEntry.class,
ASTGeneratorLossEntry.class, ASTGeneratorLossEntry.class,
ASTConditionalInputEntry.class, ASTGeneratorTargetNameEntry.class,
ASTNoiseInputEntry.class ASTNoiseInputEntry.class,
ASTGeneratorLossWeightEntry.class,
ASTDiscriminatorLossWeightEntry.class,
ASTSpeedPeriodEntry.class,
ASTPrintImagesEntry.class
); );
ParameterAlgorithmMapping() { ParameterAlgorithmMapping() {
...@@ -157,7 +160,13 @@ class ParameterAlgorithmMapping { ...@@ -157,7 +160,13 @@ class ParameterAlgorithmMapping {
boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) { boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz) return GENERAL_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz) || EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz);
}
boolean isGANLearningParameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz)
|| GENERAL_GAN_PARAMETERS.contains(entryClazz); || GENERAL_GAN_PARAMETERS.contains(entryClazz);
} }
...@@ -180,6 +189,14 @@ class ParameterAlgorithmMapping { ...@@ -180,6 +189,14 @@ class ParameterAlgorithmMapping {
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz); || EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
} }
List<Class> getAllGANParameters() {
return ImmutableList.<Class> builder()
.addAll(GENERAL_PARAMETERS)
.addAll(EXCLUSIVE_SUPERVISED_PARAMETERS)
.addAll(GENERAL_GAN_PARAMETERS)
.build();
}
List<Class> getAllReinforcementParameters() { List<Class> getAllReinforcementParameters() {
return ImmutableList.<Class> builder() return ImmutableList.<Class> builder()
.addAll(GENERAL_PARAMETERS) .addAll(GENERAL_PARAMETERS)
......
...@@ -145,7 +145,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -145,7 +145,7 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
} }
@Override @Override
public void endVisit(ASTGeneratorLossEntry node) { public void endVisit(ASTNoiseInputEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForString(node.getValue())); entry.setValue(getValueSymbolForString(node.getValue()));
addToScopeAndLinkWithNode(entry, node); addToScopeAndLinkWithNode(entry, node);
...@@ -153,7 +153,39 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -153,7 +153,39 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
} }
@Override @Override
public void endVisit(ASTConditionalInputEntry node) { public void endVisit(ASTGeneratorLossWeightEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTDiscriminatorLossWeightEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTSpeedPeriodEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTPrintImagesEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void endVisit(ASTGeneratorTargetNameEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForString(node.getValue())); entry.setValue(getValueSymbolForString(node.getValue()));
addToScopeAndLinkWithNode(entry, node); addToScopeAndLinkWithNode(entry, node);
...@@ -442,6 +474,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -442,6 +474,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry); configuration.getEntryMap().put(node.getName(), entry);
} }
@Override
public void visit(ASTGeneratorLossEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentL1()) {
value.setValue(GeneratorLoss.L1);
} else if (node.getValue().isPresentL2()) {
value.setValue(GeneratorLoss.L2);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void visit(ASTRLAlgorithmEntry node) { public void visit(ASTRLAlgorithmEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
...@@ -564,14 +612,6 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -564,14 +612,6 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry); configuration.getEntryMap().put(node.getName(), entry);
} }
@Override
public void visit(ASTPreprocessingEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void visit(ASTReplayMemoryEntry node) { public void visit(ASTReplayMemoryEntry node) {
...@@ -628,11 +668,13 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -628,11 +668,13 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override @Override
public void visit(ASTNoiseDistributionEntry node) { public void visit(ASTNoiseDistributionEntry node) {
NoiseDistribution noiseDistribution; NoiseDistribution noiseDistribution;
if(node.getValue().getName().equals("gaussian")) { if(node.getValue().getName().equals("gaussian"))
noiseDistribution = NoiseDistribution.GAUSSIAN; noiseDistribution = NoiseDistribution.GAUSSIAN;
} else { else if (node.getValue().getName().equals("uniform"))
noiseDistribution = NoiseDistribution.UNIFORM;
else
noiseDistribution = NoiseDistribution.GAUSSIAN; noiseDistribution = NoiseDistribution.GAUSSIAN;
}
processMultiParamConfigVisit(node, noiseDistribution); processMultiParamConfigVisit(node, noiseDistribution);
} }
...@@ -650,6 +692,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -650,6 +692,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
addToScopeAndLinkWithNode(symbol, node); addToScopeAndLinkWithNode(symbol, node);
} }
@Override
public void visit(ASTPreprocessingEntry node) {
PreprocessingComponentSymbol symbol = new PreprocessingComponentSymbol(node.getName());
symbol.setPreprocessingComponentName(node.getValue().getNameList());
configuration.setPreprocessingComponent(symbol);
addToScopeAndLinkWithNode(symbol, node);
}
@Override @Override
public void visit(ASTSoftTargetUpdateRateEntry node) { public void visit(ASTSoftTargetUpdateRateEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); EntrySymbol entry = new EntrySymbol(node.getName());
......
...@@ -20,6 +20,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -20,6 +20,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private OptimizerSymbol criticOptimizer; private OptimizerSymbol criticOptimizer;
private LossSymbol loss; private LossSymbol loss;
private RewardFunctionSymbol rlRewardFunctionSymbol; private RewardFunctionSymbol rlRewardFunctionSymbol;
private PreprocessingComponentSymbol preprocessingComponentSymbol;
private NNArchitectureSymbol trainedArchitecture; private NNArchitectureSymbol trainedArchitecture;
private NNArchitectureSymbol criticNetwork; private NNArchitectureSymbol criticNetwork;
private NNArchitectureSymbol discriminatorNetwork; private NNArchitectureSymbol discriminatorNetwork;
...@@ -30,6 +31,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -30,6 +31,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
public ConfigurationSymbol() { public ConfigurationSymbol() {
super("", KIND); super("", KIND);
rlRewardFunctionSymbol = null; rlRewardFunctionSymbol = null;
preprocessingComponentSymbol = null;
trainedArchitecture = null; trainedArchitecture = null;
} }
...@@ -65,6 +67,18 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -65,6 +67,18 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(this.rlRewardFunctionSymbol); return Optional.ofNullable(this.rlRewardFunctionSymbol);
} }
public void setPreprocessingComponent(PreprocessingComponentSymbol preprocessingComponentSymbol) {
this.preprocessingComponentSymbol = preprocessingComponentSymbol;
}
public Optional<PreprocessingComponentSymbol> getPreprocessingComponent() {
return Optional.ofNullable(this.preprocessingComponentSymbol);
}
public boolean hasPreprocessor() {
return this.preprocessingComponentSymbol != null;
}
public Optional<NNArchitectureSymbol> getTrainedArchitecture() { public Optional<NNArchitectureSymbol> getTrainedArchitecture() {
return Optional.ofNullable(trainedArchitecture); return Optional.ofNullable(trainedArchitecture);
} }
...@@ -118,10 +132,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -118,10 +132,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return getLearningMethod().equals(LearningMethod.GAN); return getLearningMethod().equals(LearningMethod.GAN);
} }
public boolean hasPreprocessor() {
return getEntryMap().containsKey(PREPROCESSING_NAME);
}
public boolean hasCritic() { public boolean hasCritic() {
return getEntryMap().containsKey(CRITIC); return getEntryMap().containsKey(CRITIC);
} }
...@@ -144,16 +154,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { ...@@ -144,16 +154,6 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.of((String)criticNameValue); return Optional.of((String)criticNameValue);
} }
public Optional<String> getPreprocessingName() {
if (!hasPreprocessor()) {
return Optional.empty();
}
final Object preprocessingNameValue = getEntry(PREPROCESSING_NAME).getValue().getValue();
assert preprocessingNameValue instanceof String;
return Optional.of((String)preprocessingNameValue);
}
public Optional<String> getDiscriminatorName() { public Optional<String> getDiscriminatorName() {
if (!hasDiscriminator()) { if (!hasDiscriminator()) {
return Optional.empty(); return Optional.empty();
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum GeneratorLoss {
L1{
@Override
public String toString() {
return "l1";
}
},
L2{
@Override