replay_memory.py 8.18 KB
Newer Older
1 2
import numpy as np

Nicola Gatto's avatar
Nicola Gatto committed
3

4 5 6 7
class ReplayMemoryBuilder(object):
    def __init__(self):
        self.__supported_methods = ['online', 'buffer', 'combined']

Nicola Gatto's avatar
Nicola Gatto committed
8 9
    def build_by_params(
        self,
10 11 12
        state_dim,
        method='online',
        state_dtype='float32',
Nicola Gatto's avatar
Nicola Gatto committed
13
        action_dim=(1,),
14 15 16
        action_dtype='uint8',
        rewards_dtype='float32',
        memory_size=1000,
Nicola Gatto's avatar
Nicola Gatto committed
17 18
        sample_size=32
    ):
19
        assert state_dim is not None
Nicola Gatto's avatar
Nicola Gatto committed
20
        assert action_dim is not None
21 22 23
        assert method in self.__supported_methods

        if method == 'online':
Nicola Gatto's avatar
Nicola Gatto committed
24 25 26 27
            return self.build_online_memory(
                state_dim=state_dim, state_dtype=state_dtype,
                action_dtype=action_dtype, action_dim=action_dim,
                rewards_dtype=rewards_dtype)
28 29 30 31
        else:
            assert memory_size is not None and memory_size > 0
            assert sample_size is not None and sample_size > 0
            if method == 'buffer':
Nicola Gatto's avatar
Nicola Gatto committed
32 33 34 35
                return self.build_buffered_memory(
                    state_dim=state_dim, sample_size=sample_size,
                    memory_size=memory_size, state_dtype=state_dtype,
                    action_dim=action_dim, action_dtype=action_dtype,
36 37
                    rewards_dtype=rewards_dtype)
            else:
Nicola Gatto's avatar
Nicola Gatto committed
38 39 40 41
                return self.build_combined_memory(
                    state_dim=state_dim, sample_size=sample_size,
                    memory_size=memory_size, state_dtype=state_dtype,
                    action_dim=action_dim, action_dtype=action_dtype,
42 43
                    rewards_dtype=rewards_dtype)

Nicola Gatto's avatar
Nicola Gatto committed
44 45 46 47
    def build_buffered_memory(
        self, state_dim, memory_size, sample_size, state_dtype, action_dim,
        action_dtype, rewards_dtype
    ):
48 49
        assert memory_size > 0
        assert sample_size > 0
Nicola Gatto's avatar
Nicola Gatto committed
50 51 52 53 54 55 56 57 58
        return ReplayMemory(
            state_dim, size=memory_size, sample_size=sample_size,
            state_dtype=state_dtype, action_dim=action_dim,
            action_dtype=action_dtype, rewards_dtype=rewards_dtype)

    def build_combined_memory(
        self, state_dim, memory_size, sample_size, state_dtype, action_dim,
        action_dtype, rewards_dtype
    ):
59 60
        assert memory_size > 0
        assert sample_size > 0
Nicola Gatto's avatar
Nicola Gatto committed
61 62 63 64 65 66 67 68 69 70 71
        return CombinedReplayMemory(
            state_dim, size=memory_size, sample_size=sample_size,
            state_dtype=state_dtype, action_dim=action_dim,
            action_dtype=action_dtype, rewards_dtype=rewards_dtype)

    def build_online_memory(
        self, state_dim, state_dtype, action_dtype, action_dim, rewards_dtype
    ):
        return OnlineReplayMemory(
            state_dim, state_dtype=state_dtype, action_dim=action_dim,
            action_dtype=action_dtype, rewards_dtype=rewards_dtype)
72 73 74


class ReplayMemory(object):
Nicola Gatto's avatar
Nicola Gatto committed
75 76 77 78 79 80 81 82 83 84
    def __init__(
        self,
        state_dim,
        sample_size,
        size=1000,
        action_dim=(1,),
        state_dtype='float32',
        action_dtype='uint8',
        rewards_dtype='float32'
    ):
85 86
        assert size > 0, "Size must be greater than zero"
        assert type(state_dim) is tuple, "State dimension must be a tuple"
Nicola Gatto's avatar
Nicola Gatto committed
87
        assert type(action_dim) is tuple, "Action dimension must be a tuple"
88 89 90 91 92 93 94
        assert sample_size > 0
        self._size = size
        self._sample_size = sample_size
        self._cur_size = 0
        self._pointer = 0
        self._state_dim = state_dim
        self._state_dtype = state_dtype
Nicola Gatto's avatar
Nicola Gatto committed
95
        self._action_dim = action_dim
96 97 98
        self._action_dtype = action_dtype
        self._rewards_dtype = rewards_dtype
        self._states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
Nicola Gatto's avatar
Nicola Gatto committed
99 100
        self._actions = np.zeros(
            (self._size,) + action_dim, dtype=action_dtype)
101
        self._rewards = np.array([0] * self._size, dtype=rewards_dtype)
