import math import operator from functools import reduce import numpy as np import gym from gym import error, spaces, utils from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX class ReseedWrapper(gym.core.Wrapper): """ Wrapper to always regenerate an environment with the same set of seeds. This can be used to force an environment to always keep the same configuration when reset. """ def __init__(self, env, seeds=[0], seed_idx=0): self.seeds = list(seeds) self.seed_idx = seed_idx super().__init__(env) def reset(self, **kwargs): seed = self.seeds[self.seed_idx] self.seed_idx = (self.seed_idx + 1) % len(self.seeds) self.env.seed(seed) return self.env.reset(**kwargs) def step(self, action): obs, reward, done, info = self.env.step(action) return obs, reward, done, info class ActionBonus(gym.core.Wrapper): """ Wrapper which adds an exploration bonus. This is a reward to encourage exploration of less visited (state,action) pairs. """ def __init__(self, env): super().__init__(env) self.counts = {} def step(self, action): obs, reward, done, info = self.env.step(action) env = self.unwrapped tup = (tuple(env.agent_pos), env.agent_dir, action) # Get the count for this (s,a) pair pre_count = 0 if tup in self.counts: pre_count = self.counts[tup] # Update the count for this (s,a) pair new_count = pre_count + 1 self.counts[tup] = new_count bonus = 1 / math.sqrt(new_count) reward += bonus return obs, reward, done, info def reset(self, **kwargs): return self.env.reset(**kwargs) class StateBonus(gym.core.Wrapper): """ Adds an exploration bonus based on which positions are visited on the grid. """ def __init__(self, env): super().__init__(env) self.counts = {} def step(self, action): obs, reward, done, info = self.env.step(action) # Tuple based on which we index the counts # We use the position after an update env = self.unwrapped tup = (tuple(env.agent_pos)) # Get the count for this key pre_count = 0 if tup in self.counts: pre_count = self.counts[tup] # Update the count for this key new_count = pre_count + 1 self.counts[tup] = new_count bonus = 1 / math.sqrt(new_count) reward += bonus return obs, reward, done, info def reset(self, **kwargs): return self.env.reset(**kwargs) class ImgObsWrapper(gym.core.ObservationWrapper): """ Use the image as the only observation output, no language/mission. """ def __init__(self, env): super().__init__(env) self.observation_space = env.observation_space.spaces['image'] def observation(self, obs): return obs['image'] class OneHotPartialObsWrapper(gym.core.ObservationWrapper): """ Wrapper to get a one-hot encoding of a partially observable agent view as observation. """ def __init__(self, env, tile_size=8): super().__init__(env) self.tile_size = tile_size obs_shape = env.observation_space['image'].shape # Number of bits per cell num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX) self.observation_space.spaces["image"] = spaces.Box( low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype='uint8' ) def observation(self, obs): img = obs['image'] out = np.zeros(self.observation_space.shape, dtype='uint8') for i in range(img.shape[0]): for j in range(img.shape[1]): type = img[i, j, 0] color = img[i, j, 1] state = img[i, j, 2] out[i, j, type] = 1 out[i, j, len(OBJECT_TO_IDX) + color] = 1 out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1 return { 'mission': obs['mission'], 'image': out } class RGBImgObsWrapper(gym.core.ObservationWrapper): """ Wrapper to use fully observable RGB image as the only observation output, no language/mission. This can be used to have the agent to solve the gridworld in pixel space. """ def __init__(self, env, tile_size=8): super().__init__(env) self.tile_size = tile_size self.observation_space.spaces['image'] = spaces.Box( low=0, high=255, shape=(self.env.width*tile_size, self.env.height*tile_size, 3), dtype='uint8' ) def observation(self, obs): env = self.unwrapped rgb_img = env.render( mode='rgb_array', highlight=False, tile_size=self.tile_size ) return { 'mission': obs['mission'], 'image': rgb_img } class RGBImgPartialObsWrapper(gym.core.ObservationWrapper): """ Wrapper to use partially observable RGB image as the only observation output This can be used to have the agent to solve the gridworld in pixel space. """ def __init__(self, env, tile_size=8): super().__init__(env) self.tile_size = tile_size obs_shape = env.observation_space['image'].shape self.observation_space.spaces['image'] = spaces.Box( low=0, high=255, shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3), dtype='uint8' ) def observation(self, obs): env = self.unwrapped rgb_img_partial = env.get_obs_render( obs['image'], tile_size=self.tile_size ) return { 'mission': obs['mission'], 'image': rgb_img_partial } class FullyObsWrapper(gym.core.ObservationWrapper): """ Fully observable gridworld using a compact grid encoding """ def __init__(self, env): super().__init__(env) self.observation_space.spaces["image"] = spaces.Box( low=0, high=255, shape=(self.env.width, self.env.height, 3), # number of cells dtype='uint8' ) def observation(self, obs): env = self.unwrapped full_grid = env.grid.encode() full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([ OBJECT_TO_IDX['agent'], COLOR_TO_IDX['red'], env.agent_dir ]) return { 'mission': obs['mission'], 'image': full_grid } class FlatObsWrapper(gym.core.ObservationWrapper): """ Encode mission strings using a one-hot scheme, and combine these with observed images into one flat array """ def __init__(self, env, maxStrLen=96): super().__init__(env) self.maxStrLen = maxStrLen self.numCharCodes = 27 imgSpace = env.observation_space.spaces['image'] imgSize = reduce(operator.mul, imgSpace.shape, 1) self.observation_space = spaces.Box( low=0, high=255, shape=(1, imgSize + self.numCharCodes * self.maxStrLen), dtype='uint8' ) self.cachedStr = None self.cachedArray = None def observation(self, obs): image = obs['image'] mission = obs['mission'] # Cache the last-encoded mission string if mission != self.cachedStr: assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission)) mission = mission.lower() strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32') for idx, ch in enumerate(mission): if ch >= 'a' and ch <= 'z': chNo = ord(ch) - ord('a') elif ch == ' ': chNo = ord('z') - ord('a') + 1 assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo) strArray[idx, chNo] = 1 self.cachedStr = mission self.cachedArray = strArray obs = np.concatenate((image.flatten(), self.cachedArray.flatten())) return obs class ViewSizeWrapper(gym.core.Wrapper): """ Wrapper to customize the agent field of view size. This cannot be used with fully observable wrappers. """ def __init__(self, env, agent_view_size=7): super().__init__(env) # Override default view size env.unwrapped.agent_view_size = agent_view_size # Compute observation space with specified view size observation_space = gym.spaces.Box( low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype='uint8' ) # Override the environment's observation space self.observation_space = spaces.Dict({ 'image': observation_space }) def reset(self, **kwargs): return self.env.reset(**kwargs) def step(self, action): return self.env.step(action)