|
@@ -569,19 +569,17 @@ class FlatObsWrapper(ObservationWrapper):
|
|
|
(2835,)
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, env, maxStrLen=96):
|
|
|
+ def __init__(self, env, maxStrLen: int = 96):
|
|
|
super().__init__(env)
|
|
|
|
|
|
self.maxStrLen = maxStrLen
|
|
|
self.numCharCodes = 28
|
|
|
|
|
|
- imgSpace = env.observation_space.spaces["image"]
|
|
|
- imgSize = reduce(operator.mul, imgSpace.shape, 1)
|
|
|
-
|
|
|
+ img_size = np.prod(env.observation_space["image"].shape)
|
|
|
self.observation_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
- shape=(imgSize + self.numCharCodes * self.maxStrLen,),
|
|
|
+ shape=(img_size + self.numCharCodes * self.maxStrLen,),
|
|
|
dtype="uint8",
|
|
|
)
|
|
|
|
|
@@ -598,12 +596,11 @@ class FlatObsWrapper(ObservationWrapper):
|
|
|
), f"mission string too long ({len(mission)} chars)"
|
|
|
mission = mission.lower()
|
|
|
|
|
|
- strArray = np.zeros(
|
|
|
- shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
|
|
|
- )
|
|
|
+ str_array = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype="uint8")
|
|
|
+ # as `numCharCodes` < 255 then we can use `uint8`
|
|
|
|
|
|
for idx, ch in enumerate(mission):
|
|
|
- if ch >= "a" and ch <= "z":
|
|
|
+ if "a" <= ch <= "z":
|
|
|
chNo = ord(ch) - ord("a")
|
|
|
elif ch == " ":
|
|
|
chNo = ord("z") - ord("a") + 1
|
|
@@ -613,11 +610,11 @@ class FlatObsWrapper(ObservationWrapper):
|
|
|
raise ValueError(
|
|
|
f"Character {ch} is not available in mission string."
|
|
|
)
|
|
|
- assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
|
|
|
- strArray[idx, chNo] = 1
|
|
|
+ assert chNo < self.numCharCodes, f"{ch} : {chNo:d}"
|
|
|
+ str_array[idx, chNo] = 1
|
|
|
|
|
|
self.cachedStr = mission
|
|
|
- self.cachedArray = strArray
|
|
|
+ self.cachedArray = str_array
|
|
|
|
|
|
obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
|
|
|
|