Commit 3889c267 authored by Julian Dierkes's avatar Julian Dierkes
Browse files

Merge local branch to develop

parents 300c11a6 a8f6e66a
Pipeline #205125 passed with stages
in 13 minutes and 47 seconds
...@@ -27,7 +27,7 @@ masterJobWindows: ...@@ -27,7 +27,7 @@ masterJobWindows:
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags: tags:
- Windows10 - Windows10_OS
BranchJobLinux: BranchJobLinux:
stage: linux stage: linux
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId> <artifactId>cnn-train</artifactId>
<version>0.3.6-SNAPSHOT</version> <version>0.3.8-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
...@@ -132,6 +132,13 @@ ...@@ -132,6 +132,13 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>
......
...@@ -25,6 +25,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -25,6 +25,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ComponentNameValue implements ConfigValue = Name ("."Name)*; ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")"; DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")"; IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")";
IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
...@@ -32,17 +33,24 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -32,17 +33,24 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue; NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue; OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue; TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue; LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
LossWeightsEntry implements ConfigEntry = name:"loss_weights" ":" value:DoubleVectorValue; LossWeightsEntry implements ConfigEntry = name:"loss_weights" ":" value:DoubleVectorValue;
EvalMetricValue implements ConfigValue =(accuracy:"accuracy" EvalMetricEntry implements MultiParamConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
| crossEntropy:"cross_entropy"
| f1:"f1" interface EvalMetricValue extends MultiParamValue;
| mae:"mae" AccuracyEvalMetric implements EvalMetricValue = name:"accuracy";
| mse:"mse" BleuMetric implements EvalMetricValue = name:"bleu" ("{" params:BleuEntry* "}")?;
| rmse:"rmse" CrossEntropyEvalMetric implements EvalMetricValue = name:"cross_entropy";
| topKAccuracy:"top_k_accuracy"); F1EvalMetric implements EvalMetricValue = name:"f1";
MAEEvalMetric implements EvalMetricValue = name:"mae";
MSEEvalMetric implements EvalMetricValue = name:"mse";
PerplexityEvalMetric implements EvalMetricValue = name:"perplexity";
RMSEEvalMetric implements EvalMetricValue = name:"rmse";
TopKAccuracyEvalMetric implements EvalMetricValue = name:"top_k_accuracy";
interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed" LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step" | step:"step"
...@@ -132,6 +140,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -132,6 +140,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue; ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue; RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue;
// Visual attention Extension
SaveAttentionImage implements ConfigEntry = name:"save_attention_image" ":" value:BooleanValue;
// Reinforcement Extensions // Reinforcement Extensions
interface MultiParamValue extends ConfigValue; interface MultiParamValue extends ConfigValue;
...@@ -209,7 +220,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number ...@@ -209,7 +220,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
GymEnvironmentNameEntry implements GymEnvironmentEntry = name:"name" ":" value:StringValue; GymEnvironmentNameEntry implements GymEnvironmentEntry = name:"name" ":" value:StringValue;
interface RosEnvironmentEntry extends Entry; interface RosEnvironmentEntry extends Entry;
RosEnvironmentValue implements EnvironmentValue = | name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?; RosEnvironmentValue implements EnvironmentValue = name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?;
RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue; RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue;
RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue; RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue;
RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue; RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue;
......
...@@ -13,6 +13,8 @@ import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstant ...@@ -13,6 +13,8 @@ import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstant
import java.util.Optional; import java.util.Optional;
class ASTConfigurationUtils { class ASTConfigurationUtils {
final static ParameterAlgorithmMapping parameterChecker = new ParameterAlgorithmMapping();
static boolean isReinforcementLearning(final ASTConfiguration configuration) { static boolean isReinforcementLearning(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e -> return configuration.getEntriesList().stream().anyMatch(e ->
(e instanceof ASTLearningMethodEntry) (e instanceof ASTLearningMethodEntry)
...@@ -94,4 +96,16 @@ class ASTConfigurationUtils { ...@@ -94,4 +96,16 @@ class ASTConfigurationUtils {
public static boolean isContinuousAlgorithm(final ASTConfiguration node) { public static boolean isContinuousAlgorithm(final ASTConfiguration node) {
return isDdpgAlgorithm(node) || isTd3Algorithm(node); return isDdpgAlgorithm(node) || isTd3Algorithm(node);
} }
public static boolean hasRLEntry(ASTConfiguration node) {
//return node.getEntriesList().stream()
// .anyMatch(e -> parameterChecker.isReinforcementLearningParameterOnly(e.getClass()));
for (ASTConfigEntry e : node.getEntriesList()) {
boolean b = parameterChecker.isReinforcementLearningParameterOnly(e.getClass());
if (b) {
return true;
}
}
return false;
}
} }
...@@ -18,6 +18,7 @@ public class CNNTrainCocos { ...@@ -18,6 +18,7 @@ public class CNNTrainCocos {
return new CNNTrainCoCoChecker() return new CNNTrainCoCoChecker()
.addCoCo(new CheckEntryRepetition()) .addCoCo(new CheckEntryRepetition())
.addCoCo(new CheckInteger()) .addCoCo(new CheckInteger())
.addCoCo(new CheckRLParameterOnlyWithLearningMethodSet())
.addCoCo(new CheckFixTargetNetworkRequiresInterval()) .addCoCo(new CheckFixTargetNetworkRequiresInterval())
.addCoCo(new CheckReinforcementRequiresEnvironment()) .addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination()) .addCoCo(new CheckLearningParameterCombination())
...@@ -39,7 +40,8 @@ public class CNNTrainCocos { ...@@ -39,7 +40,8 @@ public class CNNTrainCocos {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker() CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput()) .addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput())
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput()) .addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput())
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension()); .addCoCo(new CheckOUParameterDimensionEqualsActionDimension())
.addCoCo(new CheckTrainedArchitectureHasVectorAction());
checker.checkAll(configurationSymbol); checker.checkAll(configurationSymbol);
} }
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTIntegerListValue;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.monticore.numberunit._ast.ASTNumberWithUnit;
import de.se_rwth.commons.logging.Log;
public class CheckIntegerList implements CNNTrainASTIntegerListValueCoCo {
@Override
public void check(ASTIntegerListValue node) {
for (ASTNumberWithUnit element : node.getNumberList()) {
Double unitNumber = element.getNumber().get();
if ((unitNumber % 1)!= 0) {
Log.error("0" + ErrorCodes.NOT_INTEGER_CODE +" Value has to be an integer."
, node.get_SourcePositionStart());
}
}
}
}
...@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain._cocos; ...@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import java.util.Collection; import java.util.Collection;
...@@ -57,7 +58,8 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC ...@@ -57,7 +58,8 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC
String ouParameterName) { String ouParameterName) {
final int ouParameterDimension = ((Collection<?>) strategyParameters.getParameter(ouParameterName)).size(); final int ouParameterDimension = ((Collection<?>) strategyParameters.getParameter(ouParameterName)).size();
if (ouParameterDimension != actionVectorDimension) { if (ouParameterDimension != actionVectorDimension) {
Log.error("Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have" Log.error("0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ " Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
+ " the same dimensions as the action dimension of output " + " the same dimensions as the action dimension of output "
+ outputNameOfTrainedArchitecture + " which is " + actionVectorDimension, + outputNameOfTrainedArchitecture + " which is " + actionVectorDimension,
configurationSymbol.getSourcePosition()); configurationSymbol.getSourcePosition());
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
public class CheckRLParameterOnlyWithLearningMethodSet implements CNNTrainASTConfigurationCoCo {
@Override
public void check(ASTConfiguration node) {
if (!ASTConfigurationUtils.isReinforcementLearning(node) && ASTConfigurationUtils.hasRLEntry(node)) {
Log.error(ErrorCodes.REQUIRED_PARAMETER_MISSING + " Reinforcement parameter used although learning " +
"method not set", node.get_SourcePositionStart());
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.List;
public class CheckTrainedArchitectureHasVectorAction implements CNNTrainConfigurationSymbolCoCo {
@Override
public void check(ConfigurationSymbol configurationSymbol) {
if (configurationSymbol.getTrainedArchitecture().isPresent()
&& configurationSymbol.isReinforcementLearningMethod()) {
final NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
if (trainedArchitecture.getOutputs().size() == 1) {
final String actionName = trainedArchitecture.getOutputs().get(0);
final List<Integer> actionDimensions = trainedArchitecture.getDimensions().get(actionName);
if (actionDimensions.size() != 1) {
Log.error("0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ " Output of actor network must be a vector", configurationSymbol.getSourcePosition());
}
}
}
}
}
...@@ -38,6 +38,7 @@ class ParameterAlgorithmMapping { ...@@ -38,6 +38,7 @@ class ParameterAlgorithmMapping {
ASTBatchSizeEntry.class, ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class, ASTLoadCheckpointEntry.class,
ASTEvalMetricEntry.class, ASTEvalMetricEntry.class,
ASTExcludeBleuEntry.class,
ASTNormalizeEntry.class, ASTNormalizeEntry.class,
ASTNumEpochEntry.class, ASTNumEpochEntry.class,
ASTLossEntry.class, ASTLossEntry.class,
...@@ -47,7 +48,8 @@ class ParameterAlgorithmMapping { ...@@ -47,7 +48,8 @@ class ParameterAlgorithmMapping {
ASTMarginEntry.class, ASTMarginEntry.class,
ASTLabelFormatEntry.class, ASTLabelFormatEntry.class,
ASTRhoEntry.class, ASTRhoEntry.class,
ASTPreprocessingEntry.class ASTPreprocessingEntry.class,
ASTSaveAttentionImage.class
); );
private static final List<Class> GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList( private static final List<Class> GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList(
...@@ -129,6 +131,15 @@ class ParameterAlgorithmMapping { ...@@ -129,6 +131,15 @@ class ParameterAlgorithmMapping {
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz); || EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
} }
boolean isReinforcementLearningParameterOnly(Class<? extends ASTEntry> entryClazz) {
return (GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz))
&& !GENERAL_PARAMETERS.contains(entryClazz)
&& !EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz);
}
boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) { boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz) return GENERAL_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz) || EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz)
......
...@@ -156,32 +156,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -156,32 +156,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override @Override
public void visit(ASTEvalMetricEntry node) { public void visit(ASTEvalMetricEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName()); processMultiParamConfigVisit(node, node.getValue().getName());
ValueSymbol value = new ValueSymbol(); }
if (node.getValue().isPresentAccuracy()){
value.setValue(EvalMetric.ACCURACY); @Override
} public void endVisit(ASTEvalMetricEntry node) {
else if (node.getValue().isPresentCrossEntropy()){ processMultiParamConfigEndVisit(node);
value.setValue(EvalMetric.CROSS_ENTROPY);
}
else if (node.getValue().isPresentF1()){
value.setValue(EvalMetric.F1);
}
else if (node.getValue().isPresentMae()){
value.setValue(EvalMetric.MAE);
}
else if (node.getValue().isPresentMse()){
value.setValue(EvalMetric.MSE);
}
else if (node.getValue().isPresentRmse()){
value.setValue(EvalMetric.RMSE);
}
else if (node.getValue().isPresentTopKAccuracy()){
value.setValue(EvalMetric.TOP_K_ACCURACY);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
} }
@Override @Override
...@@ -335,7 +315,21 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -335,7 +315,21 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.map(n -> n.getNumber().get()) .map(n -> n.getNumber().get())
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private List<Integer> getIntegerListFromValue(ASTIntegerListValue value) {
return value.getNumberList().stream()
.filter(n -> n.getNumber().isPresent())
.map(n -> n.getNumber().get().intValue())
.collect(Collectors.toList());
}
@Override
public void endVisit(ASTSaveAttentionImage node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override @Override
public void visit(ASTLearningMethodEntry node) { public void visit(ASTLearningMethodEntry node) {
...@@ -644,6 +638,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { ...@@ -644,6 +638,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.filter(n -> n.getNumber().isPresent()) .filter(n -> n.getNumber().isPresent())
.map(n -> n.getNumber().get()) .map(n -> n.getNumber().get())
.collect(Collectors.toList()); .collect(Collectors.toList());
} else if (configValue instanceof ASTIntegerListValue) {
return getIntegerListFromValue((ASTIntegerListValue)configValue);
} }
throw new UnsupportedOperationException("Unknown Value type: " + configValue.getClass()); throw new UnsupportedOperationException("Unknown Value type: " + configValue.getClass());
} }
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum EvalMetric {
ACCURACY{
@Override
public String toString() {
return "accuracy";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "crossEntropy";
}
},
F1{
@Override
public String toString() {
return "f1";
}
},
MAE{
@Override
public String toString() {
return "mae";
}
},
MSE{
@Override
public String toString() {
return "mse";
}
},
RMSE{
@Override
public String toString() {
return "rmse";
}
},
TOP_K_ACCURACY{
@Override
public String toString() {
return "topKAccuracy";
}
}
}
...@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain.helper; ...@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain.helper;
public class ConfigEntryNameConstants { public class ConfigEntryNameConstants {
public static final String LEARNING_METHOD = "learning_method"; public static final String LEARNING_METHOD = "learning_method";
public static final String EVAL_METRIC = "eval_metric";
public static final String NUM_EPISODES = "num_episodes"; public static final String NUM_EPISODES = "num_episodes";
public static final String DISCOUNT_FACTOR = "discount_factor"; public static final String DISCOUNT_FACTOR = "discount_factor";
public static final String NUM_MAX_STEPS = "num_max_steps"; public static final String NUM_MAX_STEPS = "num_max_steps";
......
...@@ -188,4 +188,11 @@ public class AllCoCoTest extends AbstractCoCoTest{ ...@@ -188,4 +188,11 @@ public class AllCoCoTest extends AbstractCoCoTest{
"invalid_cocos_tests", "CheckRosEnvironmentHasOnlyOneRewardSpecification", "invalid_cocos_tests", "CheckRosEnvironmentHasOnlyOneRewardSpecification",
new ExpectedErrorInfo(1, ErrorCodes.CONTRADICTING_PARAMETERS)); new ExpectedErrorInfo(1, ErrorCodes.CONTRADICTING_PARAMETERS));
} }
@Test
public void testInvalidCheckReinforcementLearningEntryIsSet () {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRLParameterOnlyWithLearningMethodSet()),
"invalid_cocos_tests", "CheckRLParameterOnlyWithLearningMethodSet",
new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING));
}
} }
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain.cocos;
import de.monticore.lang.monticar.cnntrain._cocos.*;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.junit.Test;
public class InterCocoTest extends AbstractCoCoTest {
NNArchitecturerBuilder NNBuilder = new NNArchitecturerBuilder();
@Test
public void testValidTD3ActorCritic() {
// given
final NNArchitectureSymbol validActor = NNBuilder.getValidTrainedArchitecture();
final NNArchitectureSymbol validCritic = NNBuilder.getValidCriticArchitecture();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
validActor, validCritic);
// when
checkValidTrainedArchitecture(configurationSymbol);
checkValidCriticArchitecture(configurationSymbol);
}
@Test
public void testValidDDPGActorCritic() {
// given
final NNArchitectureSymbol validActor = NNBuilder.getValidTrainedArchitecture();
final NNArchitectureSymbol validCritic = NNBuilder.getValidCriticArchitecture();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "DdpgConfig",
validActor, validCritic);
// when
checkValidTrainedArchitecture(configurationSymbol);