import math import operator from functools import reduce import numpy as np import gym from gym import error, spaces, utils class ActionBonus(gym.core.Wrapper): """ Wrapper which adds an exploration bonus. This is a reward to encourage exploration of less visited (state,action) pairs. """ def __init__(self, env): super().__init__(env) self.counts = {} def step(self, action): obs, reward, done, info = self.env.step(action) env = self.unwrapped tup = (env.agentPos, env.agentDir, action) # Get the count for this (s,a) pair preCnt = 0 if tup in self.counts: preCnt = self.counts[tup] # Update the count for this (s,a) pair newCnt = preCnt + 1 self.counts[tup] = newCnt bonus = 1 / math.sqrt(newCnt) reward += bonus return obs, reward, done, info class StateBonus(gym.core.Wrapper): """ Adds an exploration bonus based on which positions are visited on the grid. """ def __init__(self, env): super().__init__(env) self.counts = {} def step(self, action): obs, reward, done, info = self.env.step(action) # Tuple based on which we index the counts # We use the position after an update env = self.unwrapped tup = (env.agentPos) # Get the count for this key preCnt = 0 if tup in self.counts: preCnt = self.counts[tup] # Update the count for this key newCnt = preCnt + 1 self.counts[tup] = newCnt bonus = 1 / math.sqrt(newCnt) reward += bonus return obs, reward, done, info class ImgObsWrapper(gym.core.ObservationWrapper): """ Use rgb image as the only observation output """ def __init__(self, env): super().__init__(env) self.__dict__.update(vars(env)) # hack to pass values to super wrapper self.observation_space = env.observation_space.spaces['image'] def observation(self, obs): return obs['image'] class FullyObsWrapper(gym.core.ObservationWrapper): """ Fully observable gridworld using a compact grid encoding """ def __init__(self, env): super().__init__(env) self.__dict__.update(vars(env)) # hack to pass values to super wrapper self.observation_space = spaces.Box( low=0, high=self.env.grid_size, shape=(self.env.grid_size, self.env.grid_size, 3), # number of cells dtype='uint8' ) def observation(self, obs): full_grid = self.env.grid.encode() full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0]) return full_grid class FlatObsWrapper(gym.core.ObservationWrapper): """ Encode mission strings using a one-hot scheme, and combine these with observed images into one flat array """ def __init__(self, env, maxStrLen=64): super().__init__(env) self.maxStrLen = maxStrLen self.numCharCodes = 27 imgSpace = env.observation_space.spaces['image'] imgSize = reduce(operator.mul, imgSpace.shape, 1) self.observation_space = spaces.Box( low=0, high=255, shape=(1, imgSize + self.numCharCodes * self.maxStrLen), dtype='uint8' ) self.cachedStr = None self.cachedArray = None def observation(self, obs): image = obs['image'] mission = obs['mission'] # Cache the last-encoded mission string if mission != self.cachedStr: assert len(mission) <= self.maxStrLen, "mission string too long" mission = mission.lower() strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32') for idx, ch in enumerate(mission): if ch >= 'a' and ch <= 'z': chNo = ord(ch) - ord('a') elif ch == ' ': chNo = ord('z') - ord('a') + 1 assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo) strArray[idx, chNo] = 1 self.cachedStr = mission self.cachedArray = strArray obs = np.concatenate((image.flatten(), self.cachedArray.flatten())) return obs