Commit cdc08ba5 authored by Nicola Gatto's avatar Nicola Gatto

Merge branch 'integrate-reinforcement-gluon' into annotate-architecture

parents cbf77848 62786f93
from reinforcement_learning.agent import DqnAgent
from reinforcement_learning.util import AgentSignalHandler
import reinforcement_learning.environment
import CNNCreator_cartpole_master_dqn
import os
import sys
import re
import logging
import mxnet as mx
session_output_dir = 'session'
agent_name='cartpole_master_dqn'
session_param_output = os.path.join(session_output_dir, agent_name)
def resume_session():
session_param_output = os.path.join(session_output_dir, agent_name)
resume_session = False
resume_directory = None
if os.path.isdir(session_output_dir) and os.path.isdir(session_param_output):
regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
dir_content = os.listdir(session_param_output)
session_files = filter(regex.search, dir_content)
session_files.sort(reverse=True)
for d in session_files:
interrupted_session_dir = os.path.join(session_param_output, d, '.interrupted_session')
if os.path.isdir(interrupted_session_dir):
resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
if resume == 'y':
resume_session = True
resume_directory = interrupted_session_dir
break
return resume_session, resume_directory
if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v0')
context = mx.cpu()
......@@ -28,32 +55,42 @@ if __name__ == "__main__":
'epsilon_decay': 0.01,
}
agent = DqnAgent(
network = net_creator.net,
environment=env,
replay_memory_params=replay_memory_params,
policy_params=policy_params,
state_dim=net_creator.get_input_shapes()[0],
ctx='cpu',
discount_factor=0.999,
loss_function='euclidean',
optimizer='rmsprop',
optimizer_params={
'learning_rate': 0.001
},
training_episodes=160,
train_interval=1,
use_fix_target=True,
target_update_interval=200,
double_dqn = False,
snapshot_interval=20,
agent_name='cartpole_master_dqn',
max_episode_step=250,
output_directory='model',
verbose=True,
live_plot = True,
make_logfile=True,
target_score=185.5
)
train_successfull = agent.train()
agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
resume_session, resume_directory = resume_session()
if resume_session:
agent = DqnAgent.resume_from_session(resume_directory, net_creator.net, env)
else:
agent = DqnAgent(
network = net_creator.net,
environment=env,
replay_memory_params=replay_memory_params,
policy_params=policy_params,
state_dim=net_creator.get_input_shapes()[0],
ctx='cpu',
discount_factor=0.999,
loss_function='euclidean',
optimizer='rmsprop',
optimizer_params={
'learning_rate': 0.001 },
training_episodes=160,
train_interval=1,
use_fix_target=True,
target_update_interval=200,
double_dqn = False,
snapshot_interval=20,
agent_name=agent_name,
max_episode_step=250,
output_directory=session_output_dir,
verbose=True,
live_plot = True,
make_logfile=True,
target_score=185.5
)
signal_handler = AgentSignalHandler()
signal_handler.register_agent(agent)
train_successful = agent.train()
if train_successful:
agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
......@@ -100,7 +100,7 @@ class DqnAgent(object):
# Prepare output directory and logger
self.__output_directory = output_directory\
+ '/' + self.__agent_name\
+ '/' + time.strftime('%d-%m-%Y-%H-%M-%S', time.localtime(self.__creation_time))
+ '/' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(self.__creation_time))
self.__logger = self.__setup_logging()
self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
......@@ -113,9 +113,8 @@ class DqnAgent(object):
return cls(network, environment, ctx=ctx, **config_dict)
@classmethod
def resume_from_session(cls, session_dir, network_type):
def resume_from_session(cls, session_dir, net, environment):
import pickle
session_dir = os.path.join(session_dir, '.interrupted_session')
if not os.path.exists(session_dir):
raise ValueError('Session directory does not exist')
......@@ -132,13 +131,14 @@ class DqnAgent(object):
with open(files['agent'], 'rb') as f:
agent = pickle.load(f)
agent.__qnet = network_type()
agent.__environment = environment
agent.__qnet = net
agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx)
agent.__qnet.hybridize()
agent.__qnet(nd.ones((1,) + agent.__environment.state_dim))
agent.__best_net = network_type()
agent.__qnet(nd.random_normal(shape=((1,) + agent.__state_dim), ctx=agent.__ctx))
agent.__best_net = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx)
agent.__target_qnet = network_type()
agent.__target_qnet = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx)
agent.__logger = agent.__setup_logging(append=True)
......@@ -157,6 +157,8 @@ class DqnAgent(object):
del self.__training_stats.logger
logger = self.__logger
self.__logger = None
self.__environment.close()
self.__environment = None
self.__save_net(self.__qnet, 'qnet', session_dir)
self.__qnet = None
......@@ -169,7 +171,7 @@ class DqnAgent(object):
with open(agent_session_file, 'wb') as f:
pickle.dump(self, f)
self.__logger = logger
logger.info('State successfully stored')
@property
......@@ -293,7 +295,7 @@ class DqnAgent(object):
return loss
def __do_snapshot_if_in_interval(self, episode):
do_snapshot = (episode % self.__snapshot_interval == 0)
do_snapshot = (episode != 0 and (episode % self.__snapshot_interval == 0))
if do_snapshot:
self.save_parameters(episode=episode)
self.__evaluate()
......@@ -318,6 +320,7 @@ class DqnAgent(object):
# Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning
while self.__current_episode < episodes:
# Check interrupt flag
if self.__interrupt_flag:
self.__interrupt_flag = False
self.__interrupt_training()
......
......@@ -17,6 +17,10 @@ class Environment:
def step(self, action):
pass
@abc.abstractmethod
def close(self):
pass
import gym
class GymEnvironment(Environment):
def __init__(self, env_name, **kwargs):
......
......@@ -37,13 +37,19 @@ class AgentSignalHandler(object):
def __init__(self):
signal.signal(signal.SIGINT, self.interrupt_training)
self.__agent = None
self.__times_interrupted = 0
def register_agent(self, agent):
self.__agent = agent
def interrupt_training(self, sig, frame):
if self.__agent:
self.__agent.set_interrupt_flag(True)
self.__times_interrupted = self.__times_interrupted + 1
if self.__times_interrupted <= 3:
if self.__agent:
self.__agent.set_interrupt_flag(True)
else:
print('Interrupt called three times: Force quit')
sys.exit(1)
style.use('fivethirtyeight')
class TrainingStats(object):
......
from reinforcement_learning.agent import DqnAgent
from reinforcement_learning.util import AgentSignalHandler
import reinforcement_learning.environment
import CNNCreator_torcs_agent_torcsAgent_dqn
import os
import sys
import re
import logging
import mxnet as mx
session_output_dir = 'session'
agent_name='torcs_agent_torcsAgent_dqn'
session_param_output = os.path.join(session_output_dir, agent_name)
def resume_session():
session_param_output = os.path.join(session_output_dir, agent_name)
resume_session = False
resume_directory = None
if os.path.isdir(session_output_dir) and os.path.isdir(session_param_output):
regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
dir_content = os.listdir(session_param_output)
session_files = filter(regex.search, dir_content)
session_files.sort(reverse=True)
for d in session_files:
interrupted_session_dir = os.path.join(session_param_output, d, '.interrupted_session')
if os.path.isdir(interrupted_session_dir):
resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
if resume == 'y':
resume_session = True
resume_directory = interrupted_session_dir
break
return resume_session, resume_directory
if __name__ == "__main__":
env_params = {
'ros_node_name' : 'torcs_agent_torcsAgent_dqnTrainerNode',
......@@ -35,32 +62,42 @@ if __name__ == "__main__":
'epsilon_decay': 0.0001,
}
agent = DqnAgent(
network = net_creator.net,
environment=env,
replay_memory_params=replay_memory_params,
policy_params=policy_params,
state_dim=net_creator.get_input_shapes()[0],
ctx='cpu',
discount_factor=0.999,
loss_function='euclidean',
optimizer='rmsprop',
optimizer_params={
'learning_rate': 0.001
},
training_episodes=20000,
train_interval=1,
use_fix_target=True,
target_update_interval=500,
double_dqn = True,
snapshot_interval=1000,
agent_name='torcs_agent_torcsAgent_dqn',
max_episode_step=999999999,
output_directory='model',
verbose=True,
live_plot = True,
make_logfile=True,
target_score=None
)
train_successfull = agent.train()
agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
resume_session, resume_directory = resume_session()
if resume_session:
agent = DqnAgent.resume_from_session(resume_directory, net_creator.net, env)
else:
agent = DqnAgent(
network = net_creator.net,
environment=env,
replay_memory_params=replay_memory_params,
policy_params=policy_params,
state_dim=net_creator.get_input_shapes()[0],
ctx='cpu',
discount_factor=0.999,
loss_function='euclidean',
optimizer='rmsprop',
optimizer_params={
'learning_rate': 0.001 },
training_episodes=20000,
train_interval=1,
use_fix_target=True,
target_update_interval=500,
double_dqn = True,
snapshot_interval=1000,
agent_name=agent_name,
max_episode_step=999999999,
output_directory=session_output_dir,
verbose=True,
live_plot = True,
make_logfile=True,
target_score=None
)
signal_handler = AgentSignalHandler()
signal_handler.register_agent(agent)
train_successful = agent.train()
if train_successful:
agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
......@@ -100,7 +100,7 @@ class DqnAgent(object):
# Prepare output directory and logger
self.__output_directory = output_directory\
+ '/' + self.__agent_name\
+ '/' + time.strftime('%d-%m-%Y-%H-%M-%S', time.localtime(self.__creation_time))
+ '/' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(self.__creation_time))
self.__logger = self.__setup_logging()
self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
......@@ -113,9 +113,8 @@ class DqnAgent(object):
return cls(network, environment, ctx=ctx, **config_dict)
@classmethod
def resume_from_session(cls, session_dir, network_type):
def resume_from_session(cls, session_dir, net, environment):
import pickle
session_dir = os.path.join(session_dir, '.interrupted_session')
if not os.path.exists(session_dir):
raise ValueError('Session directory does not exist')
......@@ -132,13 +131,14 @@ class DqnAgent(object):
with open(files['agent'], 'rb') as f:
agent = pickle.load(f)
agent.__qnet = network_type()
agent.__environment = environment
agent.__qnet = net
agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx)
agent.__qnet.hybridize()
agent.__qnet(nd.ones((1,) + agent.__environment.state_dim))
agent.__best_net = network_type()
agent.__qnet(nd.random_normal(shape=((1,) + agent.__state_dim), ctx=agent.__ctx))
agent.__best_net = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx)
agent.__target_qnet = network_type()
agent.__target_qnet = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx)
agent.__logger = agent.__setup_logging(append=True)
......@@ -157,6 +157,8 @@ class DqnAgent(object):
del self.__training_stats.logger
logger = self.__logger
self.__logger = None
self.__environment.close()
self.__environment = None
self.__save_net(self.__qnet, 'qnet', session_dir)
self.__qnet = None
......@@ -169,7 +171,7 @@ class DqnAgent(object):
with open(agent_session_file, 'wb') as f:
pickle.dump(self, f)
self.__logger = logger
logger.info('State successfully stored')
@property
......@@ -293,7 +295,7 @@ class DqnAgent(object):
return loss
def __do_snapshot_if_in_interval(self, episode):
do_snapshot = (episode % self.__snapshot_interval == 0)
do_snapshot = (episode != 0 and (episode % self.__snapshot_interval == 0))
if do_snapshot:
self.save_parameters(episode=episode)
self.__evaluate()
......@@ -318,6 +320,7 @@ class DqnAgent(object):
# Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning
while self.__current_episode < episodes:
# Check interrupt flag
if self.__interrupt_flag:
self.__interrupt_flag = False
self.__interrupt_training()
......
......@@ -32,6 +32,10 @@ class Environment:
def step(self, action):
pass
@abc.abstractmethod
def close(self):
pass
import rospy
import thread
import numpy as np
......@@ -83,7 +87,8 @@ class RosEnvironment(Environment):
reset_message.data = True
self.__waiting_for_state_update = True
self.__reset_publisher.publish(reset_message)
self.__wait_for_new_state(self.__reset_publisher, reset_message)
while self.__last_received_terminal:
self.__wait_for_new_state(self.__reset_publisher, reset_message)
return self.__last_received_state
def step(self, action):
......@@ -119,6 +124,9 @@ class RosEnvironment(Environment):
exit()
time.sleep(100/1000)
def close(self):
rospy.signal_shutdown('Program ended!')
def __state_callback(self, data):
self.__last_received_state = np.array(data.data, dtype='double')
......
# This file was automatically generated by SWIG (http://www.swig.org).
# Version 3.0.12
# Version 3.0.8
#
# Do not make changes to this file unless you know what you are doing--modify
# the SWIG interface file instead.
from sys import version_info as _swig_python_version_info
if _swig_python_version_info >= (2, 7, 0):
def swig_import_helper():
import importlib
pkg = __name__.rpartition('.')[0]
mname = '.'.join((pkg, '_torcs_agent_dqn_reward_executor')).lstrip('.')
try:
return importlib.import_module(mname)
except ImportError:
return importlib.import_module('_torcs_agent_dqn_reward_executor')
_torcs_agent_dqn_reward_executor = swig_import_helper()
del swig_import_helper
elif _swig_python_version_info >= (2, 6, 0):
from sys import version_info
if version_info >= (2, 6, 0):
def swig_import_helper():
from os.path import dirname
import imp
......@@ -26,27 +19,22 @@ elif _swig_python_version_info >= (2, 6, 0):
except ImportError:
import _torcs_agent_dqn_reward_executor
return _torcs_agent_dqn_reward_executor
try:
_mod = imp.load_module('_torcs_agent_dqn_reward_executor', fp, pathname, description)
finally:
if fp is not None:
if fp is not None:
try:
_mod = imp.load_module('_torcs_agent_dqn_reward_executor', fp, pathname, description)
finally:
fp.close()
return _mod
return _mod
_torcs_agent_dqn_reward_executor = swig_import_helper()
del swig_import_helper
else:
import _torcs_agent_dqn_reward_executor
del _swig_python_version_info
del version_info
try:
_swig_property = property
except NameError:
pass # Python < 2.2 doesn't have 'property'.
try:
import builtins as __builtin__
except ImportError:
import __builtin__
def _swig_setattr_nondynamic(self, class_type, name, value, static=1):
if (name == "thisown"):
......@@ -71,30 +59,37 @@ def _swig_setattr(self, class_type, name, value):
return _swig_setattr_nondynamic(self, class_type, name, value, 0)
def _swig_getattr(self, class_type, name):
def _swig_getattr_nondynamic(self, class_type, name, static=1):
if (name == "thisown"):
return self.this.own()
method = class_type.__swig_getmethods__.get(name, None)
if method:
return method(self)
raise AttributeError("'%s' object has no attribute '%s'" % (class_type.__name__, name))
if (not static):
return object.__getattr__(self, name)
else:
raise AttributeError(name)
def _swig_getattr(self, class_type, name):
return _swig_getattr_nondynamic(self, class_type, name, 0)
def _swig_repr(self):
try:
strthis = "proxy of " + self.this.__repr__()
except __builtin__.Exception:
except Exception:
strthis = ""
return "<%s.%s; %s >" % (self.__class__.__module__, self.__class__.__name__, strthis,)
try:
_object = object
_newclass = 1
except __builtin__.Exception:
except AttributeError:
class _object:
pass
_newclass = 0
class torcs_agent_dqn_reward_input(_object):
__swig_setmethods__ = {}
__setattr__ = lambda self, name, value: _swig_setattr(self, torcs_agent_dqn_reward_input, name, value)
......@@ -114,7 +109,7 @@ class torcs_agent_dqn_reward_input(_object):
this = _torcs_agent_dqn_reward_executor.new_torcs_agent_dqn_reward_input()
try:
self.this.append(this)
except __builtin__.Exception:
except Exception:
self.this = this
__swig_destroy__ = _torcs_agent_dqn_reward_executor.delete_torcs_agent_dqn_reward_input
__del__ = lambda self: None
......@@ -136,7 +131,7 @@ class torcs_agent_dqn_reward_output(_object):
this = _torcs_agent_dqn_reward_executor.new_torcs_agent_dqn_reward_output()
try:
self.this.append(this)
except __builtin__.Exception:
except Exception:
self.this = this
__swig_destroy__ = _torcs_agent_dqn_reward_executor.delete_torcs_agent_dqn_reward_output
__del__ = lambda self: None
......@@ -160,7 +155,7 @@ class torcs_agent_dqn_reward_executor(_object):
this = _torcs_agent_dqn_reward_executor.new_torcs_agent_dqn_reward_executor()
try:
self.this.append(this)
except __builtin__.Exception:
except Exception:
self.this = this
__swig_destroy__ = _torcs_agent_dqn_reward_executor.delete_torcs_agent_dqn_reward_executor
__del__ = lambda self: None
......
......@@ -37,13 +37,19 @@ class AgentSignalHandler(object):
def __init__(self):
signal.signal(signal.SIGINT, self.interrupt_training)
self.__agent = None
self.__times_interrupted = 0
def register_agent(self, agent):
self.__agent = agent
def interrupt_training(self, sig, frame):
if self.__agent:
self.__agent.set_interrupt_flag(True)
self.__times_interrupted = self.__times_interrupted + 1
if self.__times_interrupted <= 3:
if self.__agent:
self.__agent.set_interrupt_flag(True)
else:
print('Interrupt called three times: Force quit')
sys.exit(1)
style.use('fivethirtyeight')
class TrainingStats(object):
......
#!/bin/bash
cd reward/pylib
mkdir build
cd build
cmake ..
make
mv torcs_agent_dqn_reward_executor.py ../../../reinforcement_learning
mv _torcs_agent_dqn_reward_executor.so ../../../reinforcement_learning
cd ../../../
python CNNTrainer_torcs_agent_torcsAgent_dqn.py
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment