Commit 2df512f2 authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Prepare ddpg algorithm

parent 9de665e2
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.3.0-SNAPSHOT</version>
<version>0.3.1-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -98,6 +98,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface MultiParamValue extends ConfigValue;
LearningMethodEntry implements ConfigEntry = name:"learning_method" ":" value:LearningMethodValue;
RLAlgorithmEntry implements ConfigEntry = name:"rl_algorithm" ":" value:RLAlgorithmValue;
NumEpisodesEntry implements ConfigEntry = name:"num_episodes" ":" value:IntegerValue;
DiscountFactorEntry implements ConfigEntry = name:"discount_factor" ":" value:NumberValue;
NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue;
......@@ -109,11 +110,14 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue;
UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue;
RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue;
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
ComponentNameValue implements ConfigValue = Name ("."Name)*;
LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement");
RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm");
interface MultiParamConfigEntry extends ConfigEntry;
// Replay Memory
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
......
......@@ -33,7 +33,8 @@ public class CNNTrainCocos {
.addCoCo(new CheckFixTargetNetworkRequiresInterval())
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
.addCoCo(new CheckRosEnvironmentRequiresRewardFunction());
.addCoCo(new CheckRosEnvironmentRequiresRewardFunction())
.addCoCo(new CheckDdpgRequiresCriticNetwork());
}
public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain._ast.ASTCriticNetworkEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
public class CheckDdpgRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo {
@Override
public void check(ASTConfiguration node) {
boolean isDdpg = node.getEntriesList().stream()
.anyMatch(e -> e instanceof ASTRLAlgorithmEntry
&& ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg());
boolean hasCriticEntry = node.getEntriesList().stream()
.anyMatch(e -> ((e instanceof ASTCriticNetworkEntry)
&& !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty()));
if (isDdpg && !hasCriticEntry) {
ASTRLAlgorithmEntry algorithmEntry = node.getEntriesList().stream()
.filter(e -> e instanceof ASTRLAlgorithmEntry)
.map(e -> (ASTRLAlgorithmEntry)e)
.findFirst()
.orElseThrow(() -> new IllegalStateException("ASTRLAlgorithmEntry entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critc" +
" network entry", algorithmEntry.get_SourcePositionStart());
}
}
}
\ No newline at end of file
......@@ -61,6 +61,8 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
);
private final static List<Class> ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList(
ASTTrainContextEntry.class,
ASTRLAlgorithmEntry.class,
ASTCriticNetworkEntry.class,
ASTOptimizerEntry.class,
ASTRewardFunctionEntry.class,
ASTMinimumLearningRateEntry.class,
......
......@@ -277,6 +277,19 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
return value;
}
private ValueSymbol getValueSymbolForComponentName(ASTComponentNameValue astComponentNameValue) {
ValueSymbol value = new ValueSymbol();
List<String> valueAsList = astComponentNameValue.getNameList();
value.setValue(valueAsList);
return value;
}
private ValueSymbol getValueSymbolForComponentNameAsString(ASTComponentNameValue astComponentNameValue) {
ValueSymbol value = new ValueSymbol();
value.setValue(String.join(".", astComponentNameValue.getNameList()));
return value;
}
private String getStringFromStringValue(ASTStringValue value) {
return value.getStringLiteral().getValue();
}
......@@ -310,6 +323,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTRLAlgorithmEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
ValueSymbol value = new ValueSymbol();
if (node.getValue().isPresentDdpg()) {
value.setValue(RLAlgorithm.DDPG);
} else {
value.setValue(RLAlgorithm.DQN);
}
entry.setValue(value);
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTNumEpisodesEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
......@@ -390,6 +419,16 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTCriticNetworkEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTReplayMemoryEntry node) {
processMultiParamConfigVisit(node, node.getValue().getName());
......
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnntrain._symboltable;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import java.util.*;
......@@ -30,12 +31,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map<String, EntrySymbol> entryMap = new HashMap<>();
private OptimizerSymbol optimizer;
private RewardFunctionSymbol rlRewardFunctionSymbol;
private TrainedArchitecture trainedArchitecture;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
public ConfigurationSymbol() {
super("", KIND);
rlRewardFunctionSymbol = null;
trainedArchitecture = null;
}
public OptimizerSymbol getOptimizer() {
......@@ -54,6 +57,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(this.rlRewardFunctionSymbol);
}
public Optional<TrainedArchitecture> getTrainedArchitecture() {
return Optional.ofNullable(trainedArchitecture);
}
public void setTrainedArchitecture(TrainedArchitecture trainedArchitecture) {
this.trainedArchitecture = trainedArchitecture;
}
public Map<String, EntrySymbol> getEntryMap() {
return entryMap;
}
......@@ -66,4 +77,4 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return this.entryMap.containsKey("learning_method")
? (LearningMethod)this.entryMap.get("learning_method").getValue().getValue() : LearningMethod.SUPERVISED;
}
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
public enum RLAlgorithm {
DQN {
@Override
public String toString() {
return "dqn";
}
},
DDPG {
@Override
public String toString() {
return "ddpg";
}
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain.annotations;
import java.util.Optional;
public class Range {
private final boolean lowerLimitIsInfinity;
private final boolean upperLimitIsInfinity;
private final Double lowerLimit;
private final Double upperLimit;
private Range(boolean lowerLimitIsInfinity, boolean upperLimitIsInfinity, Double lowerLimit, Double upperLimit) {
this.lowerLimitIsInfinity = lowerLimitIsInfinity;
this.upperLimitIsInfinity = upperLimitIsInfinity;
this.lowerLimit = lowerLimit;
this.upperLimit = upperLimit;
}
public Optional<Double> getLowerLimit() {
return Optional.ofNullable(lowerLimit);
}
public Optional<Double> getUpperLimit() {
return Optional.ofNullable(upperLimit);
}
public boolean isLowerLimitInfinity() {
return this.lowerLimitIsInfinity;
}
public boolean isUpperLimitInfinity() {
return this.upperLimitIsInfinity;
}
public static Range withLimits(double lowerLimit, double upperLimit) {
return new Range(false, false, lowerLimit, upperLimit);
}
public static Range withInfinityLimits() {
return new Range(true, true, null, null);
}
public static Range withUpperInfinityLimit(double lowerLimit) {
return new Range(false, true, lowerLimit, null);
}
public static Range withLowerInfinityLimit(double upperLimit) {
return new Range(true, false, null, upperLimit);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain.annotations;
import java.util.List;
import java.util.Map;
public interface TrainedArchitecture {
public List<String> getInputs();
public List<String> getOutputs();
public Map<String, List<Integer>> getDimensions();
public Map<String, Range> getRanges();
public Map<String, String> getTypes();
}
\ No newline at end of file
......@@ -41,6 +41,7 @@ public class AllCoCoTest extends AbstractCoCoTest{
checkValid("valid_tests","FullConfig2");
checkValid("valid_tests", "ReinforcementConfig");
checkValid("valid_tests", "ReinforcementConfig2");
checkValid("valid_tests", "DdpgConfig");
}
@Test
......
configuration DdpgConfig {
learning_method : reinforcement
rl_algorithm : ddpg-algorithm
critic : path.to.component
environment : gym { name:"CartPole-v1" }
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment