Commit 5d9ba4fa authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Fix a bug which caused a crash after DQN evaluation

parent f8565806
Pipeline #149947 failed with stages
...@@ -906,4 +906,4 @@ class DqnAgent(Agent): ...@@ -906,4 +906,4 @@ class DqnAgent(Agent):
def _save_current_as_best_net(self): def _save_current_as_best_net(self):
self._best_net = copy_net( self._best_net = copy_net(
self._qnet, (1,) + self._state_dim, ctx=self._ctx) self._qnet, self._state_dim, ctx=self._ctx)
...@@ -906,4 +906,4 @@ class DqnAgent(Agent): ...@@ -906,4 +906,4 @@ class DqnAgent(Agent):
def _save_current_as_best_net(self): def _save_current_as_best_net(self):
self._best_net = copy_net( self._best_net = copy_net(
self._qnet, (1,) + self._state_dim, ctx=self._ctx) self._qnet, self._state_dim, ctx=self._ctx)
...@@ -906,4 +906,4 @@ class DqnAgent(Agent): ...@@ -906,4 +906,4 @@ class DqnAgent(Agent):
def _save_current_as_best_net(self): def _save_current_as_best_net(self):
self._best_net = copy_net( self._best_net = copy_net(
self._qnet, (1,) + self._state_dim, ctx=self._ctx) self._qnet, self._state_dim, ctx=self._ctx)
...@@ -906,4 +906,4 @@ class DqnAgent(Agent): ...@@ -906,4 +906,4 @@ class DqnAgent(Agent):
def _save_current_as_best_net(self): def _save_current_as_best_net(self):
self._best_net = copy_net( self._best_net = copy_net(
self._qnet, (1,) + self._state_dim, ctx=self._ctx) self._qnet, self._state_dim, ctx=self._ctx)
...@@ -906,4 +906,4 @@ class DqnAgent(Agent): ...@@ -906,4 +906,4 @@ class DqnAgent(Agent):
def _save_current_as_best_net(self): def _save_current_as_best_net(self):
self._best_net = copy_net( self._best_net = copy_net(
self._qnet, (1,) + self._state_dim, ctx=self._ctx) self._qnet, self._state_dim, ctx=self._ctx)
...@@ -2,502 +2,908 @@ import mxnet as mx ...@@ -2,502 +2,908 @@ import mxnet as mx
import numpy as np import numpy as np
import time import time
import os import os
import logging
import sys import sys
import util import util
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pyprind
from cnnarch_logger import ArchLogger
from replay_memory import ReplayMemoryBuilder from replay_memory import ReplayMemoryBuilder
from action_policy import ActionPolicyBuilder from strategy import StrategyBuilder
from util import copy_net, get_loss_function from util import copy_net, get_loss_function,\
copy_net_with_two_inputs, DdpgTrainingStats, DqnTrainingStats,\
make_directory_if_not_exist
from mxnet import nd, gluon, autograd from mxnet import nd, gluon, autograd
class DqnAgent(object):
def __init__(self, class Agent(object):
network, def __init__(
self,
environment, environment,
replay_memory_params, replay_memory_params,
policy_params, strategy_params,
state_dim, state_dim,
action_dim,
ctx=None, ctx=None,
discount_factor=.9, discount_factor=.9,
loss_function='euclidean',
optimizer='rmsprop',
optimizer_params = {'learning_rate':0.09},
training_episodes=50, training_episodes=50,
train_interval=1, train_interval=1,
use_fix_target=False, start_training=0,
double_dqn = False,
target_update_interval=10,
snapshot_interval=200, snapshot_interval=200,
agent_name='Dqn_agent', agent_name='Agent',
max_episode_step=99999, max_episode_step=99999,
evaluation_samples=1000,
output_directory='model_parameters', output_directory='model_parameters',
verbose=True, verbose=True,
live_plot = True, target_score=None
make_logfile=True, ):
target_score=None): assert 0 < discount_factor <= 1,\
assert 0 < discount_factor <= 1 'Discount factor must be between 0 and 1'
assert train_interval > 0 assert train_interval > 0, 'Train interval must be greater 0'
assert target_update_interval > 0 assert snapshot_interval > 0, 'Snapshot interval must be greater 0'
assert snapshot_interval > 0 assert max_episode_step > 0,\
assert max_episode_step > 0 'Maximal steps per episode must be greater 0'
assert training_episodes > 0 assert training_episodes > 0, 'Trainings episode must be greater 0'
assert replay_memory_params is not None assert replay_memory_params is not None,\
assert type(state_dim) is tuple 'Replay memory parameter not set'
assert type(state_dim) is tuple, 'State dimension is not a tuple'
self.__ctx = mx.gpu() if ctx == 'gpu' else mx.cpu() assert type(action_dim) is tuple, 'Action dimension is not a tuple'
self.__qnet = network
self._logger = ArchLogger.get_logger()
self.__environment = environment self._ctx = mx.gpu() if ctx == 'gpu' else mx.cpu()
self.__discount_factor = discount_factor self._environment = environment
self.__training_episodes = training_episodes self._discount_factor = discount_factor
self.__train_interval = train_interval self._training_episodes = training_episodes
self.__verbose = verbose self._train_interval = train_interval
self.__state_dim = state_dim self._verbose = verbose
self.__action_dim = self.__qnet(nd.random_normal(shape=((1,) + self.__state_dim), ctx=self.__ctx)).shape[1:] self._state_dim = state_dim
replay_memory_params['state_dim'] = state_dim replay_memory_params['state_dim'] = state_dim
self.__replay_memory_params = replay_memory_params replay_memory_params['action_dim'] = action_dim
self._replay_memory_params = replay_memory_params
rm_builder = ReplayMemoryBuilder() rm_builder = ReplayMemoryBuilder()
self.__memory = rm_builder.build_by_params(**replay_memory_params) self._memory = rm_builder.build_by_params(**replay_memory_params)
self.__minibatch_size = self.__memory.sample_size self._minibatch_size = self._memory.sample_size
self._action_dim = action_dim
policy_params['action_dim'] = self.__action_dim
self.__policy_params = policy_params strategy_params['action_dim'] = self._action_dim
p_builder = ActionPolicyBuilder() self._strategy_params = strategy_params
self.__policy = p_builder.build_by_params(**policy_params) strategy_builder = StrategyBuilder()
self._strategy = strategy_builder.build_by_params(**strategy_params)
self.__target_update_interval = target_update_interval self._agent_name = agent_name
self.__target_qnet = copy_net(self.__qnet, self.__state_dim, ctx=self.__ctx) self._snapshot_interval = snapshot_interval
self.__loss_function_str = loss_function self._creation_time = time.time()
self.__loss_function = get_loss_function(loss_function) self._max_episode_step = max_episode_step
self.__agent_name = agent_name self._start_training = start_training
self.__snapshot_interval = snapshot_interval self._output_directory = output_directory
self.__creation_time = time.time() self._target_score = target_score
self.__max_episode_step = max_episode_step
self.__optimizer = optimizer self._evaluation_samples = evaluation_samples
self.__optimizer_params = optimizer_params self._best_avg_score = -np.infty
self.__make_logfile = make_logfile self._best_net = None
self.__double_dqn = double_dqn
self.__use_fix_target = use_fix_target self._interrupt_flag = False
self.__live_plot = live_plot self._training_stats = None
self.__user_given_directory = output_directory
self.__target_score = target_score
self.__interrupt_flag = False
# Training Context # Training Context
self.__current_episode = 0 self._current_episode = 0
self.__total_steps = 0 self._total_steps = 0
# Initialize best network
self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
self.__best_avg_score = None
# Gluon Trainer definition @property
self.__training_stats = None def current_episode(self):
return self._current_episode
# Prepare output directory and logger @property
self.__output_directory = output_directory\ def environment(self):
+ '/' + self.__agent_name\ return self._environment
+ '/' + time.strftime('%d-%m-%Y-%H-%M-%S', time.localtime(self.__creation_time))
self.__logger = self.__setup_logging()
self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
@classmethod def save_config_file(self):
def from_config_file(cls, network, environment, config_file_path, ctx=None):
import json import json
# Load config make_directory_if_not_exist(self._output_directory)
with open(config_file_path, 'r') as config_file: filename = os.path.join(self._output_directory, 'config.json')
config_dict = json.load(config_file) config = self._make_config_dict()
return cls(network, environment, ctx=ctx, **config_dict) with open(filename, mode='w') as fp:
json.dump(config, fp, indent=4)
def set_interrupt_flag(self, interrupt):
self._interrupt_flag = interrupt
def _interrupt_training(self):
import pickle
self._logger.info('Training interrupted; Store state for resuming')
session_dir = self._get_session_dir()
agent_session_file = os.path.join(session_dir, 'agent.p')
logger = self._logger
self._training_stats.save_stats(self._output_directory, episode=self._current_episode)
self._make_pickle_ready(session_dir)
with open(agent_session_file, 'wb') as f:
pickle.dump(self, f, protocol=2)
logger.info('State successfully stored')
def _make_pickle_ready(self, session_dir):
del self._training_stats.logger
self._logger = None
self._environment.close()
self._environment = None
self._save_net(self._best_net, 'best_net', session_dir)
self._best_net = None
def _make_config_dict(self):
config = dict()
config['state_dim'] = self._state_dim
config['action_dim'] = self._action_dim
config['ctx'] = str(self._ctx)
config['discount_factor'] = self._discount_factor
config['strategy_params'] = self._strategy_params
config['replay_memory_params'] = self._replay_memory_params
config['training_episodes'] = self._training_episodes
config['start_training'] = self._start_training
config['evaluation_samples'] = self._evaluation_samples
config['train_interval'] = self._train_interval
config['snapshot_interval'] = self._snapshot_interval
config['agent_name'] = self._agent_name
config['max_episode_step'] = self._max_episode_step
config['output_directory'] = self._output_directory
config['verbose'] = self._verbose
config['target_score'] = self._target_score
return config
def _adjust_optimizer_params(self, optimizer_params):
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
del optimizer_params['weight_decay']
if 'learning_rate_decay' in optimizer_params:
min_learning_rate = 1e-8
if 'learning_rate_minimum' in optimizer_params:
min_learning_rate = optimizer_params['learning_rate_minimum']
del optimizer_params['learning_rate_minimum']
optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
optimizer_params['step_size'],
factor=optimizer_params['learning_rate_decay'],
stop_factor_lr=min_learning_rate)
del optimizer_params['step_size']
del optimizer_params['learning_rate_decay']
return optimizer_params
def _sample_from_memory(self):
states, actions, rewards, next_states, terminals\
= self._memory.sample(batch_size=self._minibatch_size)
states = nd.array(states, ctx=self._ctx)
actions = nd.array(actions, ctx=self._ctx)
rewards = nd.array(rewards, ctx=self._ctx)
next_states = nd.array(next_states, ctx=self._ctx)
terminals = nd.array(terminals, ctx=self._ctx)
return states, actions, rewards, next_states, terminals
def evaluate(self, target=None, sample_games=100, verbose=True):
if sample_games <= 0:
return 0
target = self._target_score if target is None else target
if target:
target_achieved = 0
total_reward = 0
self._logger.info('Sampling from {} games...'.format(sample_games))
for g in pyprind.prog_bar(range(sample_games)):
state = self._environment.reset()
step = 0
game_reward = 0
terminal = False
while not terminal and (step < self._max_episode_step):
action = self.get_next_action(state)
state, reward, terminal, _ = self._environment.step(action)
game_reward += reward
step += 1
if verbose:
info = 'Game %d: Reward %f' % (g, game_reward)
self._logger.debug(info)
if target:
if game_reward >= target:
target_achieved += 1
total_reward += game_reward
avg_reward = float(total_reward)/float(sample_games)
info = 'Avg. Reward: %f' % avg_reward
if target:
target_achieved_ratio = int(
(float(target_achieved)/float(sample_games))*100)
info += '; Target Achieved in %d%% of games'\
% (target_achieved_ratio)
if verbose:
self._logger.info(info)
return avg_reward
def _do_snapshot_if_in_interval(self, episode):
do_snapshot =\
(episode != 0 and (episode % self._snapshot_interval == 0))
if do_snapshot:
self.save_parameters(episode=episode)
self._evaluate()
def _evaluate(self, verbose=True):
avg_reward = self.evaluate(
sample_games=self._evaluation_samples, verbose=False)
info = 'Evaluation -> Average Reward in {} games: {}'.format(
self._evaluation_samples, avg_reward)
if self._best_avg_score is None or self._best_avg_score <= avg_reward:
self._save_current_as_best_net()
self._best_avg_score = avg_reward
if verbose:
self._logger.info(info)
def _is_target_reached(self, avg_reward):
return self._target_score is not None\
and avg_reward > self._target_score
def _do_training(self):
return (self._total_steps % self._train_interval == 0) and\
(self._memory.is_sample_possible(self._minibatch_size)) and\
(self._current_episode >= self._start_training)
def _check_interrupt_routine(self):
if self._interrupt_flag:
self._interrupt_flag = False
self._interrupt_training()
return True
return False
def _is_target_reached(self, avg_reward):
return self._target_score is not None\
and avg_reward > self._target_score
def _save_parameters(self, net, episode=None, filename='dqn-agent-params'):
assert self._output_directory
assert isinstance(net, gluon.HybridBlock)
make_directory_if_not_exist(self._output_directory)
if(episode is not None):
self._logger.info(
'Saving model parameters after episode %d' % episode)
filename = filename + '-ep{}'.format(episode)
else:
self._logger.info('Saving model parameters')
self._save_net(net, filename)
def _save_net(self, net, filename, filedir=None):
filedir = self._output_directory if filedir is None else filedir
filename = os.path.join(filedir, filename)
net.save_parameters(filename + '.params')
net.export(filename, epoch=0)
def save_best_network(self, path, epoch=0):
self._logger.info(
'Saving best network with average reward of {}'.format(
self._best_avg_score))
self._best_net.export(path, epoch=epoch)
def _get_session_dir(self):
session_dir = os.path.join(
self._output_directory, '.interrupted_session')
make_directory_if_not_exist(session_dir)
return session_dir
def _save_current_as_best_net(self):
raise NotImplementedError
def get_next_action(self, state):
raise NotImplementedError
def save_parameters(self, episode):
raise NotImplementedError
def train(self, episodes=None):
raise NotImplementedError
class DdpgAgent(Agent):
def __init__(
self,
actor,
critic,
environment,
replay_memory_params,
strategy_params,
state_dim,
action_dim,
soft_target_update_rate=.001,
actor_optimizer='adam',
actor_optimizer_params={'learning_rate': 0.0001},
critic_optimizer='adam',
critic_optimizer_params={'learning_rate': 0.001},
ctx=None,
discount_factor=.9,
training_episodes=50,
start_training=20,
train_interval=1,
snapshot_interval=200,
agent_name='DdpgAgent',
max_episode_step=9999,
evaluation_samples=100,
output_directory='model_parameters',
verbose=True,
target_score=None
):
super(DdpgAgent, self).__init__(
environment=environment, replay_memory_params=replay_memory_params,
strategy_params=strategy_params, state_dim=state_dim,
action_dim=action_dim, ctx=ctx, discount_factor=discount_factor,
training_episodes=training_episodes, start_training=start_training,
train_interval=train_interval,
snapshot_interval=snapshot_interval, agent_name=agent_name,
max_episode_step=max_episode_step,
output_directory=output_directory, verbose=verbose,
target_score=target_score, evaluation_samples=evaluation_samples)
assert critic is not None, 'Critic not set'
assert actor is not None, 'Actor is not set'
assert soft_target_update_rate > 0,\
'Target update must be greater zero'
assert actor_optimizer is not None, 'No actor optimizer set'
assert critic_optimizer is not None, 'No critic optimizer set'
self._actor = actor
self._critic = critic
self._actor_target = self._copy_actor()
self._critic_target = self._copy_critic()
self._actor_optimizer = actor_optimizer
self._actor_optimizer_params = self._adjust_optimizer_params(
actor_optimizer_params)
self._critic_optimizer = critic_optimizer
self._critic_optimizer_params = self._adjust_optimizer_params(
critic_optimizer_params)
self._soft_target_update_rate = soft_target_update_rate
self._logger.info(
'Agent created with following parameters: {}'.format(
self._make_config_dict()))
self._best_net = self._copy_actor()
self._training_stats = DdpgTrainingStats(self._training_episodes)
def _make_pickle_ready(self, session_dir):
super(DdpgAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._actor, 'current_actor')
self._save_net(self._actor, 'actor', session_dir)
self._actor = None
self._save_net(self._critic, 'critic', session_dir)
self._critic = None
self._save_net(self._actor_target, 'actor_target', session_dir)
self._actor_target = None
self._save_net(self._critic_target, 'critic_target', session_dir)
self._critic_target = None
@classmethod @classmethod
def resume_from_session(cls, session_dir, network_type): def resume_from_session(cls, session_dir, actor, critic, environment):
import pickle import pickle
session_dir = os.path.join(session_dir, '.interrupted_session')
if not os.path.exists(session_dir): if not os.path.exists(session_dir):
raise ValueError('Session directory does not exist') raise ValueError('Session directory does not exist')
files = dict() files = dict()
files['agent'] = os.path.join(session_dir, 'agent.p') files['agent'] = os.path.join(session_dir, 'agent.p')
files['best_net_params'] = os.path.join(session_dir, 'best_net.params') files['best_net_params'] = os.path.join(session_dir, 'best_net.params')
files['q_net_params'] = os.path.join(session_dir, 'qnet.params') files['actor_net_params'] = os.path.join(session_dir, 'actor.params')
files['target_net_params'] = os.path.join(session_dir, 'target_net.params') files['actor_target_net_params'] = os.path.join(
session_dir, 'actor_target.params')
files['critic_net_params'] = os.path.join(session_dir, 'critic.params')
files['critic_target_net_params'] = os.path.join(
session_dir, 'critic_target.params')
for file in files.values(): for file in files.values():
if not os.path.exists(file): if not os.path.exists(file):
raise ValueError('Session directory is not complete: {} is missing'.format(file)) raise ValueError(
'Session directory is not complete: {} is missing'
.format(file))
with open(files['agent'], 'rb') as f: with open(files['agent'], 'rb') as f: