Commit 3889c267 authored by Julian Dierkes's avatar Julian Dierkes
Browse files

Merge local branch to develop

parents 300c11a6 a8f6e66a
Pipeline #205125 passed with stages
in 13 minutes and 47 seconds
......@@ -27,7 +27,7 @@ masterJobWindows:
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags:
- Windows10
- Windows10_OS
BranchJobLinux:
stage: linux
......
......@@ -18,7 +18,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.3.6-SNAPSHOT</version>
<version>0.3.8-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -132,6 +132,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
......
......@@ -25,6 +25,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
IntegerTupelValue implements ConfigValue = "(" first:IntegerValue "," second:IntegerValue ")";
IntegerListValue implements ConfigValue = "[" number:NumberWithUnit ("," number:NumberWithUnit)* "]";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
......@@ -32,17 +33,24 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
LossWeightsEntry implements ConfigEntry = name:"loss_weights" ":" value:DoubleVectorValue;
EvalMetricValue implements ConfigValue =(accuracy:"accuracy"
| crossEntropy:"cross_entropy"
| f1:"f1"
| mae:"mae"
| mse:"mse"
| rmse:"rmse"
| topKAccuracy:"top_k_accuracy");
EvalMetricEntry implements MultiParamConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
interface EvalMetricValue extends MultiParamValue;
AccuracyEvalMetric implements EvalMetricValue = name:"accuracy";
BleuMetric implements EvalMetricValue = name:"bleu" ("{" params:BleuEntry* "}")?;
CrossEntropyEvalMetric implements EvalMetricValue = name:"cross_entropy";
F1EvalMetric implements EvalMetricValue = name:"f1";
MAEEvalMetric implements EvalMetricValue = name:"mae";
MSEEvalMetric implements EvalMetricValue = name:"mse";
PerplexityEvalMetric implements EvalMetricValue = name:"perplexity";
RMSEEvalMetric implements EvalMetricValue = name:"rmse";
TopKAccuracyEvalMetric implements EvalMetricValue = name:"top_k_accuracy";
interface BleuEntry extends Entry;
ExcludeBleuEntry implements BleuEntry = name:"exclude" ":" value:IntegerListValue;
LRPolicyValue implements ConfigValue =(fixed:"fixed"
| step:"step"
......@@ -132,6 +140,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
ClipWeightsEntry implements RmsPropEntry = name:"clip_weights" ":" value:NumberValue;
RhoEntry implements AdaDeltaEntry,RmsPropEntry,HuberEntry = name:"rho" ":" value:NumberValue;
// Visual attention Extension
SaveAttentionImage implements ConfigEntry = name:"save_attention_image" ":" value:BooleanValue;
// Reinforcement Extensions
interface MultiParamValue extends ConfigValue;
......@@ -209,7 +220,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
GymEnvironmentNameEntry implements GymEnvironmentEntry = name:"name" ":" value:StringValue;
interface RosEnvironmentEntry extends Entry;
RosEnvironmentValue implements EnvironmentValue = | name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?;
RosEnvironmentValue implements EnvironmentValue = name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?;
RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue;
RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue;
RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue;
......
......@@ -13,6 +13,8 @@ import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstant
import java.util.Optional;
class ASTConfigurationUtils {
final static ParameterAlgorithmMapping parameterChecker = new ParameterAlgorithmMapping();
static boolean isReinforcementLearning(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e ->
(e instanceof ASTLearningMethodEntry)
......@@ -94,4 +96,16 @@ class ASTConfigurationUtils {
public static boolean isContinuousAlgorithm(final ASTConfiguration node) {
return isDdpgAlgorithm(node) || isTd3Algorithm(node);
}
public static boolean hasRLEntry(ASTConfiguration node) {
//return node.getEntriesList().stream()
// .anyMatch(e -> parameterChecker.isReinforcementLearningParameterOnly(e.getClass()));
for (ASTConfigEntry e : node.getEntriesList()) {
boolean b = parameterChecker.isReinforcementLearningParameterOnly(e.getClass());
if (b) {
return true;
}
}
return false;
}
}
......@@ -18,6 +18,7 @@ public class CNNTrainCocos {
return new CNNTrainCoCoChecker()
.addCoCo(new CheckEntryRepetition())
.addCoCo(new CheckInteger())
.addCoCo(new CheckRLParameterOnlyWithLearningMethodSet())
.addCoCo(new CheckFixTargetNetworkRequiresInterval())
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
......@@ -39,7 +40,8 @@ public class CNNTrainCocos {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput())
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput())
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension());
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension())
.addCoCo(new CheckTrainedArchitectureHasVectorAction());
checker.checkAll(configurationSymbol);
}
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTIntegerListValue;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.monticore.numberunit._ast.ASTNumberWithUnit;
import de.se_rwth.commons.logging.Log;
public class CheckIntegerList implements CNNTrainASTIntegerListValueCoCo {
@Override
public void check(ASTIntegerListValue node) {
for (ASTNumberWithUnit element : node.getNumberList()) {
Double unitNumber = element.getNumber().get();
if ((unitNumber % 1)!= 0) {
Log.error("0" + ErrorCodes.NOT_INTEGER_CODE +" Value has to be an integer."
, node.get_SourcePositionStart());
}
}
}
}
......@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Collection;
......@@ -57,7 +58,8 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC
String ouParameterName) {
final int ouParameterDimension = ((Collection<?>) strategyParameters.getParameter(ouParameterName)).size();
if (ouParameterDimension != actionVectorDimension) {
Log.error("Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
Log.error("0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ " Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
+ " the same dimensions as the action dimension of output "
+ outputNameOfTrainedArchitecture + " which is " + actionVectorDimension,
configurationSymbol.getSourcePosition());
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
public class CheckRLParameterOnlyWithLearningMethodSet implements CNNTrainASTConfigurationCoCo {
@Override
public void check(ASTConfiguration node) {
if (!ASTConfigurationUtils.isReinforcementLearning(node) && ASTConfigurationUtils.hasRLEntry(node)) {
Log.error(ErrorCodes.REQUIRED_PARAMETER_MISSING + " Reinforcement parameter used although learning " +
"method not set", node.get_SourcePositionStart());
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.List;
public class CheckTrainedArchitectureHasVectorAction implements CNNTrainConfigurationSymbolCoCo {
@Override
public void check(ConfigurationSymbol configurationSymbol) {
if (configurationSymbol.getTrainedArchitecture().isPresent()
&& configurationSymbol.isReinforcementLearningMethod()) {
final NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
if (trainedArchitecture.getOutputs().size() == 1) {
final String actionName = trainedArchitecture.getOutputs().get(0);
final List<Integer> actionDimensions = trainedArchitecture.getDimensions().get(actionName);
if (actionDimensions.size() != 1) {
Log.error("0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ " Output of actor network must be a vector", configurationSymbol.getSourcePosition());
}
}
}
}
}
......@@ -38,6 +38,7 @@ class ParameterAlgorithmMapping {
ASTBatchSizeEntry.class,
ASTLoadCheckpointEntry.class,
ASTEvalMetricEntry.class,
ASTExcludeBleuEntry.class,
ASTNormalizeEntry.class,
ASTNumEpochEntry.class,
ASTLossEntry.class,
......@@ -47,7 +48,8 @@ class ParameterAlgorithmMapping {
ASTMarginEntry.class,
ASTLabelFormatEntry.class,
ASTRhoEntry.class,
ASTPreprocessingEntry.class
ASTPreprocessingEntry.class,
ASTSaveAttentionImage.class
);
private static final List<Class> GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList(
......@@ -129,6 +131,15 @@ class ParameterAlgorithmMapping {
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
}
boolean isReinforcementLearningParameterOnly(Class<? extends ASTEntry> entryClazz) {
return (GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz))
&& !GENERAL_PARAMETERS.contains(entryClazz)
&& !EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz);
}
boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz)
......
......@@ -156,32 +156,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
@Override
public void visit(ASTEvalMetricEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentAccuracy()){
value.setValue(EvalMetric.ACCURACY);
}
else if (node.getValue().isPresentCrossEntropy()){
value.setValue(EvalMetric.CROSS_ENTROPY);
}
else if (node.getValue().isPresentF1()){
value.setValue(EvalMetric.F1);
}
else if (node.getValue().isPresentMae()){
value.setValue(EvalMetric.MAE);
}
else if (node.getValue().isPresentMse()){
value.setValue(EvalMetric.MSE);
}
else if (node.getValue().isPresentRmse()){
value.setValue(EvalMetric.RMSE);
}
else if (node.getValue().isPresentTopKAccuracy()){
value.setValue(EvalMetric.TOP_K_ACCURACY);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
processMultiParamConfigVisit(node, node.getValue().getName());
}
@Override
public void endVisit(ASTEvalMetricEntry node) {
processMultiParamConfigEndVisit(node);
}
@Override
......@@ -335,7 +315,21 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.map(n -> n.getNumber().get())
.collect(Collectors.toList());
}
private List<Integer> getIntegerListFromValue(ASTIntegerListValue value) {
return value.getNumberList().stream()
.filter(n -> n.getNumber().isPresent())
.map(n -> n.getNumber().get().intValue())
.collect(Collectors.toList());
}
@Override
public void endVisit(ASTSaveAttentionImage node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForBoolean(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTLearningMethodEntry node) {
......@@ -644,6 +638,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
.filter(n -> n.getNumber().isPresent())
.map(n -> n.getNumber().get())
.collect(Collectors.toList());
} else if (configValue instanceof ASTIntegerListValue) {
return getIntegerListFromValue((ASTIntegerListValue)configValue);
}
throw new UnsupportedOperationException("Unknown Value type: " + configValue.getClass());
}
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum EvalMetric {
ACCURACY{
@Override
public String toString() {
return "accuracy";
}
},
CROSS_ENTROPY{
@Override
public String toString() {
return "crossEntropy";
}
},
F1{
@Override
public String toString() {
return "f1";
}
},
MAE{
@Override
public String toString() {
return "mae";
}
},
MSE{
@Override
public String toString() {
return "mse";
}
},
RMSE{
@Override
public String toString() {
return "rmse";
}
},
TOP_K_ACCURACY{
@Override
public String toString() {
return "topKAccuracy";
}
}
}
......@@ -9,6 +9,7 @@ package de.monticore.lang.monticar.cnntrain.helper;
public class ConfigEntryNameConstants {
public static final String LEARNING_METHOD = "learning_method";
public static final String EVAL_METRIC = "eval_metric";
public static final String NUM_EPISODES = "num_episodes";
public static final String DISCOUNT_FACTOR = "discount_factor";
public static final String NUM_MAX_STEPS = "num_max_steps";
......
......@@ -188,4 +188,11 @@ public class AllCoCoTest extends AbstractCoCoTest{
"invalid_cocos_tests", "CheckRosEnvironmentHasOnlyOneRewardSpecification",
new ExpectedErrorInfo(1, ErrorCodes.CONTRADICTING_PARAMETERS));
}
@Test
public void testInvalidCheckReinforcementLearningEntryIsSet () {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRLParameterOnlyWithLearningMethodSet()),
"invalid_cocos_tests", "CheckRLParameterOnlyWithLearningMethodSet",
new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING));
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain.cocos;
import de.monticore.lang.monticar.cnntrain._cocos.*;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.junit.Test;
public class InterCocoTest extends AbstractCoCoTest {
NNArchitecturerBuilder NNBuilder = new NNArchitecturerBuilder();
@Test
public void testValidTD3ActorCritic() {
// given
final NNArchitectureSymbol validActor = NNBuilder.getValidTrainedArchitecture();
final NNArchitectureSymbol validCritic = NNBuilder.getValidCriticArchitecture();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
validActor, validCritic);
// when
checkValidTrainedArchitecture(configurationSymbol);
checkValidCriticArchitecture(configurationSymbol);
}
@Test
public void testValidDDPGActorCritic() {
// given
final NNArchitectureSymbol validActor = NNBuilder.getValidTrainedArchitecture();
final NNArchitectureSymbol validCritic = NNBuilder.getValidCriticArchitecture();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "DdpgConfig",
validActor, validCritic);
// when
checkValidTrainedArchitecture(configurationSymbol);
checkValidCriticArchitecture(configurationSymbol);
}
@Test
public void testInvalidTrainingArchitectureWithTwoInputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckTrainedRlNetworkHasExactlyOneInput();
NNArchitectureSymbol nnWithTwoInputs = NNBuilder.getTrainedArchitectureWithTwoInputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests",
"TD3Config", nnWithTwoInputs, NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidTrainingArchitectureWithTwoOutputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckTrainedRlNetworkHasExactlyOneOutput();
NNArchitectureSymbol nnWithTwoOutputs = NNBuilder.getTrainedArchitectureWithTwoOutputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests",
"TD3Config", nnWithTwoOutputs, NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(configurationSymbol, cocoUUT, new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension1() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim1", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(3, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension2() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim2", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(2, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension3() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim3", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidCriticHasNotOneDimensionalOutput() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkHasExactlyAOneDimensionalOutput();
NNArchitectureSymbol criticWithThreeDimensionalOutput = NNBuilder.getCriticWithThreeDimensionalOutput();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithThreeDimensionalOutput);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasTwoOutputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkHasExactlyAOneDimensionalOutput();
NNArchitectureSymbol criticWithThreeDimensionalOutput = NNBuilder.getCriticWithTwoOutputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithThreeDimensionalOutput);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentStateDimensions() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentDimensions = NNBuilder.getCriticWithDifferentStateDimensions();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",