Aufgrund einer Wartung wird GitLab am 19.10. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 19.10. between 8:00 and 9:00 am.

Commit 53fe8e9c authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Fix cocos for critic inputs

parent 944c9ed8
......@@ -48,12 +48,6 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
Log.error("Malformed trained architecture");
}
if (trainedArchitecture.getInputs().size() != 2) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ "Number of critic network inputs is wrong. Critic network has two inputs," +
"first needs to be a state input and second needs to be the action input.");
}
final String stateInput = trainedArchitecture.getInputs().get(0);
final String actionOutput = trainedArchitecture.getOutputs().get(0);
final List<Integer> stateDimensions = trainedArchitecture.getDimensions().get(stateInput);
......@@ -66,23 +60,29 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
String criticInput1 = criticNetwork.getInputs().get(0);
String criticInput2 = criticNetwork.getInputs().get(1);
if (criticNetwork.getDimensions().get(criticInput1).equals(stateDimensions)) {
if (criticNetwork.getInputs().size() != 2) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ "Number of critic network inputs is wrong. Critic network has two inputs," +
"first needs to be a state input and second needs to be the action input.");
}
if (!criticNetwork.getDimensions().get(criticInput1).equals(stateDimensions)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Dimensions of first input of critic architecture must be" +
" equal to state's dimensions "
+ stateDimensions.stream().map(Object::toString).collect(Collectors.joining("{", ",", "}"))
+ stateDimensions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}"))
+ ".", configurationSymbol.getSourcePosition());
}
if (criticNetwork.getDimensions().get(criticInput2).equals(actionDimensions)) {
if (!criticNetwork.getDimensions().get(criticInput2).equals(actionDimensions)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Dimensions of second input of critic architecture must be" +
" equal to action's dimensions "
+ actionDimensions.stream().map(Object::toString).collect(Collectors.joining("{", ",", "}"))
+ actionDimensions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}"))
+ ".", configurationSymbol.getSourcePosition());
}
if (criticNetwork.getRanges().get(criticInput1).equals(stateRange)) {
if (!criticNetwork.getRanges().get(criticInput1).equals(stateRange)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Ranges of first input of critic architecture must be" +
" equal to state's ranges "
......@@ -90,7 +90,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
+ ".", configurationSymbol.getSourcePosition());
}
if (criticNetwork.getRanges().get(criticInput2).equals(actionRange)) {
if (!criticNetwork.getRanges().get(criticInput2).equals(actionRange)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Ranges of second input of critic architecture must be" +
" equal to action's ranges "
......@@ -98,7 +98,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
+ ".", configurationSymbol.getSourcePosition());
}
if (criticNetwork.getTypes().get(criticInput1).equals(stateType)) {
if (!criticNetwork.getTypes().get(criticInput1).equals(stateType)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Type of first input of critic architecture must be" +
" equal to state's types "
......@@ -106,7 +106,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo
+ ".", configurationSymbol.getSourcePosition());
}
if (criticNetwork.getTypes().get(criticInput2).equals(actionType)) {
if (!criticNetwork.getTypes().get(criticInput2).equals(actionType)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Type of second input of critic architecture must be" +
" equal to action's types "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment