Commit cdc5657a authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Add TD3 reinforcement learning parameter

parent b5b9cdaa
Pipeline #158410 passed with stages
in 8 minutes and 46 seconds
......@@ -145,7 +145,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement");
RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm");
RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm" | tdThree:"td3-algorithm");
interface MultiParamConfigEntry extends ConfigEntry;
......
......@@ -41,8 +41,16 @@ class ASTConfigurationUtils {
e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg());
}
static boolean isTd3Algorithm(final ASTConfiguration configuration) {
return isReinforcementLearning(configuration)
&& configuration.getEntriesList().stream().anyMatch(
e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentTdThree());
}
static boolean isDqnAlgorithm(final ASTConfiguration configuration) {
return isReinforcementLearning(configuration) && !isDdpgAlgorithm(configuration);
return isReinforcementLearning(configuration)
&& !isDdpgAlgorithm(configuration)
&& !isTd3Algorithm(configuration);
}
static boolean hasEntry(final ASTConfiguration configuration, final Class<? extends ASTConfigEntry> entryClazz) {
......@@ -84,4 +92,18 @@ class ASTConfigurationUtils {
}
return false;
}
static boolean isActorCriticAlgorithm(final ASTConfiguration node) {
return isDdpgAlgorithm(node) || isTd3Algorithm(node);
}
static boolean hasCriticEntry(final ASTConfiguration node) {
return node.getEntriesList().stream()
.anyMatch(e -> ((e instanceof ASTCriticNetworkEntry)
&& !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty()));
}
public static boolean isContinuousAlgorithm(final ASTConfiguration node) {
return isDdpgAlgorithm(node) || isTd3Algorithm(node);
}
}
......@@ -34,7 +34,7 @@ public class CNNTrainCocos {
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
.addCoCo(new CheckRosEnvironmentRequiresRewardFunction())
.addCoCo(new CheckDdpgRequiresCriticNetwork())
.addCoCo(new CheckActorCriticRequiresCriticNetwork())
.addCoCo(new CheckRlAlgorithmParameter())
.addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy())
.addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy())
......
......@@ -35,7 +35,7 @@ public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrai
@Override
public void check(ASTConfiguration node) {
if (ASTConfigurationUtils.isDdpgAlgorithm(node)
if (ASTConfigurationUtils.isContinuousAlgorithm(node)
&& ASTConfigurationUtils.hasStrategy(node)
&& ASTConfigurationUtils.getStrategyMethod(node).isPresent()) {
final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get();
......
......@@ -20,6 +20,7 @@
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain._ast.ASTEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
......@@ -29,61 +30,58 @@ import de.se_rwth.commons.logging.Log;
public class CheckRlAlgorithmParameter implements CNNTrainASTEntryCoCo {
private final ParameterAlgorithmMapping parameterAlgorithmMapping;
boolean algorithmKnown;
private boolean isDqn = true;
private boolean isDdpg = true;
private boolean isTd3 = true;
RLAlgorithm algorithm;
public CheckRlAlgorithmParameter() {
parameterAlgorithmMapping = new ParameterAlgorithmMapping();
algorithmKnown = false;
algorithm = null;
}
@Override
public void check(ASTEntry node) {
final boolean isDdpgParameter = parameterAlgorithmMapping.isDdpgParameter(node.getClass());
final boolean isDqnParameter = parameterAlgorithmMapping.isDqnParameter(node.getClass());
if (!parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass())) {
return;
}
if (node instanceof ASTRLAlgorithmEntry) {
ASTRLAlgorithmEntry algorithmEntry = (ASTRLAlgorithmEntry)node;
if (algorithmEntry.getValue().isPresentDdpg()) {
setAlgorithmToDdpg(node);
logWrongParameterIfCheckFails(isDdpg, node);
isTd3 = false;
isDqn = false;
} else if(algorithmEntry.getValue().isPresentTdThree()) {
logWrongParameterIfCheckFails(isTd3, node);
isDdpg = false;
isDqn = false;
} else {
setAlgorithmToDqn(node);
logWrongParameterIfCheckFails(isDqn, node);
isDdpg = false;
isTd3 = false;
}
} else {
if (isDdpgParameter && !isDqnParameter) {
setAlgorithmToDdpg(node);
} else if (!isDdpgParameter && isDqnParameter) {
setAlgorithmToDqn(node);
final boolean isDdpgParameter = parameterAlgorithmMapping.isDdpgParameter(node.getClass());
final boolean isDqnParameter = parameterAlgorithmMapping.isDqnParameter(node.getClass());
final boolean isTd3Parameter = parameterAlgorithmMapping.isTd3Parameter(node.getClass());
if (!isDdpgParameter) {
isDdpg = false;
}
if (!isTd3Parameter) {
isTd3 = false;
}
if (!isDqnParameter) {
isDqn = false;
}
}
logWrongParameterIfCheckFails(isDqn || isTd3 || isDdpg, node);
}
private void logErrorIfAlgorithmIsDqn(final ASTEntry node) {
if (algorithmKnown && algorithm.equals(RLAlgorithm.DQN)) {
Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER
+ " DDPG Parameter " + node.getName() + " used but algorithm is " + algorithm + ".",
node.get_SourcePositionStart());
}
}
private void setAlgorithmToDdpg(final ASTEntry node) {
logErrorIfAlgorithmIsDqn(node);
algorithmKnown = true;
algorithm = RLAlgorithm.DDPG;
}
private void setAlgorithmToDqn(final ASTEntry node) {
logErrorIfAlgorithmIsDdpg(node);
algorithmKnown = true;
algorithm = RLAlgorithm.DQN;
}
private void logErrorIfAlgorithmIsDdpg(final ASTEntry node) {
if (algorithmKnown && algorithm.equals(RLAlgorithm.DDPG)) {
private void logWrongParameterIfCheckFails(final boolean condition, final ASTEntry node) {
if (!condition) {
Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER
+ " DQN Parameter " + node.getName() + " used but algorithm is " + algorithm + ".",
+ "Parameter " + node.getName() + " used but parameter is not for chosen algorithm.",
node.get_SourcePositionStart());
}
}
......
......@@ -108,6 +108,8 @@ class ParameterAlgorithmMapping {
ASTStrategyOUSigma.class
);
private static final List<Class> EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList(EXCLUSIVE_DDPG_PARAMETERS);
ParameterAlgorithmMapping() {
}
......@@ -136,12 +138,19 @@ class ParameterAlgorithmMapping {
|| EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz);
}
boolean isTd3Parameter(Class<? extends ASTEntry> entryClazz) {
return GENERAL_PARAMETERS.contains(entryClazz)
|| GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
}
List<Class> getAllReinforcementParameters() {
return ImmutableList.<Class> builder()
.addAll(GENERAL_PARAMETERS)
.addAll(GENERAL_REINFORCEMENT_PARAMETERS)
.addAll(EXCLUSIVE_DQN_PARAMETERS)
.addAll(EXCLUSIVE_DDPG_PARAMETERS)
.addAll(EXCLUSIVE_TD3_PARAMETERS)
.build();
}
......
......@@ -351,6 +351,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
if (node.getValue().isPresentDdpg()) {
value.setValue(RLAlgorithm.DDPG);
} else if(node.getValue().isPresentTdThree()) {
value.setValue(RLAlgorithm.TD3);
} else {
value.setValue(RLAlgorithm.DQN);
}
......
......@@ -32,5 +32,11 @@ public enum RLAlgorithm {
public String toString() {
return "ddpg";
}
},
TD3 {
@Override
public String toString() {
return "td3";
}
}
}
}
\ No newline at end of file
configuration TD3Config {
learning_method : reinforcement
rl_algorithm : ddpg-algorithm
rl_algorithm : td3-algorithm
critic : path.to.component
environment : gym { name:"CartPole-v1" }
soft_target_update_rate: 0.001
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment