Преглед на файлове

Implemented one-hot observation wrapper

Maxime Chevalier-Boisvert преди 5 години
родител
ревизия
9c57465a0a
променени са 2 файла, в които са добавени 49 реда и са изтрити 1 реда
  1. 7 0
      gym_minigrid/minigrid.py
  2. 42 1
      gym_minigrid/wrappers.py

+ 7 - 0
gym_minigrid/minigrid.py

@@ -49,6 +49,13 @@ OBJECT_TO_IDX = {
 
 IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
 
+# Map of state names to integers
+STATE_TO_IDX = {
+    'open'  : 0,
+    'closed': 1,
+    'locked': 2,
+}
+
 # Map of agent direction indices to vectors
 DIR_TO_VEC = [
     # Pointing right (positive X)

+ 42 - 1
gym_minigrid/wrappers.py

@@ -5,7 +5,7 @@ from functools import reduce
 import numpy as np
 import gym
 from gym import error, spaces, utils
-from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
+from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
 from .minigrid import CELL_PIXELS
 
 class ReseedWrapper(gym.core.Wrapper):
@@ -112,6 +112,47 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
     def observation(self, obs):
         return obs['image']
 
+class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
+    """
+    Wrapper to get a one-hot encoding of a partially observable
+    agent view as observation.
+    """
+
+    def __init__(self, env, tile_size=8):
+        super().__init__(env)
+
+        self.tile_size = tile_size
+
+        obs_shape = env.observation_space['image'].shape
+
+        num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
+
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=(obs_shape[0], obs_shape[1], num_bits),
+            dtype='uint8'
+        )
+
+    def observation(self, obs):
+        img = obs['image']
+        out = np.zeros(self.observation_space.shape, dtype='uint8')
+
+        for i in range(img.shape[0]):
+            for j in range(img.shape[1]):
+                type = img[i, j, 0]
+                color = img[i, j, 1]
+                state = img[i, j, 2]
+
+                out[i, j, type] = 1
+                out[i, j, len(OBJECT_TO_IDX) + color] = 1
+                out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + color] = 1
+
+        return {
+            'mission': obs['mission'],
+            'image': out
+        }
+
 class RGBImgObsWrapper(gym.core.ObservationWrapper):
     """
     Wrapper to use fully observable RGB image as the only observation output,