environment.py 4.93 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import abc
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Environment:
    __metaclass__ = abc.ABCMeta

    def __init__(self):
        pass

    @abc.abstractmethod
    def reset(self):
        pass

    @abc.abstractmethod
    def step(self, action):
        pass

    @abc.abstractmethod
    def close(self):
        pass

import rospy
import thread
import numpy as np
import time
from std_msgs.msg import Float32MultiArray, Bool, Int32, MultiArrayDimension, Float32

class RosEnvironment(Environment):
    def __init__(self,
        ros_node_name='RosTrainingAgent',
        timeout_in_s=3,
        state_topic='state',
        action_topic='action',
        reset_topic='reset',
        terminal_state_topic='terminal',
        reward_topic='reward'):
        super(RosEnvironment, self).__init__()
        self.__timeout_in_s = timeout_in_s
        self.__waiting_for_state_update = False
        self.__waiting_for_terminal_update = False
        self.__last_received_state = 0
        self.__last_received_terminal = True
        self.__last_received_reward = 0.0
        self.__waiting_for_reward_update = False

        rospy.loginfo("Initialize node {0}".format(ros_node_name))

        self.__step_publisher = rospy.Publisher(action_topic, Int32, queue_size=1)
        rospy.loginfo('Step Publisher initialized with topic {}'.format(action_topic))

        self.__reset_publisher = rospy.Publisher(reset_topic, Bool, queue_size=1)
        rospy.loginfo('Reset Publisher initialized with topic {}'.format(reset_topic))

        rospy.init_node(ros_node_name, anonymous=True)

        self.__state_subscriber = rospy.Subscriber(state_topic, Float32MultiArray, self.__state_callback)
        rospy.loginfo('State Subscriber registered with topic {}'.format(state_topic))

        self.__terminal_state_subscriber = rospy.Subscriber(terminal_state_topic, Bool, self.__terminal_state_callback)
        rospy.loginfo('Terminal State Subscriber registered with topic {}'.format(terminal_state_topic))

        self.__reward_subscriber = rospy.Subscriber(reward_topic, Float32, self.__reward_callback)
        rospy.loginfo('Reward Subscriber registered with topic {}'.format(reward_topic))

        rate = rospy.Rate(10)

        thread.start_new_thread(rospy.spin, ())
        time.sleep(2)

    def reset(self):
        time.sleep(0.5)
        reset_message = Bool()
        reset_message.data = True
        self.__waiting_for_state_update = True
        self.__reset_publisher.publish(reset_message)
        while self.__last_received_terminal:
            self.__wait_for_new_state(self.__reset_publisher, reset_message)
        return self.__last_received_state

    def step(self, action):
        action_rospy = Int32()
        action_rospy.data = action

        logger.debug('Send action: {}'.format(action))

        self.__waiting_for_state_update = True
        self.__waiting_for_terminal_update = True
        self.__waiting_for_reward_update = True
        self.__step_publisher.publish(action_rospy)
        self.__wait_for_new_state(self.__step_publisher, action_rospy)
        next_state = self.__last_received_state
        terminal = self.__last_received_terminal
        reward = self.__last_received_reward
        rospy.logdebug('Calculated reward: {}'.format(reward))

        return next_state, reward, terminal, 0

    def __wait_for_new_state(self, publisher, msg):
        time_of_timeout = time.time() + self.__timeout_in_s
        timeout_counter = 0
        while(self.__waiting_for_state_update
              or self.__waiting_for_terminal_update or self.__waiting_for_reward_update):
            is_timeout = (time.time() > time_of_timeout)
            if (is_timeout):
                if timeout_counter < 3:
                    rospy.logwarn("Timeout occured: Retry message")
                    publisher.publish(msg)
                    timeout_counter += 1
                    time_of_timeout = time.time() + self.__timeout_in_s
                else:
                    rospy.logerr("Timeout 3 times in a row: Terminate application")
                    exit()
            time.sleep(100/1000)

    def close(self):
        rospy.signal_shutdown('Program ended!')

    def __state_callback(self, data):
        self.__last_received_state = np.array(data.data, dtype='float32')
        rospy.logdebug('Received state: {}'.format(self.__last_received_state))
        self.__waiting_for_state_update = False

    def __terminal_state_callback(self, data):
        self.__last_received_terminal = data.data
        rospy.logdebug('Received terminal flag: {}'.format(self.__last_received_terminal))
        logger.debug('Received terminal: {}'.format(self.__last_received_terminal))
        self.__waiting_for_terminal_update = False

    def __reward_callback(self, data):
        self.__last_received_reward = float(data.data)
        logger.debug('Received reward: {}'.format(self.__last_received_reward))
        self.__waiting_for_reward_update = False