Commit 02e3446f authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Add test with reward topic

parent 776dffd5
Pipeline #147044 passed with stages
in 3 minutes and 9 seconds
...@@ -222,7 +222,6 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -222,7 +222,6 @@ public class GenerationTest extends AbstractSymtabTest {
trainGenerator.generate(modelPath, "ReinforcementConfig3", trainedArchitecture); trainGenerator.generate(modelPath, "ReinforcementConfig3", trainedArchitecture);
/*
assertTrue(Log.getFindings().isEmpty()); assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual( checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"), Paths.get("./target/generated-sources-cnnarch"),
...@@ -238,7 +237,7 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -238,7 +237,7 @@ public class GenerationTest extends AbstractSymtabTest {
"reinforcement_learning/util.py", "reinforcement_learning/util.py",
"reinforcement_learning/cnnarch_logger.py" "reinforcement_learning/cnnarch_logger.py"
) )
);*/ );
} }
......
from reinforcement_learning.agent import DqnAgent
from reinforcement_learning.util import AgentSignalHandler
from reinforcement_learning.cnnarch_logger import ArchLogger
import reinforcement_learning.environment
import CNNCreator_reinforcementConfig3
import os
import sys
import re
import time
import numpy as np
import mxnet as mx
def resume_session(sessions_dir):
resume_session = False
resume_directory = None
if os.path.isdir(sessions_dir):
regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
dir_content = os.listdir(sessions_dir)
session_files = filter(regex.search, dir_content)
session_files.sort(reverse=True)
for d in session_files:
interrupted_session_dir = os.path.join(sessions_dir, d, '.interrupted_session')
if os.path.isdir(interrupted_session_dir):
resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
if resume == 'y':
resume_session = True
resume_directory = interrupted_session_dir
break
return resume_session, resume_directory
if __name__ == "__main__":
agent_name = 'reinforcement_agent'
# Prepare output directory and logger
all_output_dir = os.path.join('model', agent_name)
output_directory = os.path.join(
all_output_dir,
time.strftime('%Y-%m-%d-%H-%M-%S',
time.localtime(time.time())))
ArchLogger.set_output_directory(output_directory)
ArchLogger.set_logger_name(agent_name)
ArchLogger.set_output_level(ArchLogger.INFO)
env_params = {
'ros_node_name': 'reinforcementConfig3TrainerNode',
'state_topic': '/environment/state',
'action_topic': '/environment/action',
'reset_topic': '/environment/reset',
'reward_topic': '/environment/reward',
}
env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu()
qnet_creator = CNNCreator_reinforcementConfig3.CNNCreator_reinforcementConfig3()
qnet_creator.construct(context)
agent_params = {
'environment': env,
'replay_memory_params': {
'method': 'buffer',
'memory_size': 1000000,
'sample_size': 64,
'state_dtype': 'float32',
'action_dtype': 'float32',
'rewards_dtype': 'float32'
},
'strategy_params': {
'method':'epsgreedy',
'epsilon': 1,
'min_epsilon': 0.02,
'epsilon_decay_method': 'linear',
'epsilon_decay': 0.0001,
},
'agent_name': agent_name,
'verbose': True,
'output_directory': output_directory,
'state_dim': (8,),
'action_dim': (3,),
'discount_factor': 0.99999,
'training_episodes': 1000,
'train_interval': 1,
'snapshot_interval': 500,
'max_episode_step': 10000,
'target_score': 35000,
'qnet':qnet_creator.net,
'use_fix_target': True,
'target_update_interval': 500,
'loss_function': 'huber_loss',
'optimizer': 'adam',
'optimizer_params': {
'learning_rate': 0.001 },
'double_dqn': True,
}
resume, resume_directory = resume_session(all_output_dir)
if resume:
output_directory, _ = os.path.split(resume_directory)
ArchLogger.set_output_directory(output_directory)
resume_agent_params = {
'session_dir': resume_directory,
'environment': env,
'net': qnet_creator.net,
}
agent = DqnAgent.resume_from_session(**resume_agent_params)
else:
agent = DqnAgent(**agent_params)
signal_handler = AgentSignalHandler()
signal_handler.register_agent(agent)
train_successful = agent.train()
if train_successful:
agent.save_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_newest', epoch=0)
import logging
import sys
import os
import util
class ArchLogger(object):
_logger = None
__output_level = logging.INFO
__logger_name = 'agent'
__output_directory = '.'
__append = True
__logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
__dateformat = '%d-%b-%y %H:%M:%S'
INFO = logging.INFO
DEBUG = logging.DEBUG
@staticmethod
def set_output_level(output_level):
assert output_level is not None
ArchLogger.__output_level = output_level
@staticmethod
def set_logger_name(logger_name):
assert logger_name is not None
ArchLogger.__logger_name = logger_name
@staticmethod
def set_output_directory(output_directory):
assert output_directory is not None
ArchLogger.__output_directory = output_directory
@staticmethod
def set_append(append):
assert append is not None
ArchLogger.__append = append
@staticmethod
def set_log_format(logformat, dateformat):
assert logformat is not None
assert dateformat is not None
ArchLogger.__logformat = logformat
ArchLogger.__dateformat = dateformat
@staticmethod
def init_logger(make_log_file=True):
assert ArchLogger._logger is None, 'Logger init already called'
filemode = 'a' if ArchLogger.__append else 'w'
formatter = logging.Formatter(
fmt=ArchLogger.__logformat, datefmt=ArchLogger.__dateformat)
logger = logging.getLogger(ArchLogger.__logger_name)
logger.propagate = False
if not logger.handlers:
logger.setLevel(ArchLogger.__output_level)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(ArchLogger.__output_level)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if make_log_file:
util.make_directory_if_not_exist(ArchLogger.__output_directory)
log_file = os.path.join(
ArchLogger.__output_directory,
ArchLogger.__logger_name + '.log')
file_handler = logging.FileHandler(log_file, mode=filemode)
file_handler.setLevel(ArchLogger.__output_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
ArchLogger._logger = logger
@staticmethod
def get_logger():
if ArchLogger._logger is None:
ArchLogger.init_logger()
assert ArchLogger._logger is not None
return ArchLogger._logger
if __name__ == "__main__":
print('=== Test logger ===')
ArchLogger.set_logger_name('TestLogger')
ArchLogger.set_output_directory('test_log')
ArchLogger.init_logger()
logger = ArchLogger.get_logger()
logger.warning('This is a warning')
logger.debug('This is a debug information, which you should not see')
logger.info('This is a normal information')
assert os.path.exists('test_log')\
and os.path.isfile(os.path.join('test_log', 'TestLogger.log')),\
'Test failed: No logfile exists'
import shutil
shutil.rmtree('test_log')
\ No newline at end of file
import abc
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Environment:
__metaclass__ = abc.ABCMeta
def __init__(self):
pass
@abc.abstractmethod
def reset(self):
pass
@abc.abstractmethod
def step(self, action):
pass
@abc.abstractmethod
def close(self):
pass
import rospy
import thread
import numpy as np
import time
from std_msgs.msg import Float32MultiArray, Bool, Int32, MultiArrayDimension, Float32
class RosEnvironment(Environment):
def __init__(self,
ros_node_name='RosTrainingAgent',
timeout_in_s=3,
state_topic='state',
action_topic='action',
reset_topic='reset',
terminal_state_topic='terminal',
reward_topic='reward'):
super(RosEnvironment, self).__init__()
self.__timeout_in_s = timeout_in_s
self.__waiting_for_state_update = False
self.__waiting_for_terminal_update = False
self.__last_received_state = 0
self.__last_received_terminal = True
self.__last_received_reward = 0.0
self.__waiting_for_reward_update = False
rospy.loginfo("Initialize node {0}".format(ros_node_name))
self.__step_publisher = rospy.Publisher(action_topic, Int32, queue_size=1)
rospy.loginfo('Step Publisher initialized with topic {}'.format(action_topic))
self.__reset_publisher = rospy.Publisher(reset_topic, Bool, queue_size=1)
rospy.loginfo('Reset Publisher initialized with topic {}'.format(reset_topic))
rospy.init_node(ros_node_name, anonymous=True)
self.__state_subscriber = rospy.Subscriber(state_topic, Float32MultiArray, self.__state_callback)
rospy.loginfo('State Subscriber registered with topic {}'.format(state_topic))
self.__terminal_state_subscriber = rospy.Subscriber(terminal_state_topic, Bool, self.__terminal_state_callback)
rospy.loginfo('Terminal State Subscriber registered with topic {}'.format(terminal_state_topic))
self.__reward_subscriber = rospy.Subscriber(reward_topic, Float32, self.__reward_callback)
rospy.loginfo('Reward Subscriber registered with topic {}'.format(reward_topic))
rate = rospy.Rate(10)
thread.start_new_thread(rospy.spin, ())
time.sleep(2)
def reset(self):
time.sleep(0.5)
reset_message = Bool()
reset_message.data = True
self.__waiting_for_state_update = True
self.__reset_publisher.publish(reset_message)
while self.__last_received_terminal:
self.__wait_for_new_state(self.__reset_publisher, reset_message)
return self.__last_received_state
def step(self, action):
action_rospy = Int32()
action_rospy.data = action
logger.debug('Send action: {}'.format(action))
self.__waiting_for_state_update = True
self.__waiting_for_terminal_update = True
self.__waiting_for_reward_update = True
self.__step_publisher.publish(action_rospy)
self.__wait_for_new_state(self.__step_publisher, action_rospy)
next_state = self.__last_received_state
terminal = self.__last_received_terminal
reward = self.__last_received_reward
rospy.logdebug('Calculated reward: {}'.format(reward))
return next_state, reward, terminal, 0
def __wait_for_new_state(self, publisher, msg):
time_of_timeout = time.time() + self.__timeout_in_s
timeout_counter = 0
while(self.__waiting_for_state_update
or self.__waiting_for_terminal_update or self.__waiting_for_reward_update):
is_timeout = (time.time() > time_of_timeout)
if (is_timeout):
if timeout_counter < 3:
rospy.logwarn("Timeout occured: Retry message")
publisher.publish(msg)
timeout_counter += 1
time_of_timeout = time.time() + self.__timeout_in_s
else:
rospy.logerr("Timeout 3 times in a row: Terminate application")
exit()
time.sleep(100/1000)
def close(self):
rospy.signal_shutdown('Program ended!')
def __state_callback(self, data):
self.__last_received_state = np.array(data.data, dtype='float32')
rospy.logdebug('Received state: {}'.format(self.__last_received_state))
self.__waiting_for_state_update = False
def __terminal_state_callback(self, data):
self.__last_received_terminal = data.data
rospy.logdebug('Received terminal flag: {}'.format(self.__last_received_terminal))
logger.debug('Received terminal: {}'.format(self.__last_received_terminal))
self.__waiting_for_terminal_update = False
def __reward_callback(self, data):
self.__last_received_reward = float(data.data)
logger.debug('Received reward: {}'.format(self.__last_received_reward))
self.__waiting_for_reward_update = False
import numpy as np
class ReplayMemoryBuilder(object):
def __init__(self):
self.__supported_methods = ['online', 'buffer', 'combined']
def build_by_params(
self,
state_dim,
method='online',
state_dtype='float32',
action_dim=(1,),
action_dtype='uint8',
rewards_dtype='float32',
memory_size=1000,
sample_size=32
):
assert state_dim is not None
assert action_dim is not None
assert method in self.__supported_methods
if method == 'online':
return self.build_online_memory(
state_dim=state_dim, state_dtype=state_dtype,
action_dtype=action_dtype, action_dim=action_dim,
rewards_dtype=rewards_dtype)
else:
assert memory_size is not None and memory_size > 0
assert sample_size is not None and sample_size > 0
if method == 'buffer':
return self.build_buffered_memory(
state_dim=state_dim, sample_size=sample_size,
memory_size=memory_size, state_dtype=state_dtype,
action_dim=action_dim, action_dtype=action_dtype,
rewards_dtype=rewards_dtype)
else:
return self.build_combined_memory(
state_dim=state_dim, sample_size=sample_size,
memory_size=memory_size, state_dtype=state_dtype,
action_dim=action_dim, action_dtype=action_dtype,
rewards_dtype=rewards_dtype)
def build_buffered_memory(
self, state_dim, memory_size, sample_size, state_dtype, action_dim,
action_dtype, rewards_dtype
):
assert memory_size > 0
assert sample_size > 0
return ReplayMemory(
state_dim, size=memory_size, sample_size=sample_size,
state_dtype=state_dtype, action_dim=action_dim,
action_dtype=action_dtype, rewards_dtype=rewards_dtype)
def build_combined_memory(
self, state_dim, memory_size, sample_size, state_dtype, action_dim,
action_dtype, rewards_dtype
):
assert memory_size > 0
assert sample_size > 0
return CombinedReplayMemory(
state_dim, size=memory_size, sample_size=sample_size,
state_dtype=state_dtype, action_dim=action_dim,
action_dtype=action_dtype, rewards_dtype=rewards_dtype)
def build_online_memory(
self, state_dim, state_dtype, action_dtype, action_dim, rewards_dtype
):
return OnlineReplayMemory(
state_dim, state_dtype=state_dtype, action_dim=action_dim,
action_dtype=action_dtype, rewards_dtype=rewards_dtype)
class ReplayMemory(object):
def __init__(
self,
state_dim,
sample_size,
size=1000,
action_dim=(1,),
state_dtype='float32',
action_dtype='uint8',
rewards_dtype='float32'
):
assert size > 0, "Size must be greater than zero"
assert type(state_dim) is tuple, "State dimension must be a tuple"
assert type(action_dim) is tuple, "Action dimension must be a tuple"
assert sample_size > 0
self._size = size
self._sample_size = sample_size
self._cur_size = 0
self._pointer = 0
self._state_dim = state_dim
self._state_dtype = state_dtype
self._action_dim = action_dim
self._action_dtype = action_dtype
self._rewards_dtype = rewards_dtype
self._states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
self._actions = np.zeros(
(self._size,) + action_dim, dtype=action_dtype)
self._rewards = np.array([0] * self._size, dtype=rewards_dtype)
self._next_states = np.zeros(
(self._size,) + state_dim, dtype=state_dtype)
self._terminals = np.array([0] * self._size, dtype='bool')
@property
def sample_size(self):
return self._sample_size
def append(self, state, action, reward, next_state, terminal):
self._states[self._pointer] = state
self._actions[self._pointer] = action
self._rewards[self._pointer] = reward
self._next_states[self._pointer] = next_state
self._terminals[self._pointer] = terminal
self._pointer = self._pointer + 1
if self._pointer == self._size:
self._pointer = 0
self._cur_size = min(self._size, self._cur_size + 1)
def at(self, index):
return self._states[index],\
self._actions[index],\
self._rewards[index],\
self._next_states[index],\
self._terminals[index]
def is_sample_possible(self, batch_size=None):
batch_size = batch_size if batch_size is not None\
else self._sample_size
return self._cur_size >= batch_size
def sample(self, batch_size=None):
batch_size = batch_size if batch_size is not None\
else self._sample_size
assert self._cur_size >= batch_size,\
"Size of replay memory must be larger than batch size"
i = 0
states = np.zeros((
batch_size,)+self._state_dim, dtype=self._state_dtype)
actions = np.zeros(
(batch_size,)+self._action_dim, dtype=self._action_dtype)
rewards = np.zeros(batch_size, dtype=self._rewards_dtype)
next_states = np.zeros(
(batch_size,)+self._state_dim, dtype=self._state_dtype)
terminals = np.zeros(batch_size, dtype='bool')
while i < batch_size:
rnd_index = np.random.randint(low=0, high=self._cur_size)
states[i] = self._states.take(rnd_index, axis=0)
actions[i] = self._actions.take(rnd_index, axis=0)
rewards[i] = self._rewards.take(rnd_index, axis=0)
next_states[i] = self._next_states.take(rnd_index, axis=0)
terminals[i] = self._terminals.take(rnd_index, axis=0)
i += 1
return states, actions, rewards, next_states, terminals
class OnlineReplayMemory(ReplayMemory):
def __init__(
self, state_dim, state_dtype='float32', action_dim=(1,),
action_dtype='uint8', rewards_dtype='float32'
):
super(OnlineReplayMemory, self).__init__(
state_dim, sample_size=1, size=1, state_dtype=state_dtype,
action_dim=action_dim, action_dtype=action_dtype,
rewards_dtype=rewards_dtype)
class CombinedReplayMemory(ReplayMemory):
def __init__(
self, state_dim, sample_size, size=1000, state_dtype='float32',
action_dim=(1,), action_dtype='uint8', rewards_dtype='float32'
):