Commit 11efa54b authored by Nicola Gatto's avatar Nicola Gatto

Add td3 parameters policy_noise, noise_clip, policy_delay

parent 73e80dc0
Pipeline #158873 passed with stages
in 7 minutes and 37 seconds
......@@ -215,8 +215,13 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue;
// DDPG exclusive parameters
// DDPG and TD3 exclusive parameters
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue;
CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue;
// TD3 exclusive parameters
PolicyNoiseEntry implements ConfigEntry = name:"policy_noise" ":" value:NumberValue;
NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue;
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
}
\ No newline at end of file
......@@ -108,7 +108,17 @@ class ParameterAlgorithmMapping {
ASTStrategyOUSigma.class
);
private static final List<Class> EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList(EXCLUSIVE_DDPG_PARAMETERS);
private static final List<Class> EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList(
ASTCriticNetworkEntry.class,
ASTSoftTargetUpdateRateEntry.class,
ASTCriticOptimizerEntry.class,
ASTStrategyOUMu.class,
ASTStrategyOUTheta.class,
ASTStrategyOUSigma.class,
ASTPolicyNoiseEntry.class,
ASTNoiseClipEntry.class,
ASTPolicyDelayEntry.class
);
ParameterAlgorithmMapping() {
......@@ -118,7 +128,8 @@ class ParameterAlgorithmMapping {
return GENERAL_PARAMETERS.contains(entryClazz)
|| GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz);
|| EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
}
boolean isSupervisedLearningParameter(Class<? extends ASTEntry> entryClazz) {
......
......@@ -516,6 +516,30 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTPolicyNoiseEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTNoiseClipEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForDouble(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
@Override
public void visit(ASTPolicyDelayEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
entry.setValue(getValueSymbolForInteger(node.getValue()));
addToScopeAndLinkWithNode(entry, node);
configuration.getEntryMap().put(node.getName(), entry);
}
private void processMultiParamConfigVisit(ASTMultiParamConfigEntry node, Object value) {
EntrySymbol entry = new EntrySymbol(node.getName());
MultiParamValueSymbol valueSymbol = new MultiParamValueSymbol();
......
......@@ -4,6 +4,10 @@ configuration TD3Config {
critic : path.to.component
environment : gym { name:"CartPole-v1" }
soft_target_update_rate: 0.001
policy_noise: 0.2
noise_clip: 0.5
policy_delay: 2
actor_optimizer : adam{
learning_rate : 0.0001
learning_rate_minimum : 0.00005
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment