123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- import math
- import operator
- from functools import reduce
- import numpy as np
- import gym
- from gym import error, spaces, utils
- class ActionBonus(gym.core.Wrapper):
- """
- Wrapper which adds an exploration bonus.
- This is a reward to encourage exploration of less
- visited (state,action) pairs.
- """
- def __init__(self, env):
- super().__init__(env)
- self.counts = {}
- def step(self, action):
- obs, reward, done, info = self.env.step(action)
- env = self.unwrapped
- tup = (env.agentPos, env.agentDir, action)
- # Get the count for this (s,a) pair
- preCnt = 0
- if tup in self.counts:
- preCnt = self.counts[tup]
- # Update the count for this (s,a) pair
- newCnt = preCnt + 1
- self.counts[tup] = newCnt
- bonus = 1 / math.sqrt(newCnt)
- reward += bonus
- return obs, reward, done, info
- class StateBonus(gym.core.Wrapper):
- """
- Adds an exploration bonus based on which positions
- are visited on the grid.
- """
- def __init__(self, env):
- super().__init__(env)
- self.counts = {}
- def step(self, action):
- obs, reward, done, info = self.env.step(action)
- # Tuple based on which we index the counts
- # We use the position after an update
- env = self.unwrapped
- tup = (env.agentPos)
- # Get the count for this key
- preCnt = 0
- if tup in self.counts:
- preCnt = self.counts[tup]
- # Update the count for this key
- newCnt = preCnt + 1
- self.counts[tup] = newCnt
- bonus = 1 / math.sqrt(newCnt)
- reward += bonus
- return obs, reward, done, info
- class ImgObsWrapper(gym.core.ObservationWrapper):
- """
- Use rgb image as the only observation output
- """
- def __init__(self, env):
- super().__init__(env)
- self.__dict__.update(vars(env)) # hack to pass values to super wrapper
- self.observation_space = env.observation_space.spaces['image']
- def observation(self, obs):
- return obs['image']
- class FullyObsWrapper(gym.core.ObservationWrapper):
- """
- Fully observable gridworld using a compact grid encoding
- """
- def __init__(self, env):
- super().__init__(env)
- self.__dict__.update(vars(env)) # hack to pass values to super wrapper
- self.observation_space = spaces.Box(
- low=0,
- high=self.env.grid_size,
- shape=(self.env.grid_size, self.env.grid_size, 3), # number of cells
- dtype='uint8'
- )
- def observation(self, obs):
- full_grid = self.env.grid.encode()
- full_grid[self.env.agent_pos[0]][self.env.agent_pos[1]] = np.array([255, self.env.agent_dir, 0])
- return full_grid
- class FlatObsWrapper(gym.core.ObservationWrapper):
- """
- Encode mission strings using a one-hot scheme,
- and combine these with observed images into one flat array
- """
- def __init__(self, env, maxStrLen=64):
- super().__init__(env)
- self.maxStrLen = maxStrLen
- self.numCharCodes = 27
- imgSpace = env.observation_space.spaces['image']
- imgSize = reduce(operator.mul, imgSpace.shape, 1)
- self.observation_space = spaces.Box(
- low=0,
- high=255,
- shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
- dtype='uint8'
- )
- self.cachedStr = None
- self.cachedArray = None
- def observation(self, obs):
- image = obs['image']
- mission = obs['mission']
- # Cache the last-encoded mission string
- if mission != self.cachedStr:
- assert len(mission) <= self.maxStrLen, "mission string too long"
- mission = mission.lower()
- strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
- for idx, ch in enumerate(mission):
- if ch >= 'a' and ch <= 'z':
- chNo = ord(ch) - ord('a')
- elif ch == ' ':
- chNo = ord('z') - ord('a') + 1
- assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
- strArray[idx, chNo] = 1
- self.cachedStr = mission
- self.cachedArray = strArray
- obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
- return obs
|