Browse Source

Merge branch 'master' of github.com:maximecb/gym-minigrid into maximecb-classifier1

Maxime Chevalier-Boisvert 5 năm trước cách đây
mục cha
commit
936cb68a68
4 tập tin đã thay đổi với 77 bổ sung2 xóa
  1. 1 0
      README.md
  2. 7 0
      gym_minigrid/minigrid.py
  3. 68 1
      gym_minigrid/wrappers.py
  4. 1 1
      setup.py

+ 1 - 0
README.md

@@ -29,6 +29,7 @@ Please use this bibtex if you want to cite this repository in your publications:
 ```
 
 List of publications & submissions using MiniGrid (please open a pull request to add missing entries):
+- [Learning Effective Subgoals with Multi-Task Hierarchical Reinforcement Learning](http://surl.tirl.info/proceedings/SURL-2019_paper_10.pdf) (Tsinghua University, August 2019)
 - [Learning distant cause and effect using only local and immediate credit assignment](https://arxiv.org/abs/1905.11589) (Incubator 491, May 2019)
 - [Learning World Graphs to Accelerate Hierarchical Reinforcement Learning](https://arxiv.org/abs/1907.00664) (Salesforce Research, 2019)
 - [Modeling the Long Term Future in Model-Based Reinforcement Learning](https://openreview.net/forum?id=SkgQBn0cF7) (Mila, ICLR 2019)

+ 7 - 0
gym_minigrid/minigrid.py

@@ -49,6 +49,13 @@ OBJECT_TO_IDX = {
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
 
+# Map of state names to integers
+STATE_TO_IDX = {
+    'open'  : 0,
+    'closed': 1,
+    'locked': 2,
+}
+
 # Map of agent direction indices to vectors
 DIR_TO_VEC = [
     # Pointing right (positive X)

+ 68 - 1
gym_minigrid/wrappers.py

@@ -5,7 +5,7 @@ from functools import reduce
 import numpy as np
 import gym
 from gym import error, spaces, utils
-from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
+from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
 from .minigrid import CELL_PIXELS
 
 class ReseedWrapper(gym.core.Wrapper):
@@ -112,6 +112,47 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
     def observation(self, obs):
         return obs['image']
 
+class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
+    """
+    Wrapper to get a one-hot encoding of a partially observable
+    agent view as observation.
+    """
+
+    def __init__(self, env, tile_size=8):
+        super().__init__(env)
+
+        self.tile_size = tile_size
+
+        obs_shape = env.observation_space['image'].shape
+
+        num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
+
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(obs_shape[0], obs_shape[1], num_bits),
+            dtype='uint8'
+        )
+
+    def observation(self, obs):
+        img = obs['image']
+        out = np.zeros(self.observation_space.shape, dtype='uint8')
+
+        for i in range(img.shape[0]):
+            for j in range(img.shape[1]):
+                type = img[i, j, 0]
+                color = img[i, j, 1]
+                state = img[i, j, 2]
+
+                out[i, j, type] = 1
+                out[i, j, len(OBJECT_TO_IDX) + color] = 1
+                out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + color] = 1
+
+        return {
+            'mission': obs['mission'],
+            'image': out
+        }
+
 class RGBImgObsWrapper(gym.core.ObservationWrapper):
     """
     Wrapper to use fully observable RGB image as the only observation output,
@@ -139,6 +180,32 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
             tile_size=self.tile_size
         )
 
+class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
+    """
+    Wrapper to use partially observable RGB image as the only observation output
+    This can be used to have the agent to solve the gridworld in pixel space.
+    """
+
+    def __init__(self, env, tile_size=8):
+        super().__init__(env)
+
+        self.tile_size = tile_size
+
+        obs_shape = env.observation_space['image'].shape
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
+            dtype='uint8'
+        )
+
+    def observation(self, obs):
+        env = self.unwrapped
+        return {
+            'mission': obs['mission'],
+            'image': env.get_obs_render(obs['image'], tile_size=self.tile_size, mode='rgb_array')
+        }
+
 class FullyObsWrapper(gym.core.ObservationWrapper):
     """
     Fully observable gridworld using a compact grid encoding

+ 1 - 1
setup.py

@@ -2,7 +2,7 @@ from setuptools import setup
 
 setup(
     name='gym_minigrid',
-    version='0.0.4',
+    version='0.0.5',
     keywords='memory, environment, agent, rl, openaigym, openai-gym, gym',
     url='https://github.com/maximecb/gym-minigrid',
     description='Minimalistic gridworld package for OpenAI Gym',