浏览代码

np.ndarray agent_pos type annotation

Rodrigo Perez-Vicente 2 年之前
父节点
当前提交
36e05852a8

+ 3 - 1
gym_minigrid/envs/crossing.py

@@ -1,5 +1,7 @@
 import itertools as itt
 import itertools as itt
 
 
+import numpy as np
+
 from gym_minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv, Wall
 from gym_minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv, Wall
 from gym_minigrid.register import register
 from gym_minigrid.register import register
 
 
@@ -30,7 +32,7 @@ class CrossingEnv(MiniGridEnv):
         self.grid.wall_rect(0, 0, width, height)
         self.grid.wall_rect(0, 0, width, height)
 
 
         # Place the agent in the top-left corner
         # Place the agent in the top-left corner
-        self.agent_pos = (1, 1)
+        self.agent_pos = np.array((1, 1))
         self.agent_dir = 0
         self.agent_dir = 0
 
 
         # Place a goal square in the bottom-right corner
         # Place a goal square in the bottom-right corner

+ 1 - 1
gym_minigrid/envs/lavagap.py

@@ -30,7 +30,7 @@ class LavaGapEnv(MiniGridEnv):
         self.grid.wall_rect(0, 0, width, height)
         self.grid.wall_rect(0, 0, width, height)
 
 
         # Place the agent in the top-left corner
         # Place the agent in the top-left corner
-        self.agent_pos = (1, 1)
+        self.agent_pos = np.array((1, 1))
         self.agent_dir = 0
         self.agent_dir = 0
 
 
         # Place a goal square in the bottom-right corner
         # Place a goal square in the bottom-right corner

+ 3 - 1
gym_minigrid/envs/memory.py

@@ -1,3 +1,5 @@
+import numpy as np
+
 from gym_minigrid.minigrid import Ball, Grid, Key, MiniGridEnv, Wall
 from gym_minigrid.minigrid import Ball, Grid, Key, MiniGridEnv, Wall
 from gym_minigrid.register import register
 from gym_minigrid.register import register
 
 
@@ -58,7 +60,7 @@ class MemoryEnv(MiniGridEnv):
             self.grid.set(hallway_end + 2, j, Wall())
             self.grid.set(hallway_end + 2, j, Wall())
 
 
         # Fix the player's start position and orientation
         # Fix the player's start position and orientation
