소스 검색

Use correct shape for one hot encoding (#93)

Anders Thuesen 5 년 전
부모
커밋
050ce008c8
1개의 변경된 파일1개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      gym_minigrid/wrappers.py

+ 1 - 1
gym_minigrid/wrappers.py

@@ -135,7 +135,7 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
 
     def observation(self, obs):
         img = obs['image']
-        out = np.zeros(self.observation_space.shape, dtype='uint8')
+        out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')
 
         for i in range(img.shape[0]):
             for j in range(img.shape[1]):