Browse Source

np.ndarray agent_pos type annotation

Rodrigo Perez-Vicente 2 years ago
parent
commit
36e05852a8

+ 3 - 1
gym_minigrid/envs/crossing.py

@@ -1,5 +1,7 @@
 import itertools as itt
 
+import numpy as np
+
 from gym_minigrid.minigrid import Goal, Grid, Lava, MiniGridEnv, Wall
 from gym_minigrid.register import register
 
@@ -30,7 +32,7 @@ class CrossingEnv(MiniGridEnv):
         self.grid.wall_rect(0, 0, width, height)
 
         # Place the agent in the top-left corner
-        self.agent_pos = (1, 1)
+        self.agent_pos = np.array((1, 1))
         self.agent_dir = 0
 
         # Place a goal square in the bottom-right corner

+ 1 - 1
gym_minigrid/envs/lavagap.py

@@ -30,7 +30,7 @@ class LavaGapEnv(MiniGridEnv):
         self.grid.wall_rect(0, 0, width, height)
 
         # Place the agent in the top-left corner
-        self.agent_pos = (1, 1)
+        self.agent_pos = np.array((1, 1))
         self.agent_dir = 0
 
         # Place a goal square in the bottom-right corner

+ 3 - 1
gym_minigrid/envs/memory.py

@@ -1,3 +1,5 @@
+import numpy as np
+
 from gym_minigrid.minigrid import Ball, Grid, Key, MiniGridEnv, Wall
 from gym_minigrid.register import register
 
@@ -58,7 +60,7 @@ class MemoryEnv(MiniGridEnv):
             self.grid.set(hallway_end + 2, j, Wall())
 
         # Fix the player's start position and orientation
-        self.agent_pos = (self._rand_int(1, hallway_end + 1), height // 2)
+        self.agent_pos = np.array((self._rand_int(1, hallway_end + 1), height // 2))
         self.agent_dir = 0
 
         # Place objects

+ 4 - 7
gym_minigrid/minigrid.py

@@ -5,7 +5,6 @@ from enum import IntEnum
 
 import gym
 import numpy as np
-import numpy.typing as npt
 from gym import spaces
 
 # Size in pixels of a tile in the full-scale human view
@@ -712,8 +711,8 @@ class MiniGridEnv(gym.Env):
         self.see_through_walls = see_through_walls
 
         # Current position and direction of the agent
-        self.agent_pos = None
-        self.agent_dir = None
+        self.agent_pos: np.ndarray = None
+        self.agent_dir: int = None
 
         # Initialize the state
         self.reset()
@@ -721,9 +720,8 @@ class MiniGridEnv(gym.Env):
     def reset(self, *, seed=None, return_info=False, options=None):
         super().reset(seed=seed)
         # Current position and direction of the agent
-        NDArrayInt = npt.NDArray[np.int_]
-        self.agent_pos: NDArrayInt = None
-        self.agent_dir: int = None
+        self.agent_pos = None
+        self.agent_dir = None
 
         # Generate a new random grid at the start of each episode
         self._gen_grid(self.width, self.height)
@@ -1122,7 +1120,6 @@ class MiniGridEnv(gym.Env):
                 reward = self._reward()
             if fwd_cell is not None and fwd_cell.type == "lava":
                 done = True
-
         # Pick up an object
         elif action == self.actions.pickup:
             if fwd_cell and fwd_cell.can_pickup():

+ 7 - 3
gym_minigrid/roomgrid.py

@@ -1,3 +1,5 @@
+import numpy as np
+
 from gym_minigrid.minigrid import COLOR_NAMES, Ball, Box, Door, Grid, Key, MiniGridEnv
 
 
@@ -163,9 +165,11 @@ class RoomGrid(MiniGridEnv):
                     room.door_pos[3] = room.neighbors[3].door_pos[1]
 
         # The agent starts in the middle, facing right
-        self.agent_pos = (
-            (self.num_cols // 2) * (self.room_size - 1) + (self.room_size // 2),
-            (self.num_rows // 2) * (self.room_size - 1) + (self.room_size // 2),
+        self.agent_pos = np.array(
+            (
+                (self.num_cols // 2) * (self.room_size - 1) + (self.room_size // 2),
+                (self.num_rows // 2) * (self.room_size - 1) + (self.room_size // 2),
+            )
         )
         self.agent_dir = 0
 

+ 9 - 9
gym_minigrid/wrappers.py

@@ -102,7 +102,7 @@ class StateBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
 
 
-class ImgObsWrapper(gym.core.ObservationWrapper):
+class ImgObsWrapper(gym.ObservationWrapper):
     """
     Use the image as the only observation output, no language/mission.
     """
@@ -115,7 +115,7 @@ class ImgObsWrapper(gym.core.ObservationWrapper):
         return obs["image"]
 
 
-class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
+class OneHotPartialObsWrapper(gym.ObservationWrapper):
     """
     Wrapper to get a one-hot encoding of a partially observable
     agent view as observation.
@@ -155,7 +155,7 @@ class OneHotPartialObsWrapper(gym.core.ObservationWrapper):
         return {**obs, "image": out}
 
 
-class RGBImgObsWrapper(gym.core.ObservationWrapper):
+class RGBImgObsWrapper(gym.ObservationWrapper):
     """
     Wrapper to use fully observable RGB image as observation,
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -187,7 +187,7 @@ class RGBImgObsWrapper(gym.core.ObservationWrapper):
         return {**obs, "image": rgb_img}
 
 
-class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
+class RGBImgPartialObsWrapper(gym.ObservationWrapper):
     """
     Wrapper to use partially observable RGB image as observation.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -218,7 +218,7 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
         return {**obs, "image": rgb_img_partial}
 
 
-class FullyObsWrapper(gym.core.ObservationWrapper):
+class FullyObsWrapper(gym.ObservationWrapper):
     """
     Fully observable gridworld using a compact grid encoding
     """
@@ -247,7 +247,7 @@ class FullyObsWrapper(gym.core.ObservationWrapper):
         return {**obs, "image": full_grid}
 
 
-class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
+class DictObservationSpaceWrapper(gym.ObservationWrapper):
     """
     Transforms the observation space (that has a textual component) to a fully numerical observation space,
     where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
@@ -365,7 +365,7 @@ class DictObservationSpaceWrapper(gym.core.ObservationWrapper):
         return obs
 
 
-class FlatObsWrapper(gym.core.ObservationWrapper):
+class FlatObsWrapper(gym.ObservationWrapper):
     """
     Encode mission strings using a one-hot scheme,
     and combine these with observed images into one flat array
@@ -459,7 +459,7 @@ class ViewSizeWrapper(gym.Wrapper):
         return {**obs, "image": image}
 
 
-class DirectionObsWrapper(gym.core.ObservationWrapper):
+class DirectionObsWrapper(gym.ObservationWrapper):
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
@@ -493,7 +493,7 @@ class DirectionObsWrapper(gym.core.ObservationWrapper):
         return obs
 
 
-class SymbolicObsWrapper(gym.core.ObservationWrapper):
+class SymbolicObsWrapper(gym.ObservationWrapper):
     """
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are