-        self.agent_pos = (self._rand_int(1, hallway_end + 1), height // 2)
+        self.agent_pos = np.array((self._rand_int(1, hallway_end + 1), height // 2))
         self.agent_dir = 0
         self.agent_dir = 0
 
 
         # Place objects
         # Place objects

+ 4 - 7
gym_minigrid/minigrid.py

@@ -5,7 +5,6 @@ from enum import IntEnum
 
 
 import gym
 import gym
 import numpy as np
 import numpy as np
-import numpy.typing as npt
 from gym import spaces
 from gym import spaces
 
 
 # Size in pixels of a tile in the full-scale human view
 # Size in pixels of a tile in the full-scale human view
@@ -712,8 +711,8 @@ class MiniGridEnv(gym.Env):
         self.see_through_walls = see_through_walls
         self.see_through_walls = see_through_walls
 
 
         # Current position and direction of the agent
         # Current position and direction of the agent
-        self.agent_pos = None
-        self.agent_dir = None
+        self.agent_pos: np.ndarray = None
+        self.agent_dir: int = None
 
 
         # Initialize the state
         # Initialize the state
         self.reset()
         self.reset()
@@ -721,9 +720,8 @@ class MiniGridEnv(gym.Env):
     def reset(self, *, seed=None, return_info=False, options=None):
     def reset(self, *, seed=None, return_info=False, options=None):
         super().reset(seed=seed)
         super().reset(seed=seed)
         # Current position and direction of the agent
         # Current position and direction of the agent
-        NDArrayInt = npt.NDArray[np.int_]
-        self.agent_pos: NDArrayInt = None
-        self.agent_dir: int = None
+        self.agent_pos = None
+        self.agent_dir = None
 
 
         # Generate a new random grid at the start of each episode
         # Generate a new random grid at the start of each episode
         self._gen_grid(self.width, self.height)
         self._gen_grid(self.width, self.height)
@@ -1122,7 +1120,6 @@ class MiniGridEnv(gym.Env):
                 reward = self._reward()
                 reward = self._reward()
             if fwd_cell is not None and fwd_cell.type == "lava":
             if fwd_cell is not None and fwd_cell.type == "lava":
                 done = True
                 done = True
-
         # Pick up an object
         # Pick up an object
         elif action == self.actions.pickup:
         elif action == self.actions.pickup:
             if fwd_cell and fwd_cell.can_pickup():
             if fwd_cell and fwd_cell.can_pickup():

+ 7 - 3
gym_minigrid/roomgrid.py

@@ -1,3 +1,5 @@
+import numpy as np
+
 from gym_minigrid.minigrid import COLOR_NAMES, Ball, Box, Door, Grid, Key, MiniGridEnv
 from gym_minigrid.minigrid import COLOR_NAMES, Ball, Box, Door, Grid, Key, MiniGridEnv
 
 
 
 
@@ -163,9 +165,11 @@ class RoomGrid(MiniGridEnv):
                     room.door_pos[3] = room.neighbors[3].door_pos[1]
                     room.door_pos[3] = room.neighbors[3].door_pos[1]
 
 
         # The agent starts in the middle, facing right
         # The agent starts in the middle, facing right
-        self.agent_pos = (
-            (self.num_cols // 2) * (self.room_size - 1) + (self.room_size // 2),
-            (self.num_rows // 2) * (self.room_size - 1) + (self.room_size // 2),
+        self.agent_pos = np.array(
+            (
+                (self.num_cols // 2) * (self.room_size - 1) + (self.room_size // 2),
+                (self.num_rows // 2) * (self.room_size - 1) + (self.room_size // 2),
+            )
         )
         )
         self.agent_dir = 0
         self.agent_dir = 0
 
 

+ 9 - 9
gym_minigrid/wrappers.py

@@ -102,7 +102,7 @@ class StateBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
 
 
-class ImgObsWrapper(gym.core.ObservationWrapper):
+class ImgObsWrapper(gym.ObservationWrapper):
     """
     """
     Use the image as the only observation output, no language/mission.
     Use the image as the only observation output, no language/mission.
     """
     """
@@ -115,7 +115,7 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
         return obs["image"]
         return obs["image"]
 
 
 
 
-class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
+class OneHotPartialObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to get a one-hot encoding of a partially observable
     Wrapper to get a one-hot encoding of a partially observable
     agent view as observation.
     agent view as observation.
@@ -155,7 +155,7 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
         return {**obs, "image": out}
         return {**obs, "image": out}
 
 
 
 
-class RGBImgObsWrapper(gym.core.ObservationWrapper):
+class RGBImgObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to use fully observable RGB image as observation,
     Wrapper to use fully observable RGB image as observation,
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -187,7 +187,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
         return {**obs, "image": rgb_img}
         return {**obs, "image": rgb_img}
 
 
 
 
-class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
+class RGBImgPartialObsWrapper(gym.ObservationWrapper):
     """
     """
     Wrapper to use partially observable RGB image as observation.
     Wrapper to use partially observable RGB image as observation.
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -218,7 +218,7 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
         return {**obs, "image": rgb_img_partial}
         return {**obs, "image": rgb_img_partial}
 
 
 
 
-class FullyObsWrapper(gym.core.ObservationWrapper):
+class FullyObsWrapper(gym.ObservationWrapper):
     """
     """
     Fully observable gridworld using a compact grid encoding
     Fully observable gridworld using a compact grid encoding
     """
     """
@@ -247,7 +247,7 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
         return {**obs, "image": full_grid}
         return {**obs, "image": full_grid}
 
 
 
 
-class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
+class DictObservationSpaceWrapper(gym.ObservationWrapper):
     """
     """
     Transforms the observation space (that has a textual component) to a fully numerical observation space,
     Transforms the observation space (that has a textual component) to a fully numerical observation space,
     where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
     where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
@@ -365,7 +365,7 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class FlatObsWrapper(gym.core.ObservationWrapper):
+class FlatObsWrapper(gym.ObservationWrapper):
     """
     """
     Encode mission strings using a one-hot scheme,
     Encode mission strings using a one-hot scheme,
     and combine these with observed images into one flat array
     and combine these with observed images into one flat array
@@ -459,7 +459,7 @@ class ViewSizeWrapper(gym.Wrapper):
         return {**obs, "image": image}
         return {**obs, "image": image}
 
 
 
 
-class DirectionObsWrapper(gym.core.ObservationWrapper):
+class DirectionObsWrapper(gym.ObservationWrapper):
     """
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
     type = {slope , angle}
@@ -493,7 +493,7 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class SymbolicObsWrapper(gym.core.ObservationWrapper):
+class SymbolicObsWrapper(gym.ObservationWrapper):
     """
     """
     Fully observable grid with a symbolic state representation.
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are
     The symbol is a triple of (X, Y, IDX), where X and Y are