Commit ea8e2486 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Develop

See merge request !27
parents 9b8028f6 a108ad1d
Pipeline #300997 passed with stage
in 4 minutes and 11 seconds
......@@ -8,7 +8,7 @@
# (c) https://github.com/MontiCore/monticore
stages:
- windows
#- windows
- linux
masterJobLinux:
......@@ -22,12 +22,12 @@ masterJobLinux:
only:
- master
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags:
- Windows10_OS
#masterJobWindows:
# stage: windows
# script:
# - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
# tags:
# - Windows10_OS
BranchJobLinux:
stage: linux
......
[INFO] Scanning for projects...
[INFO]
[INFO] ----------------< de.monticore.lang.monticar:cnn-train >----------------
[INFO] Building cnn-train 0.3.6-SNAPSHOT
[INFO] --------------------------------[ jar ]---------------------------------
[INFO]
[INFO] --- maven-clean-plugin:2.5:clean (default-clean) @ cnn-train ---
[INFO] Deleting /home/julian/Dropbox/Dokumente/bachelorarbeit/pipeline/CNNTrainLang/target
[INFO]
[INFO] --- jacoco-maven-plugin:0.8.1:prepare-agent (pre-unit-test) @ cnn-train ---
[INFO] argLine set to -javaagent:/home/julian/.m2/repository/org/jacoco/org.jacoco.agent/0.8.1/org.jacoco.agent-0.8.1-runtime.jar=destfile=/home/julian/Dropbox/Dokumente/bachelorarbeit/pipeline/CNNTrainLang/target/jacoco.exec
[INFO]
[INFO] --- monticore-maven-plugin:5.0.1:generate (default) @ cnn-train ---
[INFO] Changes detected for /home/julian/Dropbox/Dokumente/bachelorarbeit/pipeline/CNNTrainLang/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4. Regenerating...
......@@ -18,7 +18,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.3.9-SNAPSHOT</version>
<version>0.3.10-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -30,6 +30,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
CheckpointPeriodEntry implements ConfigEntry = name:"checkpoint_period" ":" value:IntegerValue;
LoadPretrainedEntry implements ConfigEntry = name:"load_pretrained" ":" value:BooleanValue;
LogPeriodEntry implements ConfigEntry = name:"log_period" ":" value:IntegerValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
......@@ -46,6 +47,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface EvalMetricValue extends MultiParamValue;
AccuracyEvalMetric implements EvalMetricValue = name:"accuracy";
BleuMetric implements EvalMetricValue = name:"bleu" ("{" params:BleuEntry* "}")?;
AccIgnoreLabelMetric implements EvalMetricValue = name:"accuracy_ignore_label" ("{" params:AccIgnoreLabelEntry* "}")?;
CrossEntropyEvalMetric implements EvalMetricValue = name:"cross_entropy";
F1EvalMetric implements EvalMetricValue = name:"f1";
MAEEvalMetric implements EvalMetricValue = name:"mae";
......@@ -57,6 +59,10 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
interface AccIgnoreLabelEntry extends Entry;
AxisAccIgnoreLabelEntry implements AccIgnoreLabelEntry = name:"axis" ":" value:IntegerValue;
IgnoreLabelAccIgnoreLabelEntry implements AccIgnoreLabelEntry = name:"metric_ignore_label" ":" value:IntegerValue;
EvalTrainEntry implements ConfigEntry = name:"eval_train" ":" value:BooleanValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed"
......@@ -91,6 +97,12 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface SoftmaxCrossEntropyIgnoreIndicesEntry extends Entry;
SoftmaxCrossEntropyIgnoreIndicesLoss implements LossValue = name:"softmax_cross_entropy_ignore_indices" ("{" params:SoftmaxCrossEntropyIgnoreIndicesEntry* "}")?;
interface DiceEntry extends Entry;
DiceLoss implements LossValue = name:"dice_loss" ("{" params:DiceEntry* "}")?;
interface SoftmaxCrossEntropyIgnoreLabelEntry extends Entry;
SoftmaxCrossEntropyIgnoreLabelLoss implements LossValue = name:"softmax_cross_entropy_ignore_label" ("{" params:SoftmaxCrossEntropyIgnoreLabelEntry* "}")?;
SigmoidBinaryCrossEntropyLoss implements LossValue = name:"sigmoid_binary_cross_entropy" ("{" params:Entry* "}")?;
interface HingeEntry extends Entry;
......@@ -105,9 +117,12 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface KullbackLeiblerEntry extends Entry;
KullbackLeiblerLoss implements LossValue = name:"kullback_leibler" ("{" params:KullbackLeiblerEntry* "}")?;
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry = name:"from_logits" ":" value:BooleanValue;
SparseLabelEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, DiceEntry, SoftmaxCrossEntropyIgnoreLabelEntry = name:"sparse_label" ":" value:BooleanValue;
FromLogitsEntry implements SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, KullbackLeiblerEntry, DiceEntry, SoftmaxCrossEntropyIgnoreLabelEntry = name:"from_logits" ":" value:BooleanValue;
LossAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, DiceEntry,SoftmaxCrossEntropyIgnoreLabelEntry = name:"loss_axis" ":" value:IntegerValue;
BatchAxisEntry implements CrossEntropyEntry, SoftmaxCrossEntropyEntry, SoftmaxCrossEntropyIgnoreIndicesEntry, DiceEntry, SoftmaxCrossEntropyIgnoreLabelEntry = name:"batch_axis" ":" value:IntegerValue;
IgnoreIndicesEntry implements SoftmaxCrossEntropyIgnoreIndicesEntry = name:"ignore_indices" ":" value:IntegerValue;
IgnoreLabelEntry implements SoftmaxCrossEntropyIgnoreLabelEntry = name:"loss_ignore_label" ":" value:IntegerValue;
MarginEntry implements HingeEntry, SquaredHingeEntry = name:"margin" ":" value:NumberValue;
LabelFormatEntry implements LogisticEntry = name:"label_format" ":" value:StringValue;
......@@ -256,13 +271,24 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue;
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
// GANs Extensions
KValueEntry implements ConfigEntry = name:"k_value" ":" value:IntegerValue;
GeneratorTargetNameEntry implements ConfigEntry = name:"generator_target_name" ":" value:StringValue;
GeneratorLossEntry implements ConfigEntry = name:"generator_loss" ":" value:GeneratorLossValue;
GeneratorLossValue implements ConfigValue = (l1: "l1" | l2: "l2");
NoiseInputEntry implements ConfigEntry = name:"noise_input" ":" value:StringValue;
GeneratorLossWeightEntry implements ConfigEntry = name:"generator_loss_weight" ":" value:NumberValue;
DiscriminatorLossWeightEntry implements ConfigEntry = name:"discriminator_loss_weight" ":" value:NumberValue;
PrintImagesEntry implements ConfigEntry = name:"print_images" ":" value:BooleanValue;
interface MultiParamValueMapConfigEntry extends ConfigEntry;
interface MultiParamValueMapParamValue extends ConfigValue;
interface MultiParamValueMapTupleValue extends ConfigValue;
DiscriminatorNetworkEntry implements ConfigEntry = name:"discriminator_name" ":" value:ComponentNameValue;
DiscriminatorOptimizerEntry implements ConfigEntry = name:"discriminator_optimizer" ":" value:OptimizerValue;
QNetworkEntry implements ConfigEntry = name:"qnet_name" ":" value:ComponentNameValue;
PreprocessingEntry implements ConfigEntry = name:"preprocessing_name" ":" value:ComponentNameValue;
......@@ -278,6 +304,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
MeanValueEntry implements NoiseDistributionGaussianEntry = name:"mean_value" ":" value:IntegerValue;
SpreadValueEntry implements NoiseDistributionGaussianEntry = name:"spread_value" ":" value:IntegerValue;
NoiseDistributionUniformValue implements NoiseDistributionValue = name:"uniform" ("{" "}")?;
// Constraint Distributions
ConstraintDistributionEntry implements MultiParamValueMapConfigEntry = name:"constraint_distributions" ":" value:ConstraintDistributionValue;
ConstraintDistributionValue implements MultiParamValueMapParamValue = ("{" params:ConstraintDistributionParam* "}")?;
......
......@@ -21,6 +21,12 @@ class ASTConfigurationUtils {
&& ((ASTLearningMethodEntry)e).getValue().isPresentReinforcement());
}
static boolean isGANLearning(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e ->
(e instanceof ASTLearningMethodEntry)
&& ((ASTLearningMethodEntry)e).getValue().isPresentGan());
}
static boolean hasEnvironment(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e -> e instanceof ASTEnvironmentEntry);
}
......@@ -108,4 +114,32 @@ class ASTConfigurationUtils {
}
return false;
}
static boolean hasGeneratorLoss(final ASTConfiguration node) {
return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTGeneratorLossEntry);
}
static boolean hasGeneratorTargetName(final ASTConfiguration node) {
return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTGeneratorTargetNameEntry);
}
static boolean hasNoiseName(final ASTConfiguration node) {
return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTNoiseInputEntry);
}
static boolean hasNoiseDistribution(final ASTConfiguration node) {
return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTNoiseDistributionEntry);
}
static boolean hasConstraintDistribution(final ASTConfiguration node) {
return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTConstraintDistributionEntry);
}
static boolean hasConstraintLosses(final ASTConfiguration node) {
return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTConstraintLossEntry);
}
static boolean hasQNetwork(final ASTConfiguration node) {
return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTQNetworkEntry);
}
}
......@@ -27,7 +27,12 @@ public class CNNTrainCocos {
.addCoCo(new CheckRlAlgorithmParameter())
.addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy())
.addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy())
.addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification());
.addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification())
.addCoCo(new CheckConstraintDistributionQNetworkDependency())
.addCoCo(new CheckConstraintLossesQNetworkDependency())
.addCoCo(new CheckGeneratorLossTargetNameDependency())
.addCoCo(new CheckNoiseInputDistributionDependency())
.addCoCo(new CheckNoiseInputMissing());
}
public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){
......@@ -51,4 +56,14 @@ public class CNNTrainCocos {
.addCoCo(new CheckCriticNetworkInputs());
checker.checkAll(configurationSymbol);
}
public static void checkGANCocos(final ConfigurationSymbol configurationSymbol) {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckGANDiscriminatorQNetworkDependency())
.addCoCo(new CheckGANGeneratorDiscriminatorDependency())
.addCoCo(new CheckGANGeneratorHasOneOutput())
.addCoCo(new CheckGANGeneratorQNetworkDependency())
.addCoCo(new CheckGANQNetworkhasOneInput());
checker.checkAll(configurationSymbol);
}
}
\ No newline at end of file
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.*;
public class CheckConstraintDistributionQNetworkDependency implements CNNTrainASTConfigurationCoCo {
@Override
public void check(final ASTConfiguration node) {
if (isGANLearning(node) && hasConstraintDistribution(node)) {
if (!hasQNetwork(node)) {
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Constraint distributions are given but q-network is missing");
}
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.*;
public class CheckConstraintLossesQNetworkDependency implements CNNTrainASTConfigurationCoCo {
@Override
public void check(final ASTConfiguration node) {
if (isGANLearning(node) && hasConstraintLosses(node)) {
if (!hasQNetwork(node)) {
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Constraint losses are given but q-network is missing");
}
}
}
}
......@@ -32,4 +32,5 @@ public class CheckDiscreteRLAlgorithmUsesDiscreteStrategy implements CNNTrainAST
}
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckGANDiscriminatorQNetworkDependency implements CNNTrainConfigurationSymbolCoCo {
public void CheckGANNetworkPorts() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
NNArchitectureSymbol dis = configurationSymbol.getDiscriminatorNetwork().get();
Optional<NNArchitectureSymbol> qnet = configurationSymbol.getQNetwork();
if(qnet.isPresent() && dis.getOutputs().size() != 2)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 2 output " +
"ports when q-network is given");
if(!qnet.isPresent() && dis.getOutputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Discriminator needs exactly 1 output " +
"port when no q-network is given");
if(qnet.isPresent() && dis.getOutputs().size() == 2)
if(!dis.getOutputs().get(1).equals("features"))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Second output of discriminator network " +
"has to be named features when " +
"q-network is given");
if(qnet.isPresent() && !qnet.get().getInputs().get(0).equals("features"))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Input to q-network needs to be named features");
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckGANGeneratorDiscriminatorDependency implements CNNTrainConfigurationSymbolCoCo {
public void CheckGANNetworkPorts() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get();
NNArchitectureSymbol dis = configurationSymbol.getDiscriminatorNetwork().get();
if(!gen.getOutputs().get(0).equals(dis.getInputs().get(0)))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " The generator networks output name does not " +
"fit the first discriminators input name");
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckGANGeneratorHasOneOutput implements CNNTrainConfigurationSymbolCoCo {
public void CheckGANNetworkPorts() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get();
if(gen.getOutputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator network has more then one output, " +
"but is supposed to only have one");
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckGANGeneratorQNetworkDependency implements CNNTrainConfigurationSymbolCoCo {
public void CheckGANNetworkPorts() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
NNArchitectureSymbol gen = configurationSymbol.getTrainedArchitecture().get();
Optional<NNArchitectureSymbol> qnet = configurationSymbol.getQNetwork();
if(qnet.isPresent())
if(!gen.getInputs().containsAll(qnet.get().getOutputs()))
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Generator input does not contain all " +
"latent-codes outputted by q-network");
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckGANQNetworkhasOneInput implements CNNTrainConfigurationSymbolCoCo {
public void CheckGANNetworkPorts() { }
@Override
public void check(ConfigurationSymbol configurationSymbol) {
Optional<NNArchitectureSymbol> qnet = configurationSymbol.getQNetwork();
if(qnet.isPresent() && qnet.get().getInputs().size() != 1)
Log.error("0" + ErrorCodes.GAN_ARCHITECTURE_ERROR + " Q-Network has more then one input, " +
"but is supposed to only have one");
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.*;
public class CheckGeneratorLossTargetNameDependency implements CNNTrainASTConfigurationCoCo {
@Override
public void check(final ASTConfiguration node) {
if (isGANLearning(node) && hasGeneratorLoss(node)) {
if (!hasGeneratorTargetName(node)) {
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Generator loss specified but conditional input is missing");
}
}
else if (isGANLearning(node) && hasGeneratorTargetName(node)) {
if (!hasGeneratorLoss(node)) {
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING +
" Conditional input specified but generator loss is missing");
}
}
}
}
......@@ -7,7 +7,6 @@
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
......@@ -53,16 +52,26 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
= parameterAlgorithmMapping.isSupervisedLearningParameter(node.getClass());
final boolean reinforcementLearningParameter
= parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass());
final boolean ganLearningParameter
= parameterAlgorithmMapping.isGANLearningParameter(node.getClass());
assert (supervisedLearningParameter || reinforcementLearningParameter) :
assert (supervisedLearningParameter || reinforcementLearningParameter || ganLearningParameter) :
"Parameter " + node.getName() + " is not checkable, because it is unknown to Condition";
if (supervisedLearningParameter && !reinforcementLearningParameter) {
if (supervisedLearningParameter && !reinforcementLearningParameter && !ganLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node);
} else if(!supervisedLearningParameter && reinforcementLearningParameter) {
} else if(!supervisedLearningParameter && reinforcementLearningParameter && !ganLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node);
}
} else if(!supervisedLearningParameter && !reinforcementLearningParameter && ganLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotGAN(node);
} else if(learningMethodKnown && learningMethod.equals(LearningMethod.REINFORCEMENT)
&& supervisedLearningParameter && !reinforcementLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node);
} else if(learningMethodKnown && learningMethod.equals(LearningMethod.REINFORCEMENT)
&& ganLearningParameter && !reinforcementLearningParameter) {
setLearningMethodOrLogErrorIfActualLearningMethodIsNotGAN(node);
}
}
private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) {
if (isLearningMethodKnown()) {
......@@ -88,14 +97,27 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
}
}
private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotGAN(ASTEntry node) {
if (isLearningMethodKnown()) {
if (!learningMethod.equals(LearningMethod.GAN)) {
Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER + " Parameter "
+ node.getName() + " is not supported for " + this.learningMethod + " learning.",
node.get_SourcePositionStart());
}
} else {
setLearningMethodToGAN();
}
}
private void evaluateLearningMethodEntry(ASTEntry node) {
ASTLearningMethodValue learningMethodValue = (ASTLearningMethodValue)node.getValue();
LearningMethod evaluatedLearningMethod;
if(learningMethodValue.isPresentReinforcement()) {
if(learningMethodValue.isPresentReinforcement())
evaluatedLearningMethod = LearningMethod.REINFORCEMENT;
} else {
else if(learningMethodValue.isPresentGan())
evaluatedLearningMethod = LearningMethod.GAN;
else
evaluatedLearningMethod = LearningMethod.SUPERVISED;
}
if (isLearningMethodKnown()) {
logErrorIfEvaluatedLearningMethoNotEqualToActual(node, evaluatedLearningMethod);
......@@ -127,16 +149,21 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
if (learningMethod.equals(LearningMethod.REINFORCEMENT)) {
return parameterAlgorithmMapping.getAllReinforcementParameters();
}
if (learningMethod.equals(LearningMethod.GAN)) {
return parameterAlgorithmMapping.getAllGANParameters();
}
return parameterAlgorithmMapping.getAllSupervisedParameters();
}
private void setLearningMethod(final LearningMethod learningMethod) {
if (learningMethod.equals(LearningMethod.REINFORCEMENT)) {
if (learningMethod.equals(LearningMethod.REINFORCEMENT))
setLearningMethodToReinforcement();
} else {
else if (learningMethod.equals(LearningMethod.GAN))
setLearningMethodToGAN();
else
setLearningMethodToSupervised();
}