Browse Source

Added code to automatially use flat obs wrapper when needed

Maxime Chevalier-Boisvert 7 years ago
parent
commit
6db3f6bb87
3 changed files with 13 additions and 5 deletions
  1. 5 0
      gym_minigrid/envs/gotodoor.py
  2. 3 2
      gym_minigrid/wrappers.py
  3. 5 3
      pytorch_rl/envs.py

+ 5 - 0
gym_minigrid/envs/gotodoor.py

@@ -1,3 +1,4 @@
+from gym import spaces
 from gym_minigrid.minigrid import *
 from gym_minigrid.minigrid import *
 from gym_minigrid.register import register
 from gym_minigrid.register import register
 
 
@@ -14,6 +15,10 @@ class GoToDoorEnv(MiniGridEnv):
         assert size >= 5
         assert size >= 5
         super().__init__(gridSize=size, maxSteps=10*size)
         super().__init__(gridSize=size, maxSteps=10*size)
 
 
+        self.observation_space = spaces.Dict({
+            'image': self.observation_space
+        })
+
         self.reward_range = (-1, 1)
         self.reward_range = (-1, 1)
 
 
     def _genGrid(self, width, height):
     def _genGrid(self, width, height):

+ 3 - 2
gym_minigrid/wrappers.py

@@ -86,12 +86,13 @@ class FlatObsWrapper(gym.core.ObservationWrapper):
         self.maxStrLen = maxStrLen
         self.maxStrLen = maxStrLen
         self.numCharCodes = 27
         self.numCharCodes = 27
 
 
-        obsSize = batch_numel = reduce(operator.mul, self.observation_space.shape, 1)
+        imgSpace = env.observation_space.spaces['image']
+        imgSize = reduce(operator.mul, imgSpace.shape, 1)
 
 
         self.observation_space = spaces.Box(
         self.observation_space = spaces.Box(
             low=0,
             low=0,
             high=255,
             high=255,
-            shape=obsSize + self.numCharCodes * self.maxStrLen
+            shape=imgSize + self.numCharCodes * self.maxStrLen
         )
         )
 
 
         self.cachedStr = None
         self.cachedStr = None

+ 5 - 3
pytorch_rl/envs.py

@@ -1,7 +1,7 @@
 import os
 import os
 import numpy
 import numpy
 import gym
 import gym
-from gym.spaces.box import Box
+from gym import spaces
 
 
 try:
 try:
     import gym_minigrid
     import gym_minigrid
@@ -15,7 +15,9 @@ def make_env(env_id, seed, rank, log_dir):
 
 
         env.seed(seed + rank)
         env.seed(seed + rank)
 
 
-        #env = FlatObsWrapper(env)
+        # Maxime: until RL code supports dict observations, squash observations into a flat vector
+        if isinstance(env.observation_space, spaces.Dict):
+            env = FlatObsWrapper(env)
 
 
         # If the input has shape (W,H,3), wrap for PyTorch convolutions
         # If the input has shape (W,H,3), wrap for PyTorch convolutions
         obs_shape = env.observation_space.shape
         obs_shape = env.observation_space.shape
@@ -30,7 +32,7 @@ class WrapPyTorch(gym.ObservationWrapper):
     def __init__(self, env=None):
     def __init__(self, env=None):
         super(WrapPyTorch, self).__init__(env)
         super(WrapPyTorch, self).__init__(env)
         obs_shape = self.observation_space.shape
         obs_shape = self.observation_space.shape
-        self.observation_space = Box(
+        self.observation_space = spaces.Box(
             self.observation_space.low[0,0,0],
             self.observation_space.low[0,0,0],
             self.observation_space.high[0,0,0],
             self.observation_space.high[0,0,0],
             [obs_shape[2], obs_shape[1], obs_shape[0]]
             [obs_shape[2], obs_shape[1], obs_shape[0]]