environment.py 5.09 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import abc
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Environment:
    __metaclass__ = abc.ABCMeta

    def __init__(self):
        pass

    @abc.abstractmethod
    def reset(self):
        pass

    @abc.abstractmethod
    def step(self, action):
        pass

    @abc.abstractmethod
    def close(self):
        pass

import rospy
import thread
import numpy as np
import time
from std_msgs.msg import Float32MultiArray, Bool, Int32, MultiArrayDimension, Float32

class RosEnvironment(Environment):
    def __init__(self,
        ros_node_name='RosTrainingAgent',
33
        timeout_in_s=60,
Nicola Gatto's avatar
Nicola Gatto committed
34
35
36
37
38
39
40
        state_topic='state',
        action_topic='action',
        reset_topic='reset',
        terminal_state_topic='terminal',
        reward_topic='reward'):
        super(RosEnvironment, self).__init__()
        self.__timeout_in_s = timeout_in_s
41
        self.__in_reset = False
Nicola Gatto's avatar
Nicola Gatto committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        self.__waiting_for_state_update = False
        self.__waiting_for_terminal_update = False
        self.__last_received_state = 0
        self.__last_received_terminal = True
        self.__last_received_reward = 0.0
        self.__waiting_for_reward_update = False

        rospy.loginfo("Initialize node {0}".format(ros_node_name))

        self.__step_publisher = rospy.Publisher(action_topic, Int32, queue_size=1)
        rospy.loginfo('Step Publisher initialized with topic {}'.format(action_topic))

        self.__reset_publisher = rospy.Publisher(reset_topic, Bool, queue_size=1)
        rospy.loginfo('Reset Publisher initialized with topic {}'.format(reset_topic))

        rospy.init_node(ros_node_name, anonymous=True)

        self.__state_subscriber = rospy.Subscriber(state_topic, Float32MultiArray, self.__state_callback)
        rospy.loginfo('State Subscriber registered with topic {}'.format(state_topic))

        self.__terminal_state_subscriber = rospy.Subscriber(terminal_state_topic, Bool, self.__terminal_state_callback)
        rospy.loginfo('Terminal State Subscriber registered with topic {}'.format(terminal_state_topic))

        self.__reward_subscriber = rospy.Subscriber(reward_topic, Float32, self.__reward_callback)
        rospy.loginfo('Reward Subscriber registered with topic {}'.format(reward_topic))

        rate = rospy.Rate(10)

        thread.start_new_thread(rospy.spin, ())
        time.sleep(2)

    def reset(self):
74
        self.__in_reset = True
Nicola Gatto's avatar
Nicola Gatto committed
75
76
77
        reset_message = Bool()
        reset_message.data = True
        self.__waiting_for_state_update = True
78
79
        self.__waiting_for_terminal_update = False
        self.__waiting_for_reward_update = False
Nicola Gatto's avatar
Nicola Gatto committed
80
        self.__reset_publisher.publish(reset_message)
81
        self.__wait_for_new_state(self.__reset_publisher, reset_message)
Nicola Gatto's avatar
Nicola Gatto committed
82
        while self.__last_received_terminal:
83
84
            pass
        self.__in_reset = False
Nicola Gatto's avatar
Nicola Gatto committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        return self.__last_received_state

    def step(self, action):
        action_rospy = Int32()
        action_rospy.data = action

        logger.debug('Send action: {}'.format(action))

        self.__waiting_for_state_update = True
        self.__waiting_for_terminal_update = True
        self.__waiting_for_reward_update = True
        self.__step_publisher.publish(action_rospy)
        self.__wait_for_new_state(self.__step_publisher, action_rospy)
        next_state = self.__last_received_state
        terminal = self.__last_received_terminal
        reward = self.__last_received_reward
101
102

        logger.debug('Transition: ({}, {}, {}, {})'.format(action, reward, next_state, terminal))
Nicola Gatto's avatar
Nicola Gatto committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

        return next_state, reward, terminal, 0

    def __wait_for_new_state(self, publisher, msg):
        time_of_timeout = time.time() + self.__timeout_in_s
        timeout_counter = 0
        while(self.__waiting_for_state_update
              or self.__waiting_for_terminal_update or self.__waiting_for_reward_update):
            is_timeout = (time.time() > time_of_timeout)
            if (is_timeout):
                if timeout_counter < 3:
                    rospy.logwarn("Timeout occured: Retry message")
                    publisher.publish(msg)
                    timeout_counter += 1
                    time_of_timeout = time.time() + self.__timeout_in_s
                else:
                    rospy.logerr("Timeout 3 times in a row: Terminate application")
                    exit()
121
            time.sleep(1/500)
Nicola Gatto's avatar
Nicola Gatto committed
122
123
124
125
126

    def close(self):
        rospy.signal_shutdown('Program ended!')

    def __state_callback(self, data):
127
        self.__last_received_state = np.array(data.data, dtype='float32').reshape((8,))
128
        logger.debug('Received state: {}'.format(self.__last_received_state))
Nicola Gatto's avatar
Nicola Gatto committed
129
130
131
        self.__waiting_for_state_update = False

    def __terminal_state_callback(self, data):
132
133
        self.__last_received_terminal = np.bool(data.data)
        logger.debug('Received terminal flag: {}'.format(self.__last_received_terminal))
Nicola Gatto's avatar
Nicola Gatto committed
134
135
136
        self.__waiting_for_terminal_update = False

    def __reward_callback(self, data):
137
        self.__last_received_reward = np.float32(data.data)
Nicola Gatto's avatar
Nicola Gatto committed
138
139
        logger.debug('Received reward: {}'.format(self.__last_received_reward))
        self.__waiting_for_reward_update = False