|
@@ -1,5 +1,11 @@
|
|
|
import math
|
|
|
+import operator
|
|
|
+from functools import reduce
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+
|
|
|
import gym
|
|
|
+from gym import error, spaces, utils
|
|
|
|
|
|
class ActionBonus(gym.core.Wrapper):
|
|
|
"""
|
|
@@ -67,3 +73,52 @@ class StateBonus(gym.core.Wrapper):
|
|
|
reward += bonus
|
|
|
|
|
|
return obs, reward, done, info
|
|
|
+
|
|
|
+class FlatObsWrapper(gym.core.ObservationWrapper):
|
|
|
+ """
|
|
|
+ Encode mission strings using a one-hot scheme,
|
|
|
+ and combine these with observed images into one flat array
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, env, maxStrLen=48):
|
|
|
+ super().__init__(env)
|
|
|
+
|
|
|
+ self.maxStrLen = maxStrLen
|
|
|
+ self.numCharCodes = 27
|
|
|
+
|
|
|
+ obsSize = batch_numel = reduce(operator.mul, self.observation_space.shape, 1)
|
|
|
+
|
|
|
+ self.observation_space = spaces.Box(
|
|
|
+ low=0,
|
|
|
+ high=255,
|
|
|
+ shape=obsSize + self.numCharCodes * self.maxStrLen
|
|
|
+ )
|
|
|
+
|
|
|
+ self.cachedStr = None
|
|
|
+ self.cachedArray = None
|
|
|
+
|
|
|
+ def _observation(self, obs):
|
|
|
+ image = obs['image']
|
|
|
+ mission = obs['mission']
|
|
|
+
|
|
|
+ # Cache the last-encoded mission string
|
|
|
+ if mission != self.cachedStr:
|
|
|
+ assert len(mission) <= self.maxStrLen, "mission string too long"
|
|
|
+ mission = mission.lower()
|
|
|
+
|
|
|
+ strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes))
|
|
|
+
|
|
|
+ for idx, ch in enumerate(mission):
|
|
|
+ if ch >= 'a' and ch <= 'z':
|
|
|
+ chNo = ord(ch) - ord('a')
|
|
|
+ elif ch == ' ':
|
|
|
+ chNo = ord('z') - ord('a') + 1
|
|
|
+ assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
|
|
|
+ strArray[idx, chNo] = 1
|
|
|
+
|
|
|
+ self.cachedStr = mission
|
|
|
+ self.cachedArray = strArray
|
|
|
+
|
|
|
+ obs = np.hstack((image.flatten(), self.cachedArray.flatten()))
|
|
|
+
|
|
|
+ return obs
|