123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331 |
- import math
- import operator
- from functools import reduce
- import numpy as np
- import gym
- from gym import error, spaces, utils
- from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
- class ReseedWrapper(gym.core.Wrapper):
- """
- Wrapper to always regenerate an environment with the same set of seeds.
- This can be used to force an environment to always keep the same
- configuration when reset.
- """
- def __init__(self, env, seeds=[0], seed_idx=0):
- self.seeds = list(seeds)
- self.seed_idx = seed_idx
- super().__init__(env)
- def reset(self, **kwargs):
- seed = self.seeds[self.seed_idx]
- self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
- self.env.seed(seed)
- return self.env.reset(**kwargs)
- def step(self, action):
- obs, reward, done, info = self.env.step(action)
- return obs, reward, done, info
- class ActionBonus(gym.core.Wrapper):
- """
- Wrapper which adds an exploration bonus.
- This is a reward to encourage exploration of less
- visited (state,action) pairs.
- """
- def __init__(self, env):
- super().__init__(env)
- self.counts = {}
- def step(self, action):
- obs, reward, done, info = self.env.step(action)
- env = self.unwrapped
- tup = (tuple(env.agent_pos), env.agent_dir, action)
- # Get the count for this (s,a) pair
- pre_count = 0
- if tup in self.counts:
- pre_count = self.counts[tup]
- # Update the count for this (s,a) pair
- new_count = pre_count + 1
- self.counts[tup] = new_count
- bonus = 1 / math.sqrt(new_count)
- reward += bonus
- return obs, reward, done, info
- def reset(self, **kwargs):
- return self.env.reset(**kwargs)
- class StateBonus(gym.core.Wrapper):
- """
- Adds an exploration bonus based on which positions
- are visited on the grid.
- """
- def __init__(self, env):
- super().__init__(env)
- self.counts = {}
- def step(self, action):
- obs, reward, done, info = self.env.step(action)
- # Tuple based on which we index the counts
- # We use the position after an update
- env = self.unwrapped
- tup = (tuple(env.agent_pos))
- # Get the count for this key
- pre_count = 0
- if tup in self.counts:
- pre_count = self.counts[tup]
- # Update the count for this key
- new_count = pre_count + 1
- self.counts[tup] = new_count
- bonus = 1 / math.sqrt(new_count)
- reward += bonus
- return obs, reward, done, info
- def reset(self, **kwargs):
- return self.env.reset(**kwargs)
- class ImgObsWrapper(gym.core.ObservationWrapper):
- """
- Use the image as the only observation output, no language/mission.
- """
- def __init__(self, env):
- super().__init__(env)
- self.observation_space = env.observation_space.spaces['image']
- def observation(self, obs):
- return obs['image']
- class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
- """
- Wrapper to get a one-hot encoding of a partially observable
- agent view as observation.
- """
- def __init__(self, env, tile_size=8):
- super().__init__(env)
- self.tile_size = tile_size
- obs_shape = env.observation_space['image'].shape
- # Number of bits per cell
- num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
- self.observation_space.spaces["image"] = spaces.Box(
- low=0,
- high=255,
- shape=(obs_shape[0], obs_shape[1], num_bits),
- dtype='uint8'
- )
- def observation(self, obs):
- img = obs['image']
- out = np.zeros(self.observation_space.shape, dtype='uint8')
- for i in range(img.shape[0]):
- for j in range(img.shape[1]):
- type = img[i, j, 0]
- color = img[i, j, 1]
- state = img[i, j, 2]
- out[i, j, type] = 1
- out[i, j, len(OBJECT_TO_IDX) + color] = 1
- out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
- return {
- 'mission': obs['mission'],
- 'image': out
- }
- class RGBImgObsWrapper(gym.core.ObservationWrapper):
- """
- Wrapper to use fully observable RGB image as the only observation output,
- no language/mission. This can be used to have the agent to solve the
- gridworld in pixel space.
- """
- def __init__(self, env, tile_size=8):
- super().__init__(env)
- self.tile_size = tile_size
- self.observation_space.spaces['image'] = spaces.Box(
- low=0,
- high=255,
- shape=(self.env.width*tile_size, self.env.height*tile_size, 3),
- dtype='uint8'
- )
- def observation(self, obs):
- env = self.unwrapped
- rgb_img = env.render(
- mode='rgb_array',
- highlight=False,
- tile_size=self.tile_size
- )
- return {
- 'mission': obs['mission'],
- 'image': rgb_img
- }
- class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
- """
- Wrapper to use partially observable RGB image as the only observation output
- This can be used to have the agent to solve the gridworld in pixel space.
- """
- def __init__(self, env, tile_size=8):
- super().__init__(env)
- self.tile_size = tile_size
- obs_shape = env.observation_space['image'].shape
- self.observation_space.spaces['image'] = spaces.Box(
- low=0,
- high=255,
- shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
- dtype='uint8'
- )
- def observation(self, obs):
- env = self.unwrapped
- rgb_img_partial = env.get_obs_render(
- obs['image'],
- tile_size=self.tile_size
- )
- return {
- 'mission': obs['mission'],
- 'image': rgb_img_partial
- }
- class FullyObsWrapper(gym.core.ObservationWrapper):
- """
- Fully observable gridworld using a compact grid encoding
- """
- def __init__(self, env):
- super().__init__(env)
- self.observation_space.spaces["image"] = spaces.Box(
- low=0,
- high=255,
- shape=(self.env.width, self.env.height, 3), # number of cells
- dtype='uint8'
- )
- def observation(self, obs):
- env = self.unwrapped
- full_grid = env.grid.encode()
- full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
- OBJECT_TO_IDX['agent'],
- COLOR_TO_IDX['red'],
- env.agent_dir
- ])
- return {
- 'mission': obs['mission'],
- 'image': full_grid
- }
- class FlatObsWrapper(gym.core.ObservationWrapper):
- """
- Encode mission strings using a one-hot scheme,
- and combine these with observed images into one flat array
- """
- def __init__(self, env, maxStrLen=96):
- super().__init__(env)
- self.maxStrLen = maxStrLen
- self.numCharCodes = 27
- imgSpace = env.observation_space.spaces['image']
- imgSize = reduce(operator.mul, imgSpace.shape, 1)
- self.observation_space = spaces.Box(
- low=0,
- high=255,
- shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
- dtype='uint8'
- )
- self.cachedStr = None
- self.cachedArray = None
- def observation(self, obs):
- image = obs['image']
- mission = obs['mission']
- # Cache the last-encoded mission string
- if mission != self.cachedStr:
- assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission))
- mission = mission.lower()
- strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
- for idx, ch in enumerate(mission):
- if ch >= 'a' and ch <= 'z':
- chNo = ord(ch) - ord('a')
- elif ch == ' ':
- chNo = ord('z') - ord('a') + 1
- assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
- strArray[idx, chNo] = 1
- self.cachedStr = mission
- self.cachedArray = strArray
- obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
- return obs
- class ViewSizeWrapper(gym.core.Wrapper):
- """
- Wrapper to customize the agent field of view size.
- This cannot be used with fully observable wrappers.
- """
- def __init__(self, env, agent_view_size=7):
- super().__init__(env)
- # Override default view size
- env.unwrapped.agent_view_size = agent_view_size
- # Compute observation space with specified view size
- observation_space = gym.spaces.Box(
- low=0,
- high=255,
- shape=(agent_view_size, agent_view_size, 3),
- dtype='uint8'
- )
- # Override the environment's observation space
- self.observation_space = spaces.Dict({
- 'image': observation_space
- })
- def reset(self, **kwargs):
- return self.env.reset(**kwargs)
- def step(self, action):
- return self.env.step(action)
|