From 7ed155b345bdefb904af10fa765dcbad158e93a5 Mon Sep 17 00:00:00 2001 From: Nicola Gatto <nicola.gatto@rwth-aachen.de> Date: Wed, 17 Jul 2019 00:25:40 +0200 Subject: [PATCH] Implement noise variance --- .../templates/gluon/reinforcement/agent/Strategy.ftl | 11 ++++++++--- .../gluon/reinforcement/params/StrategyParams.ftl | 5 +++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Strategy.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Strategy.ftl index 38abc42b..d2d23866 100644 --- a/src/main/resources/templates/gluon/reinforcement/agent/Strategy.ftl +++ b/src/main/resources/templates/gluon/reinforcement/agent/Strategy.ftl @@ -19,7 +19,8 @@ class StrategyBuilder(object): action_high=None, mu=0.0, theta=0.5, - sigma=0.3 + sigma=0.3, + noise_variance=0.1 ): if epsilon_decay_method == 'linear': @@ -50,8 +51,9 @@ class StrategyBuilder(object): assert action_dim is not None assert action_low is not None assert action_high is not None + assert noise_variance is not None return GaussianNoiseStrategy(action_dim, action_low, action_high, - epsilon, decay) + epsilon, noise_variance, decay) else: assert action_dim is not None assert len(action_dim) == 1 @@ -197,6 +199,7 @@ class GaussianNoiseStrategy(BaseStrategy): action_low, action_high, eps, + noise_variance, decay=NoDecay() ): super(GaussianNoiseStrategy, self).__init__(decay) @@ -207,7 +210,9 @@ class GaussianNoiseStrategy(BaseStrategy): self._action_low = action_low self._action_high = action_high + self._noise_variance = noise_variance + def select_action(self, values): - noise = np.random.normal(size=self._action_dim) + noise = np.random.normal(loc=0.0, scale=self._noise_variance, size=self._action_dim) action = values + self.cur_eps * noise return np.clip(action, self._action_low, self._action_high) diff --git a/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl b/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl index e5ef1ee5..b8ffa903 100644 --- a/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl +++ b/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl @@ -19,6 +19,11 @@ <#if (config.strategy.epsilon_decay_per_step)??> 'epsilon_decay_per_step': ${config.strategy.epsilon_decay_per_step?string('True', 'False')}, </#if> +<#if (config.strategy.method=="gaussian")> +<#if (config.strategy.noise_variance)??> + 'noise_variance': ${config.strategy.noise_variance}, +</#if> +</#if> <#if (config.strategy.method)?? && (config.strategy.method=="ornstein_uhlenbeck" || config.strategy.method=="gaussian")> <#if (config.strategy.action_low)?? > 'action_low': ${config.strategy.action_low}, -- GitLab