Browse Source

fix pyright error with wrappers + raise error instead of string + use right registry

saleml 2 years ago
parent
commit
f6a103d05b
4 changed files with 22 additions and 20 deletions
  1. 1 1
      gym_minigrid/minigrid.py
  2. 2 1
      gym_minigrid/roomgrid.py
  3. 14 13
      gym_minigrid/wrappers.py
  4. 5 5
      run_tests.py

+ 1 - 1
gym_minigrid/minigrid.py

@@ -1,7 +1,7 @@
 import hashlib
-from abc import abstractmethod
 import math
 import string
+from abc import abstractmethod
 from enum import IntEnum
 
 import gym

+ 2 - 1
gym_minigrid/roomgrid.py

@@ -209,7 +209,8 @@ class RoomGrid(MiniGridEnv):
             obj = Box(color)
         else:
             raise ValueError(
-                f"{kind} object kind is not available in this environment.")
+                f"{kind} object kind is not available in this environment."
+            )
 
         return self.place_in_room(i, j, obj)
 

+ 14 - 13
gym_minigrid/wrappers.py

@@ -5,11 +5,12 @@ from functools import reduce
 import gym
 import numpy as np
 from gym import spaces
+from gym.core import ObservationWrapper, Wrapper
 
 from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
 
 
-class ReseedWrapper(gym.Wrapper):
+class ReseedWrapper(Wrapper):
     """
     Wrapper to always regenerate an environment with the same set of seeds.
     This can be used to force an environment to always keep the same
@@ -31,7 +32,7 @@ class ReseedWrapper(gym.Wrapper):
         return obs, reward, done, info
 
 
-class ActionBonus(gym.Wrapper):
+class ActionBonus(Wrapper):
     """
     Wrapper which adds an exploration bonus.
     This is a reward to encourage exploration of less
@@ -66,7 +67,7 @@ class ActionBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
 
 
-class StateBonus(gym.Wrapper):
+class StateBonus(Wrapper):
     """
     Adds an exploration bonus based on which positions
     are visited on the grid.
@@ -102,7 +103,7 @@ class StateBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
 
 
-class ImgObsWrapper(gym.ObservationWrapper):
+class ImgObsWrapper(ObservationWrapper):
     """
     Use the image as the only observation output, no language/mission.
     """
@@ -115,7 +116,7 @@ class ImgObsWrapper(gym.ObservationWrapper):
         return obs["image"]
 
 
-class OneHotPartialObsWrapper(gym.ObservationWrapper):
+class OneHotPartialObsWrapper(ObservationWrapper):
     """
     Wrapper to get a one-hot encoding of a partially observable
     agent view as observation.
@@ -155,7 +156,7 @@ class OneHotPartialObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": out}
 
 
-class RGBImgObsWrapper(gym.ObservationWrapper):
+class RGBImgObsWrapper(ObservationWrapper):
     """
     Wrapper to use fully observable RGB image as observation,
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -187,7 +188,7 @@ class RGBImgObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": rgb_img}
 
 
-class RGBImgPartialObsWrapper(gym.ObservationWrapper):
+class RGBImgPartialObsWrapper(ObservationWrapper):
     """
     Wrapper to use partially observable RGB image as observation.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -218,7 +219,7 @@ class RGBImgPartialObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": rgb_img_partial}
 
 
-class FullyObsWrapper(gym.ObservationWrapper):
+class FullyObsWrapper(ObservationWrapper):
     """
     Fully observable gridworld using a compact grid encoding
     """
@@ -247,7 +248,7 @@ class FullyObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": full_grid}
 
 
-class DictObservationSpaceWrapper(gym.ObservationWrapper):
+class DictObservationSpaceWrapper(ObservationWrapper):
     """
     Transforms the observation space (that has a textual component) to a fully numerical observation space,
     where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
@@ -365,7 +366,7 @@ class DictObservationSpaceWrapper(gym.ObservationWrapper):
         return obs
 
 
-class FlatObsWrapper(gym.ObservationWrapper):
+class FlatObsWrapper(ObservationWrapper):
     """
     Encode mission strings using a one-hot scheme,
     and combine these with observed images into one flat array
@@ -424,7 +425,7 @@ class FlatObsWrapper(gym.ObservationWrapper):
         return obs
 
 
-class ViewSizeWrapper(gym.Wrapper):
+class ViewSizeWrapper(Wrapper):
     """
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.
@@ -459,7 +460,7 @@ class ViewSizeWrapper(gym.Wrapper):
         return {**obs, "image": image}
 
 
-class DirectionObsWrapper(gym.ObservationWrapper):
+class DirectionObsWrapper(ObservationWrapper):
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
@@ -493,7 +494,7 @@ class DirectionObsWrapper(gym.ObservationWrapper):
         return obs
 
 
-class SymbolicObsWrapper(gym.ObservationWrapper):
+class SymbolicObsWrapper(ObservationWrapper):
     """
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are

+ 5 - 5
run_tests.py

@@ -5,10 +5,10 @@ import random
 import gym
 import numpy as np
 from gym import spaces
+from gym.envs.registration import registry
 
-from gym_minigrid.envs.empty import EmptyEnv5x5
+from gym_minigrid.envs.empty import EmptyEnv
 from gym_minigrid.minigrid import Grid
-from gym_minigrid.register import env_list
 from gym_minigrid.wrappers import (
     DictObservationSpaceWrapper,
     FlatObsWrapper,
@@ -21,7 +21,7 @@ from gym_minigrid.wrappers import (
     ViewSizeWrapper,
 )
 
-# Test importing wrappers
+env_list = [key for key in registry.keys() if key.startswith("MiniGrid")]
 
 
 print("%d environments registered" % len(env_list))
@@ -151,13 +151,13 @@ for env_idx, env_name in enumerate(env_list):
 print("testing extra observations")
 
 
-class EmptyEnvWithExtraObs(EmptyEnv5x5):
+class EmptyEnvWithExtraObs(EmptyEnv):
     """
     Custom environment with an extra observation
     """
 
     def __init__(self, **kwargs) -> None:
-        super().__init__(**kwargs)
+        super().__init__(size=5, **kwargs)
         self.observation_space["size"] = spaces.Box(
             low=0,
             high=1000,  # gym does not like np.iinfo(np.uint).max,