Commit 62786f93 authored by Nicola Gatto's avatar Nicola Gatto

Adjust tests

parent 14c29770
Pipeline #142451 failed
from reinforcement_learning.agent import DqnAgent from reinforcement_learning.agent import DqnAgent
from reinforcement_learning.util import AgentSignalHandler
import reinforcement_learning.environment import reinforcement_learning.environment
import CNNCreator_cartpole_master_dqn import CNNCreator_cartpole_master_dqn
import os
import sys
import re
import logging import logging
import mxnet as mx import mxnet as mx
session_output_dir = 'session'
agent_name='cartpole_master_dqn'
session_param_output = os.path.join(session_output_dir, agent_name)
def resume_session():
session_param_output = os.path.join(session_output_dir, agent_name)
resume_session = False
resume_directory = None
if os.path.isdir(session_output_dir) and os.path.isdir(session_param_output):
regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
dir_content = os.listdir(session_param_output)
session_files = filter(regex.search, dir_content)
session_files.sort(reverse=True)
for d in session_files:
interrupted_session_dir = os.path.join(session_param_output, d, '.interrupted_session')
if os.path.isdir(interrupted_session_dir):
resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
if resume == 'y':
resume_session = True
resume_directory = interrupted_session_dir
break
return resume_session, resume_directory
if __name__ == "__main__": if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v0') env = reinforcement_learning.environment.GymEnvironment('CartPole-v0')
context = mx.cpu() context = mx.cpu()
...@@ -28,32 +55,42 @@ if __name__ == "__main__": ...@@ -28,32 +55,42 @@ if __name__ == "__main__":
'epsilon_decay': 0.01, 'epsilon_decay': 0.01,
} }
agent = DqnAgent( resume_session, resume_directory = resume_session()
network = net_creator.net,
environment=env, if resume_session:
replay_memory_params=replay_memory_params, agent = DqnAgent.resume_from_session(resume_directory, net_creator.net, env)
policy_params=policy_params, else:
state_dim=net_creator.get_input_shapes()[0], agent = DqnAgent(
ctx='cpu', network = net_creator.net,
discount_factor=0.999, environment=env,
loss_function='euclidean', replay_memory_params=replay_memory_params,
optimizer='rmsprop', policy_params=policy_params,
optimizer_params={ state_dim=net_creator.get_input_shapes()[0],
'learning_rate': 0.001 ctx='cpu',
}, discount_factor=0.999,
training_episodes=160, loss_function='euclidean',
train_interval=1, optimizer='rmsprop',
use_fix_target=True, optimizer_params={
target_update_interval=200, 'learning_rate': 0.001 },
double_dqn = False, training_episodes=160,
snapshot_interval=20, train_interval=1,
agent_name='cartpole_master_dqn', use_fix_target=True,
max_episode_step=250, target_update_interval=200,
output_directory='model', double_dqn = False,
verbose=True, snapshot_interval=20,
live_plot = True, agent_name=agent_name,
make_logfile=True, max_episode_step=250,
target_score=185.5 output_directory=session_output_dir,
) verbose=True,
train_successfull = agent.train() live_plot = True,
agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0) make_logfile=True,
\ No newline at end of file target_score=185.5
)
signal_handler = AgentSignalHandler()
signal_handler.register_agent(agent)
train_successful = agent.train()
if train_successful:
agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
...@@ -100,7 +100,7 @@ class DqnAgent(object): ...@@ -100,7 +100,7 @@ class DqnAgent(object):
# Prepare output directory and logger # Prepare output directory and logger
self.__output_directory = output_directory\ self.__output_directory = output_directory\
+ '/' + self.__agent_name\ + '/' + self.__agent_name\
+ '/' + time.strftime('%d-%m-%Y-%H-%M-%S', time.localtime(self.__creation_time)) + '/' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(self.__creation_time))
self.__logger = self.__setup_logging() self.__logger = self.__setup_logging()
self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict())) self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
...@@ -113,9 +113,8 @@ class DqnAgent(object): ...@@ -113,9 +113,8 @@ class DqnAgent(object):
return cls(network, environment, ctx=ctx, **config_dict) return cls(network, environment, ctx=ctx, **config_dict)
@classmethod @classmethod
def resume_from_session(cls, session_dir, network_type): def resume_from_session(cls, session_dir, net, environment):
import pickle import pickle
session_dir = os.path.join(session_dir, '.interrupted_session')
if not os.path.exists(session_dir): if not os.path.exists(session_dir):
raise ValueError('Session directory does not exist') raise ValueError('Session directory does not exist')
...@@ -132,13 +131,14 @@ class DqnAgent(object): ...@@ -132,13 +131,14 @@ class DqnAgent(object):
with open(files['agent'], 'rb') as f: with open(files['agent'], 'rb') as f:
agent = pickle.load(f) agent = pickle.load(f)
agent.__qnet = network_type() agent.__environment = environment
agent.__qnet = net
agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx) agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx)
agent.__qnet.hybridize() agent.__qnet.hybridize()
agent.__qnet(nd.ones((1,) + agent.__environment.state_dim)) agent.__qnet(nd.random_normal(shape=((1,) + agent.__state_dim), ctx=agent.__ctx))
agent.__best_net = network_type() agent.__best_net = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx) agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx)
agent.__target_qnet = network_type() agent.__target_qnet = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx) agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx)
agent.__logger = agent.__setup_logging(append=True) agent.__logger = agent.__setup_logging(append=True)
...@@ -157,6 +157,8 @@ class DqnAgent(object): ...@@ -157,6 +157,8 @@ class DqnAgent(object):
del self.__training_stats.logger del self.__training_stats.logger
logger = self.__logger logger = self.__logger
self.__logger = None self.__logger = None
self.__environment.close()
self.__environment = None
self.__save_net(self.__qnet, 'qnet', session_dir) self.__save_net(self.__qnet, 'qnet', session_dir)
self.__qnet = None self.__qnet = None
...@@ -169,7 +171,7 @@ class DqnAgent(object): ...@@ -169,7 +171,7 @@ class DqnAgent(object):
with open(agent_session_file, 'wb') as f: with open(agent_session_file, 'wb') as f:
pickle.dump(self, f) pickle.dump(self, f)
self.__logger = logger
logger.info('State successfully stored') logger.info('State successfully stored')
@property @property
...@@ -293,7 +295,7 @@ class DqnAgent(object): ...@@ -293,7 +295,7 @@ class DqnAgent(object):
return loss return loss
def __do_snapshot_if_in_interval(self, episode): def __do_snapshot_if_in_interval(self, episode):
do_snapshot = (episode % self.__snapshot_interval == 0) do_snapshot = (episode != 0 and (episode % self.__snapshot_interval == 0))
if do_snapshot: if do_snapshot:
self.save_parameters(episode=episode) self.save_parameters(episode=episode)
self.__evaluate() self.__evaluate()
...@@ -318,6 +320,7 @@ class DqnAgent(object): ...@@ -318,6 +320,7 @@ class DqnAgent(object):
# Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning # Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning
while self.__current_episode < episodes: while self.__current_episode < episodes:
# Check interrupt flag
if self.__interrupt_flag: if self.__interrupt_flag:
self.__interrupt_flag = False self.__interrupt_flag = False
self.__interrupt_training() self.__interrupt_training()
......
...@@ -17,6 +17,10 @@ class Environment: ...@@ -17,6 +17,10 @@ class Environment:
def step(self, action): def step(self, action):
pass pass
@abc.abstractmethod
def close(self):
pass
import gym import gym
class GymEnvironment(Environment): class GymEnvironment(Environment):
def __init__(self, env_name, **kwargs): def __init__(self, env_name, **kwargs):
......
...@@ -37,13 +37,19 @@ class AgentSignalHandler(object): ...@@ -37,13 +37,19 @@ class AgentSignalHandler(object):
def __init__(self): def __init__(self):
signal.signal(signal.SIGINT, self.interrupt_training) signal.signal(signal.SIGINT, self.interrupt_training)
self.__agent = None self.__agent = None
self.__times_interrupted = 0
def register_agent(self, agent): def register_agent(self, agent):
self.__agent = agent self.__agent = agent
def interrupt_training(self, sig, frame): def interrupt_training(self, sig, frame):
if self.__agent: self.__times_interrupted = self.__times_interrupted + 1
self.__agent.set_interrupt_flag(True) if self.__times_interrupted <= 3:
if self.__agent:
self.__agent.set_interrupt_flag(True)
else:
print('Interrupt called three times: Force quit')
sys.exit(1)
style.use('fivethirtyeight') style.use('fivethirtyeight')
class TrainingStats(object): class TrainingStats(object):
......
from reinforcement_learning.agent import DqnAgent from reinforcement_learning.agent import DqnAgent
from reinforcement_learning.util import AgentSignalHandler
import reinforcement_learning.environment import reinforcement_learning.environment
import CNNCreator_torcs_agent_torcsAgent_dqn import CNNCreator_torcs_agent_torcsAgent_dqn
import os
import sys
import re
import logging import logging
import mxnet as mx import mxnet as mx
session_output_dir = 'session'
agent_name='torcs_agent_torcsAgent_dqn'
session_param_output = os.path.join(session_output_dir, agent_name)
def resume_session():
session_param_output = os.path.join(session_output_dir, agent_name)
resume_session = False
resume_directory = None
if os.path.isdir(session_output_dir) and os.path.isdir(session_param_output):
regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
dir_content = os.listdir(session_param_output)
session_files = filter(regex.search, dir_content)
session_files.sort(reverse=True)
for d in session_files:
interrupted_session_dir = os.path.join(session_param_output, d, '.interrupted_session')
if os.path.isdir(interrupted_session_dir):
resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
if resume == 'y':
resume_session = True
resume_directory = interrupted_session_dir
break
return resume_session, resume_directory
if __name__ == "__main__": if __name__ == "__main__":
env_params = { env_params = {
'ros_node_name' : 'torcs_agent_torcsAgent_dqnTrainerNode', 'ros_node_name' : 'torcs_agent_torcsAgent_dqnTrainerNode',
...@@ -35,32 +62,42 @@ if __name__ == "__main__": ...@@ -35,32 +62,42 @@ if __name__ == "__main__":
'epsilon_decay': 0.0001, 'epsilon_decay': 0.0001,
} }
agent = DqnAgent( resume_session, resume_directory = resume_session()
network = net_creator.net,
environment=env, if resume_session:
replay_memory_params=replay_memory_params, agent = DqnAgent.resume_from_session(resume_directory, net_creator.net, env)
policy_params=policy_params, else:
state_dim=net_creator.get_input_shapes()[0], agent = DqnAgent(
ctx='cpu', network = net_creator.net,
discount_factor=0.999, environment=env,
loss_function='euclidean', replay_memory_params=replay_memory_params,
optimizer='rmsprop', policy_params=policy_params,
optimizer_params={ state_dim=net_creator.get_input_shapes()[0],
'learning_rate': 0.001 ctx='cpu',
}, discount_factor=0.999,
training_episodes=20000, loss_function='euclidean',
train_interval=1, optimizer='rmsprop',
use_fix_target=True, optimizer_params={
target_update_interval=500, 'learning_rate': 0.001 },
double_dqn = True, training_episodes=20000,
snapshot_interval=1000, train_interval=1,
agent_name='torcs_agent_torcsAgent_dqn', use_fix_target=True,
max_episode_step=999999999, target_update_interval=500,
output_directory='model', double_dqn = True,
verbose=True, snapshot_interval=1000,
live_plot = True, agent_name=agent_name,
make_logfile=True, max_episode_step=999999999,
target_score=None output_directory=session_output_dir,
) verbose=True,
train_successfull = agent.train() live_plot = True,
agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0) make_logfile=True,
\ No newline at end of file target_score=None
)
signal_handler = AgentSignalHandler()
signal_handler.register_agent(agent)
train_successful = agent.train()
if train_successful:
agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
...@@ -100,7 +100,7 @@ class DqnAgent(object): ...@@ -100,7 +100,7 @@ class DqnAgent(object):
# Prepare output directory and logger # Prepare output directory and logger
self.__output_directory = output_directory\ self.__output_directory = output_directory\
+ '/' + self.__agent_name\ + '/' + self.__agent_name\
+ '/' + time.strftime('%d-%m-%Y-%H-%M-%S', time.localtime(self.__creation_time)) + '/' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(self.__creation_time))
self.__logger = self.__setup_logging() self.__logger = self.__setup_logging()
self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict())) self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
...@@ -113,9 +113,8 @@ class DqnAgent(object): ...@@ -113,9 +113,8 @@ class DqnAgent(object):
return cls(network, environment, ctx=ctx, **config_dict) return cls(network, environment, ctx=ctx, **config_dict)
@classmethod @classmethod
def resume_from_session(cls, session_dir, network_type): def resume_from_session(cls, session_dir, net, environment):
import pickle import pickle
session_dir = os.path.join(session_dir, '.interrupted_session')
if not os.path.exists(session_dir): if not os.path.exists(session_dir):
raise ValueError('Session directory does not exist') raise ValueError('Session directory does not exist')
...@@ -132,13 +131,14 @@ class DqnAgent(object): ...@@ -132,13 +131,14 @@ class DqnAgent(object):
with open(files['agent'], 'rb') as f: with open(files['agent'], 'rb') as f:
agent = pickle.load(f) agent = pickle.load(f)
agent.__qnet = network_type() agent.__environment = environment
agent.__qnet = net
agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx) agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx)
agent.__qnet.hybridize() agent.__qnet.hybridize()
agent.__qnet(nd.ones((1,) + agent.__environment.state_dim)) agent.__qnet(nd.random_normal(shape=((1,) + agent.__state_dim), ctx=agent.__ctx))
agent.__best_net = network_type() agent.__best_net = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx) agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx)
agent.__target_qnet = network_type() agent.__target_qnet = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx) agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx)
agent.__logger = agent.__setup_logging(append=True) agent.__logger = agent.__setup_logging(append=True)
...@@ -157,6 +157,8 @@ class DqnAgent(object): ...@@ -157,6 +157,8 @@ class DqnAgent(object):
del self.__training_stats.logger del self.__training_stats.logger
logger = self.__logger logger = self.__logger
self.__logger = None self.__logger = None
self.__environment.close()
self.__environment = None
self.__save_net(self.__qnet, 'qnet', session_dir) self.__save_net(self.__qnet, 'qnet', session_dir)
self.__qnet = None self.__qnet = None
...@@ -169,7 +171,7 @@ class DqnAgent(object): ...@@ -169,7 +171,7 @@ class DqnAgent(object):
with open(agent_session_file, 'wb') as f: with open(agent_session_file, 'wb') as f:
pickle.dump(self, f) pickle.dump(self, f)
self.__logger = logger
logger.info('State successfully stored') logger.info('State successfully stored')
@property @property
...@@ -293,7 +295,7 @@ class DqnAgent(object): ...@@ -293,7 +295,7 @@ class DqnAgent(object):
return loss return loss
def __do_snapshot_if_in_interval(self, episode): def __do_snapshot_if_in_interval(self, episode):
do_snapshot = (episode % self.__snapshot_interval == 0) do_snapshot = (episode != 0 and (episode % self.__snapshot_interval == 0))
if do_snapshot: if do_snapshot:
self.save_parameters(episode=episode) self.save_parameters(episode=episode)
self.__evaluate() self.__evaluate()
...@@ -318,6 +320,7 @@ class DqnAgent(object): ...@@ -318,6 +320,7 @@ class DqnAgent(object):
# Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning # Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning
while self.__current_episode < episodes: while self.__current_episode < episodes:
# Check interrupt flag
if self.__interrupt_flag: if self.__interrupt_flag:
self.__interrupt_flag = False self.__interrupt_flag = False
self.__interrupt_training() self.__interrupt_training()
......
...@@ -32,6 +32,10 @@ class Environment: ...@@ -32,6 +32,10 @@ class Environment:
def step(self, action): def step(self, action):
pass pass
@abc.abstractmethod
def close(self):
pass
import rospy import rospy
import thread import thread
import numpy as np import numpy as np
...@@ -83,7 +87,8 @@ class RosEnvironment(Environment): ...@@ -83,7 +87,8 @@ class RosEnvironment(Environment):
reset_message.data = True reset_message.data = True
self.__waiting_for_state_update = True self.__waiting_for_state_update = True
self.__reset_publisher.publish(reset_message) self.__reset_publisher.publish(reset_message)
self.__wait_for_new_state(self.__reset_publisher, reset_message) while self.__last_received_terminal:
self.__wait_for_new_state(self.__reset_publisher, reset_message)
return self.__last_received_state return self.__last_received_state
def step(self, action): def step(self, action):
...@@ -119,6 +124,9 @@ class RosEnvironment(Environment): ...@@ -119,6 +124,9 @@ class RosEnvironment(Environment):
exit() exit()
time.sleep(100/1000) time.sleep(100/1000)
def close(self):
rospy.signal_shutdown('Program ended!')
def __state_callback(self, data): def __state_callback(self, data):
self.__last_received_state = np.array(data.data, dtype='double') self.__last_received_state = np.array(data.data, dtype='double')
......
# This file was automatically generated by SWIG (http://www.swig.org). # This file was automatically generated by SWIG (http://www.swig.org).
# Version 3.0.12 # Version 3.0.8
# #
# Do not make changes to this file unless you know what you are doing--modify # Do not make changes to this file unless you know what you are doing--modify
# the SWIG interface file instead. # the SWIG interface file instead.
from sys import version_info as _swig_python_version_info
if _swig_python_version_info >= (2, 7, 0):
def swig_import_helper():
import importlib
pkg = __name__.rpartition('.')[0] from sys import version_info
mname = '.'.join((pkg, '_torcs_agent_dqn_reward_executor')).lstrip('.') if version_info >= (2, 6, 0):
try:
return importlib.import_module(mname)
except ImportError:
return importlib.import_module('_torcs_agent_dqn_reward_executor')
_torcs_agent_dqn_reward_executor = swig_import_helper()
del swig_import_helper
elif _swig_python_version_info >= (2, 6, 0):
def swig_import_helper():