diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl index d58d42a64d6a8062bf3d4a28a2084a0425d45cc9..21ecbdaa604a6f2b9b4ee5ee14305646cd3ff1f8 100644 --- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl +++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl @@ -504,6 +504,7 @@ class DdpgAgent(Agent): # actor and exploration noise N according to strategy action = self._strategy.select_action( self.get_next_action(state)) + self._strategy.decay(self._current_episode) # Execute action a and observe reward r and next state ns next_state, reward, terminal, _ = \ @@ -545,10 +546,10 @@ class DdpgAgent(Agent): # with critic parameters tmp_critic = self._copy_critic() with autograd.record(): - actor_qvalues = tmp_critic(states, self._actor(states)) # For maximizing qvalues we have to multiply with -1 # as we use a minimizer - actor_loss = -1 * actor_qvalues.mean() + actor_loss = -tmp_critic( + states, self._actor(states)).mean() actor_loss.backward() trainer_actor.step(self._minibatch_size) @@ -594,7 +595,6 @@ class DdpgAgent(Agent): self._strategy.cur_eps, episode_reward) self._do_snapshot_if_in_interval(self._current_episode) - self._strategy.decay(self._current_episode) if self._is_target_reached(avg_reward): self._logger.info( @@ -1190,6 +1190,7 @@ class DqnAgent(Agent): # 1. Choose an action based on current game state and policy q_values = self._qnet(nd.array([state], ctx=self._ctx)) action = self._strategy.select_action(q_values[0]) + self._strategy.decay(self._current_episode) # 2. Play the game for a single step next_state, reward, terminal, _ =\ @@ -1226,8 +1227,6 @@ class DqnAgent(Agent): self._current_episode, start, training_steps, episode_loss, self._strategy.cur_eps, episode_reward) - self._strategy.decay(self._current_episode) - if self._is_target_reached(avg_reward): self._logger.info( 'Target score is reached in average; Training is stopped') diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Strategy.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Strategy.ftl index a823466fd0a74542436b50bccaf80cc935a0e172..38abc42bdd48cea2598baecf68dc8e6c7bd45f43 100644 --- a/src/main/resources/templates/gluon/reinforcement/agent/Strategy.ftl +++ b/src/main/resources/templates/gluon/reinforcement/agent/Strategy.ftl @@ -13,6 +13,7 @@ class StrategyBuilder(object): epsilon_decay_method='no', epsilon_decay=0.0, epsilon_decay_start=0, + epsilon_decay_per_step=False, action_dim=None, action_low=None, action_high=None, @@ -24,7 +25,8 @@ class StrategyBuilder(object): if epsilon_decay_method == 'linear': decay = LinearDecay( eps_decay=epsilon_decay, min_eps=min_epsilon, - decay_start=epsilon_decay_start) + decay_start=epsilon_decay_start, + decay_per_step=epsilon_decay_per_step) else: decay = NoDecay() @@ -76,17 +78,27 @@ class NoDecay(BaseDecay): class LinearDecay(BaseDecay): - def __init__(self, eps_decay, min_eps=0, decay_start=0): + def __init__(self, eps_decay, min_eps=0, decay_start=0, decay_per_step=False): super(LinearDecay, self).__init__() self.eps_decay = eps_decay self.min_eps = min_eps self.decay_start = decay_start + self.decay_per_step = decay_per_step + self.last_episode = -1 - def decay(self, cur_eps, episode): - if episode < self.decay_start: - return cur_eps + def do_decay(self, episode): + if self.decay_per_step: + do = (episode >= self.decay_start) else: + do = ((self.last_episode != episode) and (episode >= self.decay_start)) + self.last_episode = episode + return do + + def decay(self, cur_eps, episode): + if self.do_decay(episode): return max(cur_eps - self.eps_decay, self.min_eps) + else: + return cur_eps class BaseStrategy(object): diff --git a/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl b/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl index 836a95111705a0f327f2e275ce64c650f39dd34f..402049c6283a6f4ab896f76fff652a8ffb6b6f12 100644 --- a/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl +++ b/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl @@ -16,6 +16,9 @@ <#if (config.strategy.epsilon_decay_start)??> 'epsilon_decay_start': ${config.strategy.epsilon_decay_start}, </#if> +<#if (config.strategy.epsilon_decay_start)??> + 'epsilon_decay_per_step': ${config.strategy.epsilon_decay_per_step?string('True', 'False')}, +</#if> <#if (config.strategy.method)?? && (config.strategy.method=="ornstein_uhlenbeck")> <#if (config.strategy.action_low)?? > 'action_low': ${config.strategy.action_low},