Commit fb24b807 authored by Nils Baumann's avatar Nils Baumann
Browse files

Updated Tests

parent 88c0c13c
Pipeline #683374 passed with stage
in 21 seconds
......@@ -1139,19 +1139,19 @@ class DqnAgent(Agent):
def get_next_action(self, state, with_best=False):
q_values = self.get_q_values(state, with_best=with_best)
action = q_values.asnumpy().argmax()
action = q_values[0][0].asnumpy().argmax()
return action
def __determine_target_q_values(
self, states, actions, rewards, next_states, terminals
):
if self._use_fix_target:
q_max_val = self._target_qnet(next_states)
q_max_val = self._target_qnet(next_states)[0][0]
else:
q_max_val = self._qnet(next_states)
q_max_val = self._qnet(next_states)[0][0]
if self._double_dqn:
q_values_next_states = self._qnet(next_states)
q_values_next_states = self._qnet(next_states)[0][0]
target_rewards = rewards + nd.choose_element_0index(
q_max_val, nd.argmax_channel(q_values_next_states))\
* (1.0 - terminals) * self._discount_factor
......@@ -1160,11 +1160,13 @@ class DqnAgent(Agent):
q_max_val, nd.argmax_channel(q_max_val))\
* (1.0 - terminals) * self._discount_factor
actions = actions.astype(int)
target_qval = self._qnet(states)
for t in range(target_rewards.shape[0]):
target_qval[t][actions[t]] = target_rewards[t]
target_qval[0][0][t][actions.asnumpy()[t, 0]] = target_rewards[t]
return target_qval
return target_qval[0][0]
def __train_q_net_step(self, trainer):
states, actions, rewards, next_states, terminals =\
......@@ -1173,7 +1175,7 @@ class DqnAgent(Agent):
states, actions, rewards, next_states, terminals)
with autograd.record():
q_values = self._qnet(states)
loss = self._loss_function(q_values, target_qval)
loss = self._loss_function(q_values[0][0], target_qval)
loss.backward()
trainer.step(self._minibatch_size)
return loss
......
# (c) https://github.com/MontiCore/monticore
# (c) https://github.com/MontiCore/monticore
import numpy as np
......@@ -130,7 +130,7 @@ class EpsilonGreedyStrategy(BaseStrategy):
if do_exploration:
action = np.random.randint(low=0, high=self.__number_of_actions)
else:
action = values.asnumpy().argmax()
action = values[0][0].asnumpy().argmax()
return action
......
......@@ -1139,19 +1139,19 @@ class DqnAgent(Agent):
def get_next_action(self, state, with_best=False):
q_values = self.get_q_values(state, with_best=with_best)
action = q_values.asnumpy().argmax()
action = q_values[0][0].asnumpy().argmax()
return action
def __determine_target_q_values(
self, states, actions, rewards, next_states, terminals
):
if self._use_fix_target:
q_max_val = self._target_qnet(next_states)
q_max_val = self._target_qnet(next_states)[0][0]
else:
q_max_val = self._qnet(next_states)
q_max_val = self._qnet(next_states)[0][0]
if self._double_dqn:
q_values_next_states = self._qnet(next_states)
q_values_next_states = self._qnet(next_states)[0][0]
target_rewards = rewards + nd.choose_element_0index(
q_max_val, nd.argmax_channel(q_values_next_states))\
* (1.0 - terminals) * self._discount_factor
......@@ -1160,11 +1160,13 @@ class DqnAgent(Agent):
q_max_val, nd.argmax_channel(q_max_val))\
* (1.0 - terminals) * self._discount_factor
actions = actions.astype(int)
target_qval = self._qnet(states)
for t in range(target_rewards.shape[0]):
target_qval[t][actions[t]] = target_rewards[t]
target_qval[0][0][t][actions.asnumpy()[t, 0]] = target_rewards[t]
return target_qval
return target_qval[0][0]
def __train_q_net_step(self, trainer):
states, actions, rewards, next_states, terminals =\
......@@ -1173,7 +1175,7 @@ class DqnAgent(Agent):
states, actions, rewards, next_states, terminals)
with autograd.record():
q_values = self._qnet(states)
loss = self._loss_function(q_values, target_qval)
loss = self._loss_function(q_values[0][0], target_qval)
loss.backward()
trainer.step(self._minibatch_size)
return loss
......
# (c) https://github.com/MontiCore/monticore
# (c) https://github.com/MontiCore/monticore
import numpy as np
......@@ -130,7 +130,7 @@ class EpsilonGreedyStrategy(BaseStrategy):
if do_exploration:
action = np.random.randint(low=0, high=self.__number_of_actions)
else:
action = values.asnumpy().argmax()
action = values[0][0].asnumpy().argmax()
return action
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment