environment.py 5.28 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
import abc
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import reward_rewardFunction_executor

class RewardFunction(object):
    def __init__(self):
        self.__reward_wrapper = reward_rewardFunction_executor.reward_rewardFunction_executor()
        self.__reward_wrapper.init()

    def reward(self, state, terminal):
Nicola Gatto's avatar
Nicola Gatto committed
13 14
        s = state.astype('double')
        t = bool(terminal)
15
        inp = reward_rewardFunction_executor.reward_rewardFunction_input()
Nicola Gatto's avatar
Nicola Gatto committed
16 17
        inp.state = s
        inp.isTerminal = t
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
        output = self.__reward_wrapper.execute(inp)
        return output.reward



class Environment:
    __metaclass__ = abc.ABCMeta

    def __init__(self):
        self._reward_function = RewardFunction()

    @abc.abstractmethod
    def reset(self):
        pass

    @abc.abstractmethod
    def step(self, action):
        pass

    @abc.abstractmethod
    def close(self):
        pass

import rospy
import thread
import numpy as np
import time
Nicola Gatto's avatar
Nicola Gatto committed
45
from std_msgs.msg import Float32MultiArray, Bool, Int32, MultiArrayDimension, Float32
46 47 48 49

class RosEnvironment(Environment):
    def __init__(self,
        ros_node_name='RosTrainingAgent',
50
        timeout_in_s=60,
51 52 53 54
        state_topic='state',
        action_topic='action',
        reset_topic='reset',
        terminal_state_topic='terminal',
Nicola Gatto's avatar
Nicola Gatto committed
55
        reward_topic='reward'):
56 57
        super(RosEnvironment, self).__init__()
        self.__timeout_in_s = timeout_in_s
58
        self.__in_reset = False
59 60 61
        self.__waiting_for_state_update = False
        self.__waiting_for_terminal_update = False
        self.__last_received_state = 0
Nicola Gatto's avatar
Nicola Gatto committed
62
        self.__last_received_terminal = True
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

        rospy.loginfo("Initialize node {0}".format(ros_node_name))

        self.__step_publisher = rospy.Publisher(action_topic, Int32, queue_size=1)
        rospy.loginfo('Step Publisher initialized with topic {}'.format(action_topic))

        self.__reset_publisher = rospy.Publisher(reset_topic, Bool, queue_size=1)
        rospy.loginfo('Reset Publisher initialized with topic {}'.format(reset_topic))

        rospy.init_node(ros_node_name, anonymous=True)

        self.__state_subscriber = rospy.Subscriber(state_topic, Float32MultiArray, self.__state_callback)
        rospy.loginfo('State Subscriber registered with topic {}'.format(state_topic))

        self.__terminal_state_subscriber = rospy.Subscriber(terminal_state_topic, Bool, self.__terminal_state_callback)
        rospy.loginfo('Terminal State Subscriber registered with topic {}'.format(terminal_state_topic))

        rate = rospy.Rate(10)

        thread.start_new_thread(rospy.spin, ())
        time.sleep(2)

    def reset(self):
86
        self.__in_reset = True
87 88 89 90
        time.sleep(0.5)
        reset_message = Bool()
        reset_message.data = True
        self.__waiting_for_state_update = True
91 92
        self.__waiting_for_terminal_update = False
        self.__waiting_for_reward_update = False
93
        self.__reset_publisher.publish(reset_message)
94
        self.__wait_for_new_state(self.__reset_publisher, reset_message)
95
        while self.__last_received_terminal:
96 97
            pass
        self.__in_reset = False
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        return self.__last_received_state

    def step(self, action):
        action_rospy = Int32()
        action_rospy.data = action

        logger.debug('Send action: {}'.format(action))

        self.__waiting_for_state_update = True
        self.__waiting_for_terminal_update = True
        self.__step_publisher.publish(action_rospy)
        self.__wait_for_new_state(self.__step_publisher, action_rospy)
        next_state = self.__last_received_state
        terminal = self.__last_received_terminal
        reward = self.__calc_reward(next_state, terminal)
        rospy.logdebug('Calculated reward: {}'.format(reward))

        return next_state, reward, terminal, 0

    def __wait_for_new_state(self, publisher, msg):
        time_of_timeout = time.time() + self.__timeout_in_s
        timeout_counter = 0
Nicola Gatto's avatar
Nicola Gatto committed
120 121
        while(self.__waiting_for_state_update
              or self.__waiting_for_terminal_update):
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
            is_timeout = (time.time() > time_of_timeout)
            if (is_timeout):
                if timeout_counter < 3:
                    rospy.logwarn("Timeout occured: Retry message")
                    publisher.publish(msg)
                    timeout_counter += 1
                    time_of_timeout = time.time() + self.__timeout_in_s
                else:
                    rospy.logerr("Timeout 3 times in a row: Terminate application")
                    exit()
            time.sleep(100/1000)

    def close(self):
        rospy.signal_shutdown('Program ended!')

    def __state_callback(self, data):
138
        self.__last_received_state = np.array(data.data, dtype='float32').reshape((8,))
139 140 141 142 143 144 145 146 147 148 149 150
        rospy.logdebug('Received state: {}'.format(self.__last_received_state))
        self.__waiting_for_state_update = False

    def __terminal_state_callback(self, data):
        self.__last_received_terminal = data.data
        rospy.logdebug('Received terminal flag: {}'.format(self.__last_received_terminal))
        logger.debug('Received terminal: {}'.format(self.__last_received_terminal))
        self.__waiting_for_terminal_update = False

    def __calc_reward(self, state, terminal):
        # C++ Wrapper call
        return self._reward_function.reward(state, terminal)