Commit dc20415c authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Fix dimension check

parent 18373198
Pipeline #189094 passed with stages
in 8 minutes and 50 seconds
......@@ -4,13 +4,11 @@
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.se_rwth.commons.logging.Log;
import java.util.Collection;
......@@ -31,21 +29,23 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC
= (MultiParamValueSymbol)configurationSymbol.getEntry(STRATEGY).getValue();
final NNArchitectureSymbol architectureSymbol = configurationSymbol.getTrainedArchitecture().get();
final String outputNameOfTrainedArchitecture = architectureSymbol.getOutputs().get(0);
final int actionDimension = architectureSymbol.getDimensions().get(outputNameOfTrainedArchitecture).size();
final List<Integer> actionDimensions = architectureSymbol.getDimensions().get(outputNameOfTrainedArchitecture);
assert actionDimensions.size() == 1: "Invalid action: DDPG Actor model requires action to be a vector";
final int vectorSize = actionDimensions.get(0);
if (strategyParameters.hasParameter(STRATEGY_OU_MU)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_MU);
vectorSize, STRATEGY_OU_MU);
}
if (strategyParameters.hasParameter(STRATEGY_OU_SIGMA)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_SIGMA);
vectorSize, STRATEGY_OU_SIGMA);
}
if (strategyParameters.hasParameter(STRATEGY_OU_THETA)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_THETA);
vectorSize, STRATEGY_OU_THETA);
}
}
}
......@@ -53,13 +53,13 @@ public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainC
private void logIfDimensionIsUnequal(ConfigurationSymbol configurationSymbol,
MultiParamValueSymbol strategyParameters,
String outputNameOfTrainedArchitecture,
int actionDimension,
int actionVectorDimension,
String ouParameterName) {
final int ouParameterDimension = ((Collection<?>) strategyParameters.getParameter(ouParameterName)).size();
if (ouParameterDimension != actionDimension) {
if (ouParameterDimension != actionVectorDimension) {
Log.error("Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
+ " the same dimensions as the action dimension of output "
+ outputNameOfTrainedArchitecture + " which is " + actionDimension,
+ outputNameOfTrainedArchitecture + " which is " + actionVectorDimension,
configurationSymbol.getSourcePosition());
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment