浏览代码

Fix `FlatObsWrapper` obs dtype (#435)

Mark Towers 10 月之前
父节点
当前提交
6762cb1f5c
共有 1 个文件被更改,包括 9 次插入12 次删除
  1. 9 12
      minigrid/wrappers.py

+ 9 - 12
minigrid/wrappers.py

@@ -569,19 +569,17 @@ class FlatObsWrapper(ObservationWrapper):
         (2835,)
     """
 
-    def __init__(self, env, maxStrLen=96):
+    def __init__(self, env, maxStrLen: int = 96):
         super().__init__(env)
 
         self.maxStrLen = maxStrLen
         self.numCharCodes = 28
 
-        imgSpace = env.observation_space.spaces["image"]
-        imgSize = reduce(operator.mul, imgSpace.shape, 1)
-
+        img_size = np.prod(env.observation_space["image"].shape)
         self.observation_space = spaces.Box(
             low=0,
             high=255,
-            shape=(imgSize + self.numCharCodes * self.maxStrLen,),
+            shape=(img_size + self.numCharCodes * self.maxStrLen,),
             dtype="uint8",
         )
 
@@ -598,12 +596,11 @@ class FlatObsWrapper(ObservationWrapper):
             ), f"mission string too long ({len(mission)} chars)"
             mission = mission.lower()
 
-            strArray = np.zeros(
-                shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
-            )
+            str_array = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype="uint8")
+            # as `numCharCodes` < 255 then we can use `uint8`
 
             for idx, ch in enumerate(mission):
-                if ch >= "a" and ch <= "z":
+                if "a" <= ch <= "z":
                     chNo = ord(ch) - ord("a")
                 elif ch == " ":
                     chNo = ord("z") - ord("a") + 1
@@ -613,11 +610,11 @@ class FlatObsWrapper(ObservationWrapper):
                     raise ValueError(
                         f"Character {ch} is not available in mission string."
                     )
-                assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
-                strArray[idx, chNo] = 1
+                assert chNo < self.numCharCodes, f"{ch} : {chNo:d}"
+                str_array[idx, chNo] = 1
 
             self.cachedStr = mission
-            self.cachedArray = strArray
+            self.cachedArray = str_array
 
         obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))