浏览代码

Use correct shape for one hot encoding (#93)

Anders Thuesen 5 年之前
父节点
当前提交
050ce008c8
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      gym_minigrid/wrappers.py

+ 1 - 1
gym_minigrid/wrappers.py

@@ -135,7 +135,7 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
 
     def observation(self, obs):
         img = obs['image']
-        out = np.zeros(self.observation_space.shape, dtype='uint8')
+        out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')
 
         for i in range(img.shape[0]):
             for j in range(img.shape[1]):