Explorar o código

Use correct shape for one hot encoding (#93)

Anders Thuesen %!s(int64=5) %!d(string=hai) anos
pai
achega
050ce008c8
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      gym_minigrid/wrappers.py

+ 1 - 1
gym_minigrid/wrappers.py

@@ -135,7 +135,7 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
 
     def observation(self, obs):
         img = obs['image']
-        out = np.zeros(self.observation_space.shape, dtype='uint8')
+        out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')
 
         for i in range(img.shape[0]):
             for j in range(img.shape[1]):