| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 | 
							- import math
 
- import operator
 
- from functools import reduce
 
- import numpy as np
 
- import gym
 
- from gym import error, spaces, utils
 
- class ActionBonus(gym.core.Wrapper):
 
-     """
 
-     Wrapper which adds an exploration bonus.
 
-     This is a reward to encourage exploration of less
 
-     visited (state,action) pairs.
 
-     """
 
-     def __init__(self, env):
 
-         super().__init__(env)
 
-         self.counts = {}
 
-     def step(self, action):
 
-         obs, reward, done, info = self.env.step(action)
 
-         env = self.unwrapped
 
-         tup = (env.agentPos, env.agentDir, action)
 
-         # Get the count for this (s,a) pair
 
-         preCnt = 0
 
-         if tup in self.counts:
 
-             preCnt = self.counts[tup]
 
-         # Update the count for this (s,a) pair
 
-         newCnt = preCnt + 1
 
-         self.counts[tup] = newCnt
 
-         bonus = 1 / math.sqrt(newCnt)
 
-         reward += bonus
 
-         return obs, reward, done, info
 
- class StateBonus(gym.core.Wrapper):
 
-     """
 
-     Adds an exploration bonus based on which positions
 
-     are visited on the grid.
 
-     """
 
-     def __init__(self, env):
 
-         super().__init__(env)
 
-         self.counts = {}
 
-     def step(self, action):
 
-         obs, reward, done, info = self.env.step(action)
 
-         # Tuple based on which we index the counts
 
-         # We use the position after an update
 
-         env = self.unwrapped
 
-         tup = (env.agentPos)
 
-         # Get the count for this key
 
-         preCnt = 0
 
-         if tup in self.counts:
 
-             preCnt = self.counts[tup]
 
-         # Update the count for this key
 
-         newCnt = preCnt + 1
 
-         self.counts[tup] = newCnt
 
-         bonus = 1 / math.sqrt(newCnt)
 
-         reward += bonus
 
-         return obs, reward, done, info
 
- class FlatObsWrapper(gym.core.ObservationWrapper):
 
-     """
 
-     Encode mission strings using a one-hot scheme,
 
-     and combine these with observed images into one flat array
 
-     """
 
-     def __init__(self, env, maxStrLen=64):
 
-         super().__init__(env)
 
-         self.maxStrLen = maxStrLen
 
-         self.numCharCodes = 27
 
-         imgSpace = env.observation_space.spaces['image']
 
-         imgSize = reduce(operator.mul, imgSpace.shape, 1)
 
-         self.observation_space = spaces.Box(
 
-             low=0,
 
-             high=255,
 
-             shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
 
-             dtype='uint8'
 
-         )
 
-         self.cachedStr = None
 
-         self.cachedArray = None
 
-     def observation(self, obs):
 
-         image = obs['image']
 
-         mission = obs['mission']
 
-         # Cache the last-encoded mission string
 
-         if mission != self.cachedStr:
 
-             assert len(mission) <= self.maxStrLen, "mission string too long"
 
-             mission = mission.lower()
 
-             strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
 
-             for idx, ch in enumerate(mission):
 
-                 if ch >= 'a' and ch <= 'z':
 
-                     chNo = ord(ch) - ord('a')
 
-                 elif ch == ' ':
 
-                     chNo = ord('z') - ord('a') + 1
 
-                 assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
 
-                 strArray[idx, chNo] = 1
 
-             self.cachedStr = mission
 
-             self.cachedArray = strArray
 
-         obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
 
-         return obs
 
 
  |