Commit e550ff1e authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Use replay memory with discrete actions for DQN

parent 60e7c61d
......@@ -1043,6 +1043,14 @@ class DqnAgent(Agent):
self._double_dqn = double_dqn
self._use_fix_target = use_fix_target
# Build memory buffer for discrete actions
replay_memory_params['state_dim'] = state_dim
replay_memory_params['action_dim'] = (1,)
self._replay_memory_params = replay_memory_params
rm_builder = ReplayMemoryBuilder()
self._memory = rm_builder.build_by_params(**replay_memory_params)
self._minibatch_size = self._memory.sample_size
# Initialize best network
self._best_net = copy_net(self._qnet, self._state_dim, self._ctx)
self._best_avg_score = -np.infty
......@@ -1199,7 +1207,7 @@ class DqnAgent(Agent):
# 3. Store transition in replay memory
self._memory.append(
state, action, reward, next_state, terminal)
state, [action], reward, next_state, terminal)
# 4. Train the network if in interval
if self._do_training():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment