Commit 944c9ed8 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Add Ornstein Uhlenbeck dimension check

parent d3a10207
Pipeline #163024 failed with stages
......@@ -51,7 +51,8 @@ public class CNNTrainCocos {
public static void checkTrainedArchitectureCoCos(final ConfigurationSymbol configurationSymbol) {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput())
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput());
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput())
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension());
checker.checkAll(configurationSymbol);
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.se_rwth.commons.logging.Log;
import java.util.Collection;
import java.util.List;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
/**
*
*/
public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainConfigurationSymbolCoCo {
@Override
public void check(final ConfigurationSymbol configurationSymbol) {
if (configurationSymbol.getTrainedArchitecture().isPresent()
&& configurationSymbol.isReinforcementLearningMethod()
&& configurationSymbol.getEntry(STRATEGY).getValue().getValue().equals(STRATEGY_OU)) {
final MultiParamValueSymbol strategyParameters
= (MultiParamValueSymbol)configurationSymbol.getEntry(STRATEGY).getValue();
final NNArchitectureSymbol architectureSymbol = configurationSymbol.getTrainedArchitecture().get();
final String outputNameOfTrainedArchitecture = architectureSymbol.getOutputs().get(0);
final int actionDimension = architectureSymbol.getDimensions().get(outputNameOfTrainedArchitecture).size();
if (strategyParameters.hasParameter(STRATEGY_OU_MU)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_MU);
}
if (strategyParameters.hasParameter(STRATEGY_OU_SIGMA)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_SIGMA);
}
if (strategyParameters.hasParameter(STRATEGY_OU_THETA)) {
logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
actionDimension, STRATEGY_OU_THETA);
}
}
}
private void logIfDimensionIsUnequal(ConfigurationSymbol configurationSymbol,
MultiParamValueSymbol strategyParameters,
String outputNameOfTrainedArchitecture,
int actionDimension,
String ouParameterName) {
final int ouParameterDimension = ((Collection<?>) strategyParameters.getParameter(ouParameterName)).size();
if (ouParameterDimension != actionDimension) {
Log.error("Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
+ " the same dimensions as the action dimension of output "
+ outputNameOfTrainedArchitecture + " which is " + actionDimension,
configurationSymbol.getSourcePosition());
}
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
......
......@@ -44,6 +44,10 @@ public class MultiParamValueSymbol extends ValueSymbol {
return parameters.get(parameterName);
}
public boolean hasParameter(final String parameterName) {
return parameters.containsKey(parameterName);
}
public void addParameter(final String parameterName, final Object value) {
parameters.put(parameterName, value);
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain.helper;
/**
......@@ -32,6 +52,9 @@ public class ConfigEntryNameConstants {
public static final String STRATEGY = "strategy";
public static final String STRATEGY_OU = "ornstein_uhlenbeck";
public static final String STRATEGY_OU_MU = "mu";
public static final String STRATEGY_OU_THETA = "theta";
public static final String STRATEGY_OU_SIGMA = "sigma";
public static final String STRATEGY_GAUSSIAN = "gaussian";
public static final String STRATEGY_EPSGREEDY = "epsgreedy";
public static final String STRATEGY_EPSDECAY = "epsdecay";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment