|
@@ -3,6 +3,7 @@
|
|
|
import random
|
|
|
import numpy as np
|
|
|
import gym
|
|
|
+import gym_minigrid
|
|
|
from gym_minigrid.register import env_list
|
|
|
from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
|
|
|
|
|
@@ -133,6 +134,61 @@ for env_idx, env_name in enumerate(env_list):
|
|
|
|
|
|
##############################################################################
|
|
|
|
|
|
+print('testing extra observations')
|
|
|
+class EmptyEnvWithExtraObs(gym_minigrid.envs.EmptyEnv5x5):
|
|
|
+ """
|
|
|
+ Custom environment with an extra observation
|
|
|
+ """
|
|
|
+ def __init__(self) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.observation_space['size'] = spaces.Box(
|
|
|
+ low=0,
|
|
|
+ high=np.iinfo(np.uint).max,
|
|
|
+ shape=(2,),
|
|
|
+ dtype=np.uint
|
|
|
+ )
|
|
|
+
|
|
|
+ def reset(self):
|
|
|
+ obs = super().reset()
|
|
|
+ obs['size'] = np.array([self.width, self.height])
|
|
|
+ return obs
|
|
|
+
|
|
|
+ def step(self, action):
|
|
|
+ obs, reward, done, info = super().step(action)
|
|
|
+ obs['size'] = np.array([self.width, self.height])
|
|
|
+ return obs, reward, done, info
|
|
|
+
|
|
|
+wrappers = [
|
|
|
+ OneHotPartialObsWrapper,
|
|
|
+ RGBImgObsWrapper,
|
|
|
+ RGBImgPartialObsWrapper,
|
|
|
+ FullyObsWrapper,
|
|
|
+]
|
|
|
+for wrapper in wrappers:
|
|
|
+ env1 = wrapper(EmptyEnvWithExtraObs())
|
|
|
+ env2 = wrapper(gym.make('MiniGrid-Empty-5x5-v0'))
|
|
|
+
|
|
|
+ env1.seed(0)
|
|
|
+ env2.seed(0)
|
|
|
+
|
|
|
+ obs1 = env1.reset()
|
|
|
+ obs2 = env2.reset()
|
|
|
+ assert 'size' in obs1
|
|
|
+ assert obs1['size'].shape == (2,)
|
|
|
+ assert (obs1['size'] == [5,5]).all()
|
|
|
+ for key in obs2:
|
|
|
+ assert np.array_equal(obs1[key], obs2[key])
|
|
|
+
|
|
|
+ obs1, reward1, done1, _ = env1.step(0)
|
|
|
+ obs2, reward2, done2, _ = env2.step(0)
|
|
|
+ assert 'size' in obs1
|
|
|
+ assert obs1['size'].shape == (2,)
|
|
|
+ assert (obs1['size'] == [5,5]).all()
|
|
|
+ for key in obs2:
|
|
|
+ assert np.array_equal(obs1[key], obs2[key])
|
|
|
+
|
|
|
+##############################################################################
|
|
|
+
|
|
|
print('testing agent_sees method')
|
|
|
env = gym.make('MiniGrid-DoorKey-6x6-v0')
|
|
|
goal_pos = (env.grid.width - 2, env.grid.height - 2)
|