Commit cc1d2fcb authored by Nicola Gatto's avatar Nicola Gatto

Adapt tests to fix

parent d9d9cf5d
......@@ -56,7 +56,7 @@ if __name__ == "__main__":
'memory_size': 10000,
'sample_size': 32,
'state_dtype': 'float32',
'action_dtype': 'float32',
'action_dtype': 'uint8',
'rewards_dtype': 'float32'
},
'strategy_params': {
......
......@@ -1043,6 +1043,14 @@ class DqnAgent(Agent):
self._double_dqn = double_dqn
self._use_fix_target = use_fix_target
# Build memory buffer for discrete actions
replay_memory_params['state_dim'] = state_dim
replay_memory_params['action_dim'] = (1,)
self._replay_memory_params = replay_memory_params
rm_builder = ReplayMemoryBuilder()
self._memory = rm_builder.build_by_params(**replay_memory_params)
self._minibatch_size = self._memory.sample_size
# Initialize best network
self._best_net = copy_net(self._qnet, self._state_dim, self._ctx)
self._best_avg_score = -np.infty
......@@ -1199,7 +1207,7 @@ class DqnAgent(Agent):
# 3. Store transition in replay memory
self._memory.append(
state, action, reward, next_state, terminal)
state, [action], reward, next_state, terminal)
# 4. Train the network if in interval
if self._do_training():
......
......@@ -1043,6 +1043,14 @@ class DqnAgent(Agent):
self._double_dqn = double_dqn
self._use_fix_target = use_fix_target
# Build memory buffer for discrete actions
replay_memory_params['state_dim'] = state_dim
replay_memory_params['action_dim'] = (1,)
self._replay_memory_params = replay_memory_params
rm_builder = ReplayMemoryBuilder()
self._memory = rm_builder.build_by_params(**replay_memory_params)
self._minibatch_size = self._memory.sample_size
# Initialize best network
self._best_net = copy_net(self._qnet, self._state_dim, self._ctx)
self._best_avg_score = -np.infty
......@@ -1199,7 +1207,7 @@ class DqnAgent(Agent):
# 3. Store transition in replay memory
self._memory.append(
state, action, reward, next_state, terminal)
state, [action], reward, next_state, terminal)
# 4. Train the network if in interval
if self._do_training():
......
......@@ -63,7 +63,7 @@ if __name__ == "__main__":
'memory_size': 1000000,
'sample_size': 32,
'state_dtype': 'float32',
'action_dtype': 'float32',
'action_dtype': 'uint8',
'rewards_dtype': 'float32'
},
'strategy_params': {
......
......@@ -1043,6 +1043,14 @@ class DqnAgent(Agent):
self._double_dqn = double_dqn
self._use_fix_target = use_fix_target
# Build memory buffer for discrete actions
replay_memory_params['state_dim'] = state_dim
replay_memory_params['action_dim'] = (1,)
self._replay_memory_params = replay_memory_params
rm_builder = ReplayMemoryBuilder()
self._memory = rm_builder.build_by_params(**replay_memory_params)
self._minibatch_size = self._memory.sample_size
# Initialize best network
self._best_net = copy_net(self._qnet, self._state_dim, self._ctx)
self._best_avg_score = -np.infty
......@@ -1199,7 +1207,7 @@ class DqnAgent(Agent):
# 3. Store transition in replay memory
self._memory.append(
state, action, reward, next_state, terminal)
state, [action], reward, next_state, terminal)
# 4. Train the network if in interval
if self._do_training():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment