environment.py 1.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
import abc
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Environment:
    __metaclass__ = abc.ABCMeta

    def __init__(self):
        pass

    @abc.abstractmethod
    def reset(self):
        pass

    @abc.abstractmethod
    def step(self, action):
        pass

    @abc.abstractmethod
    def close(self):
        pass

import gym
class GymEnvironment(Environment):
    def __init__(self, env_name, **kwargs):
        super(GymEnvironment, self).__init__(**kwargs)
        self.__seed = 42
        self.__env = gym.make(env_name)
        self.__env.seed(self.__seed)

    @property
    def state_dim(self):
        return self.__env.observation_space.shape


    @property
    def number_of_actions(self):
        return self.__env.action_space.n

    @property
    def rewards_dtype(self):
        return 'float32'

    def reset(self):
        return self.__env.reset()

    def step(self, action):
        return self.__env.step(action)

    def close(self):
        self.__env.close()

    def action_space(self):
        self.__env.action_space

    def is_in_action_space(self, action):
        return self.__env.action_space.contains(action)

    def sample_action(self):
        return self.__env.action_space.sample()

    def render(self):
        self.__env.render()