| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 | import mathimport operatorfrom functools import reduceimport numpy as npimport gymfrom gym import error, spaces, utilsclass ActionBonus(gym.core.Wrapper):    """    Wrapper which adds an exploration bonus.    This is a reward to encourage exploration of less    visited (state,action) pairs.    """    def __init__(self, env):        super().__init__(env)        self.counts = {}    def step(self, action):        obs, reward, done, info = self.env.step(action)        env = self.unwrapped        tup = (env.agentPos, env.agentDir, action)        # Get the count for this (s,a) pair        preCnt = 0        if tup in self.counts:            preCnt = self.counts[tup]        # Update the count for this (s,a) pair        newCnt = preCnt + 1        self.counts[tup] = newCnt        bonus = 1 / math.sqrt(newCnt)        reward += bonus        return obs, reward, done, infoclass StateBonus(gym.core.Wrapper):    """    Adds an exploration bonus based on which positions    are visited on the grid.    """    def __init__(self, env):        super().__init__(env)        self.counts = {}    def step(self, action):        obs, reward, done, info = self.env.step(action)        # Tuple based on which we index the counts        # We use the position after an update        env = self.unwrapped        tup = (env.agentPos)        # Get the count for this key        preCnt = 0        if tup in self.counts:            preCnt = self.counts[tup]        # Update the count for this key        newCnt = preCnt + 1        self.counts[tup] = newCnt        bonus = 1 / math.sqrt(newCnt)        reward += bonus        return obs, reward, done, infoclass ImgObsWrapper(gym.core.ObservationWrapper):    """    Use the image as the only observation output, no language/mission.    """    def __init__(self, env):        super().__init__(env)        # Hack to pass values to super wrapper        self.__dict__.update(vars(env))        self.observation_space = env.observation_space.spaces['image']    def observation(self, obs):        return obs['image']class FullyObsWrapper(gym.core.ObservationWrapper):    """    Fully observable gridworld using a compact grid encoding    """    def __init__(self, env):        super().__init__(env)        self.__dict__.update(vars(env))  # hack to pass values to super wrapper        self.observation_space = spaces.Box(            low=0,            high=255,            shape=(self.env.width, self.env.height, 3),  # number of cells            dtype='uint8'        )    def observation(self, obs):        full_grid = self.env.grid.encode()        full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])        return full_gridclass FlatObsWrapper(gym.core.ObservationWrapper):    """    Encode mission strings using a one-hot scheme,    and combine these with observed images into one flat array    """    def __init__(self, env, maxStrLen=64):        super().__init__(env)        self.maxStrLen = maxStrLen        self.numCharCodes = 27        imgSpace = env.observation_space.spaces['image']        imgSize = reduce(operator.mul, imgSpace.shape, 1)        self.observation_space = spaces.Box(            low=0,            high=255,            shape=(1, imgSize + self.numCharCodes * self.maxStrLen),            dtype='uint8'        )        self.cachedStr = None        self.cachedArray = None    def observation(self, obs):        image = obs['image']        mission = obs['mission']        # Cache the last-encoded mission string        if mission != self.cachedStr:            assert len(mission) <= self.maxStrLen, "mission string too long"            mission = mission.lower()            strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')            for idx, ch in enumerate(mission):                if ch >= 'a' and ch <= 'z':                    chNo = ord(ch) - ord('a')                elif ch == ' ':                    chNo = ord('z') - ord('a') + 1                assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)                strArray[idx, chNo] = 1            self.cachedStr = mission            self.cachedArray = strArray        obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))        return obs
 |