Skip to content
Snippets Groups Projects
Select Git revision
  • master
1 result

vision_wrapping_classes.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    vision_wrapping_classes.py 4.56 KiB
    """
    Wrapper classes for the QubeEnv of the quanser driver. The wrappers work as an OpenAi Gym interface.
    
    These wrappers can be used to get consistent environments with the following state representation:
    [cos(params), sin(params), cos(alpha), sin(alpha), theta_dot, alpha_dot].
    
    The reward functions provided can be found in rl_reward_functions.py.
    
    Wrapper always needs to be used like
        with Wrapper as wrapper:
    to ensure safe closure of camera and qube!
    
    @Author: Steffen Bleher
    """
    import numpy as np
    import cv2
    
    from gym_brt.blackfly.blackfly import Blackfly
    from gym_brt.blackfly.image_preprocessor import ImagePreprocessor
    from gym_brt.blackfly.image_preprocessor import IMAGE_SHAPE
    from gym_brt.envs.qube_swingup_env import QubeSwingupEnv
    from gym_brt.data.config.configuration import FREQUENCY
    from gym import ObservationWrapper, spaces
    
    
    class BlackFlyWrapper(ObservationWrapper):
        """
        Use images from a BlackFly camera as observation
        rather than the observation the environment provides
        """
        def __init__(self, env, no_image_normalization=False, additional_process=None):
            super(BlackFlyWrapper, self).__init__(env)
            self.observation_space = spaces.Box(low=0, high=255,
                                                shape=IMAGE_SHAPE, dtype=np.float32)
            self.camera = Blackfly(exposure_time=1000)
            self.camera.start_acquisition()
            self.preprocessor = ImagePreprocessor(False, IMAGE_SHAPE)
            self.no_image_normalization = no_image_normalization
    
        def _get_state(self):
            image = self.camera.get_image()
            if self.no_image_normalization:
                return self.preprocessor.preprocess_image(image)
            else:
                return self.preprocessor.preprocess_and_normalize_image(image)
    
        def __enter__(self):
            print('start camera')
            self.camera.start_acquisition()
            return super().__enter__()
    
        def __exit__(self, type, value, traceback):
            try:
                self.camera.end_acquisition()
            except:
                print('could not end camera')
                pass
            self.camera.__exit__(type, value, traceback)
            super().__exit__(type, value, traceback)
    
        def observation(self, observation):
            return self._get_state()
    
    
    class VisionQubeBeginDownEnv(QubeSwingupEnv):
        def __init__(self, frequency=FREQUENCY, batch_size=2048, use_simulator=False, simulation_mode='ode', integration_steps=1,
                     encoder_reset_steps=int(1e8), no_image_normalization=False):
            super(QubeSwingupEnv, self).__init__(frequency, batch_size, use_simulator, simulation_mode, integration_steps, encoder_reset_steps, )
            self.out_shape = IMAGE_SHAPE
            self.observation_space = spaces.Box(low=0, high=255,
                                                shape=IMAGE_SHAPE, dtype=np.float32)
    
            if use_simulator:
                if simulation_mode == 'mujoco':
                    # Nothing to do here
                    self.preprocessor = ImagePreprocessor(False, IMAGE_SHAPE)
                else:
                    raise ValueError(f"Unsupported simulation type '{simulation_mode}'. "
                                     f"Valid ones are 'mujoco'")
            else:
                self.camera = Blackfly(exposure_time=1000)
                self.camera.start_acquisition()
                self.preprocessor = ImagePreprocessor(False, IMAGE_SHAPE)
    
            self.use_simulator = use_simulator
            self.simulation_mode = simulation_mode
            self.no_image_normalization = no_image_normalization
    
    
        def _get_state(self):
            if self.use_simulator:
                if self.simulation_mode == 'mujoco':
                    image = self.render("rgb_array", width=self.out_shape[0], height=self.out_shape[1])
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                else:
                    raise ValueError(f"Unsupported simulation type '{self.simulation_mode}'. "
                                     f"Valid ones are 'mujoco'")
            else:
                image = self.camera.get_image()
            if self.no_image_normalization:
                return self.preprocessor.preprocess_image(image)
            else:
                return self.preprocessor.preprocess_and_normalize_image(image)
    
        def __enter__(self):
            print('start camera')
            # self.camera.start_acquisition()
            return super().__enter__()
    
        def __exit__(self, type, value, traceback):
            if not self.use_simulator:
                try:
                    self.camera.end_acquisition()
                except:
                    print('could not end camera')
                    pass
                self.camera.__exit__(type, value, traceback)
            super().__exit__(type, value, traceback)