Commit 04a67e21 authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Cocos fix

parent 33adc56c
......@@ -18,7 +18,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.3.6-SNAPSHOT</version>
<version>0.3.7-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -132,6 +132,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
......
......@@ -208,7 +208,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
GymEnvironmentNameEntry implements GymEnvironmentEntry = name:"name" ":" value:StringValue;
interface RosEnvironmentEntry extends Entry;
RosEnvironmentValue implements EnvironmentValue = | name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?;
RosEnvironmentValue implements EnvironmentValue = name:"ros_interface" ("{" params:RosEnvironmentEntry* "}")?;
RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue;
RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue;
RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue;
......@@ -230,4 +230,4 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
PolicyNoiseEntry implements ConfigEntry = name:"policy_noise" ":" value:NumberValue;
NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue;
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
}
}
\ No newline at end of file
......@@ -13,6 +13,8 @@ import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstant
import java.util.Optional;
class ASTConfigurationUtils {
final static ParameterAlgorithmMapping parameterChecker = new ParameterAlgorithmMapping();
static boolean isReinforcementLearning(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e ->
(e instanceof ASTLearningMethodEntry)
......@@ -94,4 +96,16 @@ class ASTConfigurationUtils {
public static boolean isContinuousAlgorithm(final ASTConfiguration node) {
return isDdpgAlgorithm(node) || isTd3Algorithm(node);
}
public static boolean hasRLEntry(ASTConfiguration node) {
//return node.getEntriesList().stream()
// .anyMatch(e -> parameterChecker.isReinforcementLearningParameterOnly(e.getClass()));
for (ASTConfigEntry e : node.getEntriesList()) {
boolean b = parameterChecker.isReinforcementLearningParameterOnly(e.getClass());
if (b) {
return true;
}
}
return false;
}
}
......@@ -18,6 +18,7 @@ public class CNNTrainCocos {
return new CNNTrainCoCoChecker()
.addCoCo(new CheckEntryRepetition())
.addCoCo(new CheckInteger())
.addCoCo(new CheckRLParameterOnlyWithLearningMethodSet())
.addCoCo(new CheckFixTargetNetworkRequiresInterval())
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
......@@ -39,7 +40,8 @@ public class CNNTrainCocos {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput())
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput())
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension());
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension())
.addCoCo(new CheckTrainedArchitectureHasVectorAction());
checker.checkAll(configurationSymbol);
}
......
......@@ -9,9 +9,6 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
/**
*
*/
public interface CNNTrainConfigurationSymbolCoCo {
void check(ConfigurationSymbol configurationSymbol);
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
/**
*
*/
public interface CNNTrainConfigurationSymbolCoCo {
void check(ConfigurationSymbol configurationSymbol);
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
/**
*
*/
public interface CNNTrainConfigurationSymbolCoCo {
void check(ConfigurationSymbol configurationSymbol);
}
\ No newline at end of file
......@@ -4,7 +4,6 @@
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
......@@ -33,4 +32,4 @@ public class CheckActorCriticRequiresCriticNetwork implements CNNTrainASTConfigu
" network entry", algorithmEntry.get_SourcePositionStart());
}
}
}
}
\ No newline at end of file
......@@ -102,4 +102,4 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
}
}
}
}
\ No newline at end of file
}
......@@ -4,13 +4,12 @@
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Collection;
......@@ -31,21 +30,23 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC
= (MultiParamValueSymbol)configurationSymbol.getEntry(STRATEGY).getValue();
final NNArchitectureSymbol architectureSymbol = configurationSymbol.getTrainedArchitecture().get();
final String outputNameOfTrainedArchitecture = architectureSymbol.getOutputs().get(0);
final int actionDimension = architectureSymbol.getDimensions().get(outputNameOfTrainedArchitecture).size();
final List<Integer> actionDimensions = architectureSymbol.getDimensions().get(outputNameOfTrainedArchitecture);
assert actionDimensions.size() == 1: "Invalid action: DDPG Actor model requires action to be a vector";
final int vectorSize = actionDimensions.get(0);
if (strategyParameters.hasParameter(STRATEGY_OU_MU)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_MU);
vectorSize, STRATEGY_OU_MU);
}
if (strategyParameters.hasParameter(STRATEGY_OU_SIGMA)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_SIGMA);
vectorSize, STRATEGY_OU_SIGMA);
}
if (strategyParameters.hasParameter(STRATEGY_OU_THETA)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_THETA);
vectorSize, STRATEGY_OU_THETA);
}
}
}
......@@ -53,13 +54,14 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC
private void logIfDimensionIsUnequal(ConfigurationSymbol configurationSymbol,
MultiParamValueSymbol strategyParameters,
String outputNameOfTrainedArchitecture,
int actionDimension,
int actionVectorDimension,
String ouParameterName) {
final int ouParameterDimension = ((Collection<?>) strategyParameters.getParameter(ouParameterName)).size();
if (ouParameterDimension != actionDimension) {
Log.error("Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
if (ouParameterDimension != actionVectorDimension) {
Log.error("0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ " Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
+ " the same dimensions as the action dimension of output "
+ outputNameOfTrainedArchitecture + " which is " + actionDimension,
+ outputNameOfTrainedArchitecture + " which is " + actionVectorDimension,
configurationSymbol.getSourcePosition());
}
}
......
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
public class CheckRLParameterOnlyWithLearningMethodSet implements CNNTrainASTConfigurationCoCo {
@Override
public void check(ASTConfiguration node) {
if (!ASTConfigurationUtils.isReinforcementLearning(node) && ASTConfigurationUtils.hasRLEntry(node)) {
Log.error(ErrorCodes.REQUIRED_PARAMETER_MISSING + " Reinforcement parameter used although learning " +
"method not set", node.get_SourcePositionStart());
}
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.List;
public class CheckTrainedArchitectureHasVectorAction implements CNNTrainConfigurationSymbolCoCo {
@Override
public void check(ConfigurationSymbol configurationSymbol) {
if (configurationSymbol.getTrainedArchitecture().isPresent()
&& configurationSymbol.isReinforcementLearningMethod()) {
final NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
if (trainedArchitecture.getOutputs().size() == 1) {
final String actionName = trainedArchitecture.getOutputs().get(0);
final List<Integer> actionDimensions = trainedArchitecture.getDimensions().get(actionName);
if (actionDimensions.size() != 1) {
Log.error("0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ " Output of actor network must be a vector", configurationSymbol.getSourcePosition());
}
}
}
}
}
......@@ -32,4 +32,4 @@ public class CheckTrainedRlNetworkHasExactlyOneInput implements CNNTrainConfigur
}
}
}
}
}
\ No newline at end of file
......@@ -122,6 +122,15 @@ class ParameterAlgorithmMapping {
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
}
boolean isReinforcementLearningParameterOnly(Class<? extends ASTEntry> entryClazz) {
return (GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz))
&& !GENERAL_PARAMETERS.contains(entryClazz)
&& !EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz);
}
boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz);
......
......@@ -7,9 +7,6 @@
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain.helper;
/**
*
*/
public class ConfigEntryNameConstants {
public static final String LEARNING_METHOD = "learning_method";
public static final String NUM_EPISODES = "num_episodes";
......
......@@ -188,4 +188,11 @@ public class AllCoCoTest extends AbstractCoCoTest{
"invalid_cocos_tests", "CheckRosEnvironmentHasOnlyOneRewardSpecification",
new ExpectedErrorInfo(1, ErrorCodes.CONTRADICTING_PARAMETERS));
}
@Test
public void testInvalidCheckReinforcementLearningEntryIsSet () {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRLParameterOnlyWithLearningMethodSet()),
"invalid_cocos_tests", "CheckRLParameterOnlyWithLearningMethodSet",
new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING));
}
}
/**
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnntrain.cocos;
import de.monticore.lang.monticar.cnntrain._cocos.*;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.junit.Test;
public class InterCocoTest extends AbstractCoCoTest {
NNArchitecturerBuilder NNBuilder = new NNArchitecturerBuilder();
@Test
public void testValidTD3ActorCritic() {
// given
final NNArchitectureSymbol validActor = NNBuilder.getValidTrainedArchitecture();
final NNArchitectureSymbol validCritic = NNBuilder.getValidCriticArchitecture();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
validActor, validCritic);
// when
checkValidTrainedArchitecture(configurationSymbol);
checkValidCriticArchitecture(configurationSymbol);
}
@Test
public void testValidDDPGActorCritic() {
// given
final NNArchitectureSymbol validActor = NNBuilder.getValidTrainedArchitecture();
final NNArchitectureSymbol validCritic = NNBuilder.getValidCriticArchitecture();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "DdpgConfig",
validActor, validCritic);
// when
checkValidTrainedArchitecture(configurationSymbol);
checkValidCriticArchitecture(configurationSymbol);
}
@Test
public void testInvalidTrainingArchitectureWithTwoInputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckTrainedRlNetworkHasExactlyOneInput();
NNArchitectureSymbol nnWithTwoInputs = NNBuilder.getTrainedArchitectureWithTwoInputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests",
"TD3Config", nnWithTwoInputs, NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidTrainingArchitectureWithTwoOutputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckTrainedRlNetworkHasExactlyOneOutput();
NNArchitectureSymbol nnWithTwoOutputs = NNBuilder.getTrainedArchitectureWithTwoOutputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests",
"TD3Config", nnWithTwoOutputs, NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(configurationSymbol, cocoUUT, new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension1() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim1", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(3, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension2() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim2", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(2, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidActionDimensionUnequalToOUParameterDimension3() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckOUParameterDimensionEqualsActionDimension();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("invalid_cocos_tests",
"UnequalOUDim3", NNBuilder.getValidTrainedArchitecture(), NNBuilder.getValidCriticArchitecture());
// when
checkInvalidTrainedArchitecture(
configurationSymbol, cocoUUT, new ExpectedErrorInfo(1, ErrorCodes.TRAINED_ARCHITECTURE_ERROR));
}
@Test
public void testInvalidCriticHasNotOneDimensionalOutput() {
//given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkHasExactlyAOneDimensionalOutput();
NNArchitectureSymbol criticWithThreeDimensionalOutput = NNBuilder.getCriticWithThreeDimensionalOutput();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithThreeDimensionalOutput);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasTwoOutputs() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkHasExactlyAOneDimensionalOutput();
NNArchitectureSymbol criticWithThreeDimensionalOutput = NNBuilder.getCriticWithTwoOutputs();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithThreeDimensionalOutput);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentStateDimensions() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentDimensions = NNBuilder.getCriticWithDifferentStateDimensions();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentDimensions);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentActionDimensions() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentDimensions = NNBuilder.getCriticWithDifferentActionDimensions();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentDimensions);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentStateTypes() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentStateTypes = NNBuilder.getCriticWithDifferentStateTypes();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentStateTypes);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentActionTypes() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentActionTypes = NNBuilder.getCriticWithDifferentActionTypes();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentActionTypes);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentStateRanges() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentStateRanges = NNBuilder.getCriticWithDifferentStateRanges();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentStateRanges);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));
}
@Test
public void testInvalidCriticHasDifferentActionRanges() {
// given
CNNTrainConfigurationSymbolCoCo cocoUUT = new CheckCriticNetworkInputs();
NNArchitectureSymbol criticWithDifferentActionRanges = NNBuilder.getCriticWithDifferentActionRanges();
ConfigurationSymbol configurationSymbol = getConfigurationSymbolFrom("valid_tests", "TD3Config",
NNBuilder.getValidTrainedArchitecture(), criticWithDifferentActionRanges);
// when
checkInvalidCriticArchitecture(configurationSymbol, cocoUUT,
new ExpectedErrorInfo(1, ErrorCodes.CRITIC_NETWORK_ERROR));