Commit cd2644d1 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Implement TD3 algorithm

parent 7184ebe9
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.2-SNAPSHOT</version>
<version>0.2.3-SNAPSHOT-NG</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,7 +16,7 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.3.5-SNAPSHOT-NG</CNNTrain.version>
<CNNArch2X.version>0.0.2-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
......
......@@ -119,7 +119,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
if (rlAlgorithm.equals(RLAlgorithm.DDPG)) {
if (rlAlgorithm.equals(RLAlgorithm.DDPG)
|| rlAlgorithm.equals(RLAlgorithm.TD3)) {
CriticNetworkGenerator criticNetworkGenerator = new CriticNetworkGenerator();
criticNetworkGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
......
......@@ -29,6 +29,9 @@ public class ReinforcementConfigurationData extends ConfigurationData {
private static final String AST_ENTRY_START_TRAINING_AT = "start_training_at";
private static final String AST_SOFT_TARGET_UPDATE_RATE = "soft_target_update_rate";
private static final String AST_EVALUATION_SAMPLES = "evaluation_samples";
private static final String AST_ENTRY_POLICY_NOISE = "policy_noise";
private static final String AST_ENTRY_NOISE_CLIP = "noise_clip";
private static final String AST_ENTRY_POLICY_DELAY = "policy_delay";
private static final String ENVIRONMENT_PARAM_REWARD_TOPIC = "reward_topic";
private static final String ENVIRONMENT_ROS = "ros_interface";
......@@ -118,7 +121,20 @@ public class ReinforcementConfigurationData extends ConfigurationData {
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_EVALUATION_SAMPLES);
}
public Double getPolicyNoise() {
return !configurationContainsKey(AST_ENTRY_POLICY_NOISE)
? null : (Double) retrieveConfigurationEntryValueByKey(AST_ENTRY_POLICY_NOISE);
}
public Double getNoiseClip() {
return !configurationContainsKey(AST_ENTRY_NOISE_CLIP)
? null : (Double) retrieveConfigurationEntryValueByKey(AST_ENTRY_NOISE_CLIP);
}
public Integer getPolicyDelay() {
return !configurationContainsKey(AST_ENTRY_POLICY_DELAY)
? null : (Integer) retrieveConfigurationEntryValueByKey(AST_ENTRY_POLICY_DELAY);
}
public RLAlgorithm getRlAlgorithm() {
if (!isReinforcementLearning()) {
......
<#setting number_format="computer">
<#assign config = configurations[0]>
<#assign rlAgentType=config.rlAlgorithm?switch("dqn", "DqnAgent", "ddpg", "DdpgAgent")>
<#assign rlAgentType=config.rlAlgorithm?switch("dqn", "DqnAgent", "ddpg", "DdpgAgent", "td3", "TwinDelayedDdpgAgent")>
from ${rlFrameworkModule}.agent import ${rlAgentType}
from ${rlFrameworkModule}.util import AgentSignalHandler
from ${rlFrameworkModule}.cnnarch_logger import ArchLogger
<#if config.rlAlgorithm=="ddpg">
<#if config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
from ${rlFrameworkModule}.CNNCreator_${criticInstanceName} import CNNCreator_${criticInstanceName}
</#if>
import ${rlFrameworkModule}.environment
......@@ -137,8 +137,10 @@ if __name__ == "__main__":
</#if>
<#if (config.rlAlgorithm == "dqn")>
<#include "params/DqnAgentParams.ftl">
<#else>
<#elseif config.rlAlgorithm == "ddpg">
<#include "params/DdpgAgentParams.ftl">
<#else>
<#include "params/Td3AgentParams.ftl">
</#if>
}
......
......@@ -641,6 +641,357 @@ class DdpgAgent(Agent):
self._critic, self._state_dim, self._action_dim, ctx=self._ctx)
class TwinDelayedDdpgAgent(DdpgAgent):
def __init__(
self,
actor,
critic,
environment,
replay_memory_params,
strategy_params,
state_dim,
action_dim,
soft_target_update_rate=.001,
actor_optimizer='adam',
actor_optimizer_params={'learning_rate': 0.0001},
critic_optimizer='adam',
critic_optimizer_params={'learning_rate': 0.001},
ctx=None,
discount_factor=.9,
training_episodes=50,
start_training=20,
train_interval=1,
snapshot_interval=200,
agent_name='DdpgAgent',
max_episode_step=9999,
evaluation_samples=100,
output_directory='model_parameters',
verbose=True,
target_score=None,
policy_noise=0.2,
noise_clip=0.5,
policy_delay=2
):
super(TwinDelayedDdpgAgent, self).__init__(
environment=environment, replay_memory_params=replay_memory_params,
strategy_params=strategy_params, state_dim=state_dim,
action_dim=action_dim, ctx=ctx, discount_factor=discount_factor,
training_episodes=training_episodes, start_training=start_training,
train_interval=train_interval,
snapshot_interval=snapshot_interval, agent_name=agent_name,
max_episode_step=max_episode_step,
output_directory=output_directory, verbose=verbose,
target_score=target_score, evaluation_samples=evaluation_samples,
critic=critic, soft_target_update_rate=soft_target_update_rate,
actor=actor, actor_optimizer=actor_optimizer,
actor_optimizer_params=actor_optimizer_params,
critic_optimizer=critic_optimizer,
critic_optimizer_params=critic_optimizer_params)
self._policy_noise = policy_noise
self._noise_clip = noise_clip
self._policy_delay = policy_delay
self._critic2 = self._critic.__class__()
self._critic2.collect_params().initialize(
mx.init.Normal(), ctx=self._ctx)
self._critic2.hybridize()
self._critic2(nd.ones((1,) + state_dim, ctx=self._ctx),
nd.ones((1,) + action_dim, ctx=self._ctx))
self._critic2_target = self._copy_critic2()
self._critic2_optimizer = critic_optimizer
self._critic2_optimizer_params = self._adjust_optimizer_params(
critic_optimizer_params)
def _make_pickle_ready(self, session_dir):
super(TwinDelayedDdpgAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._critic2, 'critic2', session_dir)
self._critic2 = None
self._save_net(self._critic2_target, 'critic2_target', session_dir)
self._critic2_target = None
@classmethod
def resume_from_session(cls, session_dir, actor, critic, environment):
import pickle
if not os.path.exists(session_dir):
raise ValueError('Session directory does not exist')
files = dict()
files['agent'] = os.path.join(session_dir, 'agent.p')
files['best_net_params'] = os.path.join(session_dir, 'best_net.params')
files['actor_net_params'] = os.path.join(session_dir, 'actor.params')
files['actor_target_net_params'] = os.path.join(
session_dir, 'actor_target.params')
files['critic_net_params'] = os.path.join(session_dir, 'critic.params')
files['critic_target_net_params'] = os.path.join(
session_dir, 'critic_target.params')
files['critic2_net_params'] = os.path.join(
session_dir, 'critic2.params')
files['critic2_target_net_params'] = os.path.join(
session_dir, 'critic2_target.params')
for file in files.values():
if not os.path.exists(file):
raise ValueError(
'Session directory is not complete: {} is missing'
.format(file))
with open(files['agent'], 'rb') as f:
agent = pickle.load(f)
agent._environment = environment
agent._actor = actor
agent._actor.load_parameters(files['actor_net_params'], agent._ctx)
agent._actor.hybridize()
agent._actor(nd.random_normal(
shape=((1,) + agent._state_dim), ctx=agent._ctx))
agent._best_net = copy_net(agent._actor, agent._state_dim, agent._ctx)
agent._best_net.load_parameters(files['best_net_params'], agent._ctx)
agent._actor_target = copy_net(
agent._actor, agent._state_dim, agent._ctx)
agent._actor_target.load_parameters(files['actor_target_net_params'])
agent._critic = critic
agent._critic.load_parameters(files['critic_net_params'], agent._ctx)
agent._critic.hybridize()
agent._critic(
nd.random_normal(shape=((1,) + agent._state_dim), ctx=agent._ctx),
nd.random_normal(shape=((1,) + agent._action_dim), ctx=agent._ctx))
agent._critic_target = copy_net_with_two_inputs(
agent._critic, agent._state_dim, agent._action_dim, agent._ctx)
agent._critic_target.load_parameters(files['critic_target_net_params'])
agent._critic2 = copy_net_with_two_inputs(
agent._critic, agent._state_dim, agent._action_dim, agent._ctx)
agent._critic2.load_parameters(files['critic2_net_params'], agent._ctx)
agent._critic2.hybridize()
agent._critic2(
nd.random_normal(shape=((1,) + agent._state_dim), ctx=agent._ctx),
nd.random_normal(shape=((1,) + agent._action_dim), ctx=agent._ctx))
agent._critic2_target = copy_net_with_two_inputs(
agent._critic2, agent._state_dim, agent._action_dim, agent._ctx)
agent._critic2_target.load_parameters(
files['critic2_target_net_params'])
agent._logger = ArchLogger.get_logger()
agent._training_stats.logger = ArchLogger.get_logger()
agent._logger.info('Agent was retrieved; Training can be continued')
return agent
def _copy_critic2(self):
assert self._critic2 is not None
assert self._ctx is not None
assert type(self._state_dim) is tuple
assert type(self._action_dim) is tuple
return copy_net_with_two_inputs(
self._critic2, self._state_dim, self._action_dim, ctx=self._ctx)
def train(self, episodes=None):
self.save_config_file()
self._logger.info("--- Start TwinDelayedDDPG training ---")
episodes = \
episodes if episodes is not None else self._training_episodes
resume = (self._current_episode > 0)
if resume:
self._logger.info("Training session resumed")
self._logger.info(
"Starting from episode {}".format(self._current_episode))
else:
self._training_stats = DdpgTrainingStats(episodes)
# Initialize target Q1' and Q2' and mu'
self._actor_target = self._copy_actor()
self._critic_target = self._copy_critic()
self._critic2_target = self._copy_critic2()
# Initialize l2 loss for critic network
l2_loss = gluon.loss.L2Loss()
# Initialize critic and actor trainer
trainer_actor = gluon.Trainer(
self._actor.collect_params(), self._actor_optimizer,
self._actor_optimizer_params)
trainer_critic = gluon.Trainer(
self._critic.collect_params(), self._critic_optimizer,
self._critic_optimizer_params)
trainer_critic2 = gluon.Trainer(
self._critic2.collect_params(), self._critic2_optimizer,
self._critic2_optimizer_params)
# For episode=1..n
while self._current_episode < episodes:
# Check interrupt flag
if self._check_interrupt_routine():
return False
# Initialize new episode
step = 0
episode_reward = 0
start = time.time()
episode_critic_loss = 0
episode_actor_loss = 0
episode_avg_q_value = 0
training_steps = 0
actor_updates = 0
# Get initialial observation state s
state = self._environment.reset()
# For step=1..T
while step < self._max_episode_step:
# Select an action a = mu(s) + N(step) according to current
# actor and exploration noise N according to strategy
action = self._strategy.select_action(
self.get_next_action(state))
self._strategy.decay(self._current_episode)
# Execute action a and observe reward r and next state ns
next_state, reward, terminal, _ = \
self._environment.step(action)
self._logger.debug(
'Applied action {} with reward {}'.format(action, reward))
# Store transition (s,a,r,ns) in replay buffer
self._memory.append(
state, action, reward, next_state, terminal)
if self._do_training():
# Sample random minibatch of b transitions
# (s_i, a_i, r_i, s_(i+1)) from replay buffer
states, actions, rewards, next_states, terminals =\
self._sample_from_memory()
clipped_noise = nd.array(
np.clip(
np.random.normal(
loc=0, scale=self._policy_noise,
size=self._minibatch_size
).reshape(self._minibatch_size, 1),
-self._noise_clip,
self._noise_clip
),
ctx=self._ctx
)
target_action = np.clip(
self._actor_target(next_states) + clipped_noise,
self._strategy._action_low,
self._strategy._action_high)
rewards = rewards.reshape(self._minibatch_size, 1)
terminals = terminals.reshape(self._minibatch_size, 1)
target_qvalues1 = self._critic_target(next_states,
target_action)
target_qvalues2 = self._critic2_target(next_states,
target_action)
target_qvalues = nd.minimum(target_qvalues1,
target_qvalues2)
y = rewards + (1 - terminals) * self._discount_factor\
* target_qvalues
with autograd.record():
qvalues1 = self._critic(states, actions)
critic1_loss = l2_loss(qvalues1, y)
critic1_loss.backward()
trainer_critic.step(self._minibatch_size)
with autograd.record():
qvalues2 = self._critic2(states, actions)
critic2_loss = l2_loss(qvalues2, y)
critic2_loss.backward()
trainer_critic2.step(self._minibatch_size)
critic_loss = (critic1_loss.mean() + critic2_loss.mean())/2
if self._total_steps % self._policy_delay == 0:
tmp_critic = self._copy_critic()
with autograd.record():
actor_loss = -tmp_critic(
states, self._actor(states)).mean()
actor_loss.backward()
trainer_actor.step(self._minibatch_size)
# Update target networks:
self._actor_target = self._soft_update(
self._actor, self._actor_target,
self._soft_target_update_rate)
self._critic_target = self._soft_update(
self._critic, self._critic_target,
self._soft_target_update_rate)
self._critic2_target = self._soft_update(
self._critic2, self._critic2_target,
self._soft_target_update_rate)
actor_updates = actor_updates + 1
else:
actor_loss = nd.array([0], ctx=self._ctx)
# Update statistics
episode_critic_loss +=\
np.sum(critic_loss.asnumpy()) / self._minibatch_size
episode_actor_loss += 0 if actor_updates == 0 else\
np.sum(actor_loss.asnumpy()[0])
episode_avg_q_value = 0
training_steps += 1
episode_reward += reward
step += 1
self._total_steps += 1
state = next_state
if terminal:
# Reset the strategy
self._strategy.reset()
break
# Log the episode results
episode_actor_loss = 0 if actor_updates == 0\
else (episode_actor_loss / actor_updates)
episode_critic_loss = 0 if training_steps == 0\
else (episode_critic_loss / training_steps)
episode_avg_q_value = 0 if training_steps == 0\
else (episode_avg_q_value / training_steps)
avg_reward = self._training_stats.log_episode(
self._current_episode, start, training_steps,
episode_actor_loss, episode_critic_loss, episode_avg_q_value,
self._strategy.cur_eps, episode_reward)
self._do_snapshot_if_in_interval(self._current_episode)
if self._is_target_reached(avg_reward):
self._logger.info(
'Target score is reached in average; Training is stopped')
break
self._current_episode += 1
self._evaluate()
self.save_parameters(episode=self._current_episode)
self.save_best_network(os.path.join(self._output_directory, 'best'))
self._training_stats.save_stats(self._output_directory)
self._logger.info('--------- Training finished ---------')
return True
def _make_config_dict(self):
config = super(TwinDelayedDdpgAgent, self)._make_config_dict()
config['policy_noise'] = self._policy_noise
config['noise_clip'] = self._noise_clip
config['policy_delay'] = self._policy_delay
return config
class DqnAgent(Agent):
def __init__(
self,
......
<#include "DdpgAgentParams.ftl">
<#if (config.policyNoise)??>
'policy_noise': ${config.policyNoise},
</#if>
<#if (config.noiseClip)??>
'noise_clip': ${config.noiseClip},
</#if>
<#if (config.policyDelay)??>
'policy_delay': ${config.policyDelay},
</#if>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment