|
@@ -3,9 +3,9 @@ import operator
|
|
|
from functools import reduce
|
|
|
|
|
|
import numpy as np
|
|
|
-
|
|
|
import gym
|
|
|
from gym import error, spaces, utils
|
|
|
+from .minigrid import OBJECT_TO_IDX
|
|
|
|
|
|
class ActionBonus(gym.core.Wrapper):
|
|
|
"""
|
|
@@ -111,7 +111,12 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
|
|
|
def observation(self, obs):
|
|
|
env = self.unwrapped
|
|
|
full_grid = env.grid.encode()
|
|
|
- full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([255, env.agent_dir, 0])
|
|
|
+ full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
|
|
|
+ OBJECT_TO_IDX['agent'],
|
|
|
+ env.agent_dir,
|
|
|
+ 0
|
|
|
+ ])
|
|
|
+
|
|
|
return full_grid
|
|
|
|
|
|
class FlatObsWrapper(gym.core.ObservationWrapper):
|