Explorar o código

Classical env and wrappers (#6, #13, #22) (#24)

* Classical env and wrappers (#6, #13, #22)

* Add Classical-v0 4 rooms env #

* Add image wrapper

* Add full state wrapper

* Updated according to #24

* Changed name to FourRooms

* Fix obs space in ObsWrapper

* Add test for FullObsWrapper

* revert

* Updated according to #24

* Changed name to FourRooms

* Fix obs space in ObsWrapper

* Add test for FullObsWrapper

* Removed doors

* Removed test env #24

* Revert minigrid
d3sm0 %!s(int64=6) %!d(string=hai) anos
pai
achega
f1a2080a32

+ 1 - 0
gym_minigrid/envs/__init__.py

@@ -14,3 +14,4 @@ from gym_minigrid.envs.playground_v0 import *
 from gym_minigrid.envs.redbluedoors import *
 from gym_minigrid.envs.obstructedmaze import *
 from gym_minigrid.envs.memory import *
+from gym_minigrid.envs.fourrooms import *

+ 79 - 0
gym_minigrid/envs/fourrooms.py

@@ -0,0 +1,79 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from gym_minigrid.minigrid import *
+from gym_minigrid.register import register
+
+
+class FourRoomsEnv(MiniGridEnv):
+    """
+    Classical 4 rooms Gridworld environmnet.
+    Can specify agent and goal position, if not it set at random.
+    """
+
+    def __init__(self, agent_pos=None, goal_pos=None):
+        self._agent_default_pos = agent_pos
+        self._goal_default_pos = goal_pos
+        super().__init__(grid_size=19, max_steps=100)
+
+    def _gen_grid(self, width, height):
+        # Create the grid
+        self.grid = Grid(width, height)
+
+        # Generate the surrounding walls
+        self.grid.horz_wall(0, 0)
+        self.grid.horz_wall(0, height - 1)
+        self.grid.vert_wall(0, 0)
+        self.grid.vert_wall(width - 1, 0)
+
+        room_w = width // 2
+        room_h = height // 2
+
+        # For each row of rooms
+        for j in range(0, 2):
+
+            # For each column
+            for i in range(0, 2):
+                xL = i * room_w
+                yT = j * room_h
+                xR = xL + room_w
+                yB = yT + room_h
+
+                # Bottom wall and door
+                if i + 1 < 2:
+                    self.grid.vert_wall(xR, yT, room_h)
+                    pos = (xR, self._rand_int(yT + 1, yB - 1))
+                    self.grid.set(*pos, None)
+
+                # Bottom wall and door
+                if j + 1 < 2:
+                    self.grid.horz_wall(xL, yB, room_w)
+                    pos = (self._rand_int(xL + 1, xR - 1), yB)
+                    self.grid.set(*pos, None)
+
+        # Randomize the player start position and orientation
+        if self._agent_default_pos is not None:
+            self.start_pos = self._agent_default_pos
+            self.grid.set(*self._agent_default_pos, None)
+            self.start_dir = self._rand_int(0, 4)  # assuming random start direction
+        else:
+            self.place_agent()
+
+        if self._goal_default_pos is not None:
+            goal = Goal()
+            self.grid.set(*self._goal_default_pos, goal)
+            goal.init_pos, goal.cur_pos = self._goal_default_pos
+        else:
+            self.place_obj(Goal())
+
+        self.mission = 'Reach the goal'
+
+    def step(self, action):
+        obs, reward, done, info = MiniGridEnv.step(self, action)
+        return obs, reward, done, info
+
+
+register(
+    id='MiniGrid-FourRooms-v0',
+    entry_point='gym_minigrid.envs:FourRoomsEnv'
+)

+ 1 - 1
gym_minigrid/minigrid.py

@@ -1308,4 +1308,4 @@ class MiniGridEnv(gym.Env):
         elif mode == 'pixmap':
             return r.getPixmap()
 
-        return r
+        return r

+ 36 - 0
gym_minigrid/wrappers.py

@@ -74,6 +74,42 @@ class StateBonus(gym.core.Wrapper):
 
         return obs, reward, done, info
 
+
+class ImgObsWrapper(gym.core.ObservationWrapper):
+    """
+    Use rgb image as the only observation output
+    """
+
+    def __init__(self, env):
+        super().__init__(env)
+        self.__dict__.update(vars(env))  # hack to pass values to super wrapper
+        self.observation_space = env.observation_space['image']
+
+    def observation(self, obs):
+        return obs['image']
+
+
+class FullyObsWrapper(gym.core.ObservationWrapper):
+    """
+    Fully observable gridworld
+    """
+
+    def __init__(self, env):
+        super().__init__(env)
+        self.__dict__.update(vars(env))  # hack to pass values to super wrapper
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(self.env.grid_size * 32, self.env.grid_size * 32, 3),  # number of cells
+            dtype='uint8'
+        )
+
+    def observation(self, obs):
+        if self.env.grid_render is None:
+            return np.zeros(shape=self.observation_space.shape)  # dark screen as init state?
+        return self.env.grid_render.getArray()
+
+
 class FlatObsWrapper(gym.core.ObservationWrapper):
     """
     Encode mission strings using a one-hot scheme,

+ 3 - 0
run_tests.py

@@ -86,3 +86,6 @@ for i in range(0, 500):
     assert agent_sees_goal == goal_visible
     if done:
         env.reset()
+
+#############################################################################
+