Nicola Gatto's avatar
Nicola Gatto committed
102 103
        self._next_states = np.zeros(
            (self._size,) + state_dim, dtype=state_dtype)
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        self._terminals = np.array([0] * self._size, dtype='bool')

    @property
    def sample_size(self):
        return self._sample_size

    def append(self, state, action, reward, next_state, terminal):
        self._states[self._pointer] = state
        self._actions[self._pointer] = action
        self._rewards[self._pointer] = reward
        self._next_states[self._pointer] = next_state
        self._terminals[self._pointer] = terminal

        self._pointer = self._pointer + 1
        if self._pointer == self._size:
            self._pointer = 0

        self._cur_size = min(self._size, self._cur_size + 1)

    def at(self, index):
        return self._states[index],\
            self._actions[index],\
            self._rewards[index],\
            self._next_states[index],\
            self._terminals[index]

    def is_sample_possible(self, batch_size=None):
Nicola Gatto's avatar
Nicola Gatto committed
131 132
        batch_size = batch_size if batch_size is not None\
            else self._sample_size
133 134 135
        return self._cur_size >= batch_size

    def sample(self, batch_size=None):
Nicola Gatto's avatar
Nicola Gatto committed
136 137 138 139 140 141 142 143 144
        batch_size = batch_size if batch_size is not None\
            else self._sample_size
        assert self._cur_size >= batch_size,\
            "Size of replay memory must be larger than batch size"
        i = 0
        states = np.zeros((
            batch_size,)+self._state_dim, dtype=self._state_dtype)
        actions = np.zeros(
            (batch_size,)+self._action_dim,  dtype=self._action_dtype)
145
        rewards = np.zeros(batch_size, dtype=self._rewards_dtype)
Nicola Gatto's avatar
Nicola Gatto committed
146 147
        next_states = np.zeros(
            (batch_size,)+self._state_dim, dtype=self._state_dtype)
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
        terminals = np.zeros(batch_size, dtype='bool')

        while i < batch_size:
            rnd_index = np.random.randint(low=0, high=self._cur_size)
            states[i] = self._states.take(rnd_index, axis=0)
            actions[i] = self._actions.take(rnd_index, axis=0)
            rewards[i] = self._rewards.take(rnd_index, axis=0)
            next_states[i] = self._next_states.take(rnd_index, axis=0)
            terminals[i] = self._terminals.take(rnd_index, axis=0)
            i += 1

        return states, actions, rewards, next_states, terminals


class OnlineReplayMemory(ReplayMemory):
Nicola Gatto's avatar
Nicola Gatto committed
163 164 165 166 167 168 169 170
    def __init__(
        self, state_dim, state_dtype='float32', action_dim=(1,),
        action_dtype='uint8', rewards_dtype='float32'
    ):
        super(OnlineReplayMemory, self).__init__(
            state_dim, sample_size=1, size=1, state_dtype=state_dtype,
            action_dim=action_dim, action_dtype=action_dtype,
            rewards_dtype=rewards_dtype)
171 172 173


class CombinedReplayMemory(ReplayMemory):
Nicola Gatto's avatar
Nicola Gatto committed
174 175 176 177 178 179 180 181
    def __init__(
        self, state_dim, sample_size, size=1000, state_dtype='float32',
        action_dim=(1,), action_dtype='uint8', rewards_dtype='float32'
    ):
        super(CombinedReplayMemory, self).__init__(
            state_dim=state_dim, sample_size=(sample_size - 1), size=size,
            state_dtype=state_dtype, action_dim=action_dim,
            action_dtype=action_dtype, rewards_dtype=rewards_dtype)
182 183

        self._last_state = np.zeros((1,) + state_dim, dtype=state_dtype)
Nicola Gatto's avatar
Nicola Gatto committed
184
        self._last_action = np.array((1,) + action_dim, dtype=action_dtype)
185 186 187 188 189
        self._last_reward = np.array([0], dtype=rewards_dtype)
        self._last_next_state = np.zeros((1,) + state_dim, dtype=state_dtype)
        self._last_terminal = np.array([0], dtype='bool')

    def append(self, state, action, reward, next_state, terminal):
Nicola Gatto's avatar
Nicola Gatto committed
190 191
        super(CombinedReplayMemory, self).append(
            state, action, reward, next_state, terminal)
192 193 194 195 196 197 198
        self._last_state = state
        self._last_action = action
        self._last_reward = reward
        self._last_next_state = next_state
        self._last_terminal = terminal

    def sample(self, batch_size=None):
Nicola Gatto's avatar
Nicola Gatto committed
199 200 201 202
        batch_size = (batch_size-1) if batch_size is not None\
            else self._sample_size
        states, actions, rewards, next_states, terminals = super(
            CombinedReplayMemory, self).sample(batch_size=batch_size)
203 204 205 206 207
        states = np.append(states, [self._last_state], axis=0)
        actions = np.append(actions, [self._last_action], axis=0)
        rewards = np.append(rewards, [self._last_reward], axis=0)
        next_states = np.append(next_states, [self._last_next_state], axis=0)
        terminals = np.append(terminals, [self._last_terminal], axis=0)
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
208
        return states, actions, rewards, next_states, terminals