浏览代码

fix pyright error with wrappers + raise error instead of string + use right registry

saleml 2 年之前
父节点
当前提交
f6a103d05b
共有 4 个文件被更改,包括 22 次插入20 次删除
  1. 1 1
      gym_minigrid/minigrid.py
  2. 2 1
      gym_minigrid/roomgrid.py
  3. 14 13
      gym_minigrid/wrappers.py
  4. 5 5
      run_tests.py

+ 1 - 1
gym_minigrid/minigrid.py

@@ -1,7 +1,7 @@
 import hashlib
 import hashlib
-from abc import abstractmethod
 import math
 import math
 import string
 import string
+from abc import abstractmethod
 from enum import IntEnum
 from enum import IntEnum
 
 
 import gym
 import gym

+ 2 - 1
gym_minigrid/roomgrid.py

@@ -209,7 +209,8 @@ class RoomGrid(MiniGridEnv):
             obj = Box(color)
             obj = Box(color)
         else:
         else:
             raise ValueError(
             raise ValueError(
-                f"{kind} object kind is not available in this environment.")
+                f"{kind} object kind is not available in this environment."
+            )
 
 
         return self.place_in_room(i, j, obj)
         return self.place_in_room(i, j, obj)
 
 

+ 14 - 13
gym_minigrid/wrappers.py

@@ -5,11 +5,12 @@ from functools import reduce
 import gym
 import gym
 import numpy as np
 import numpy as np
 from gym import spaces
 from gym import spaces
+from gym.core import ObservationWrapper, Wrapper
 
 
 from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
 from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, Goal
 
 
 
 
-class ReseedWrapper(gym.Wrapper):
+class ReseedWrapper(Wrapper):
     """
     """
     Wrapper to always regenerate an environment with the same set of seeds.
     Wrapper to always regenerate an environment with the same set of seeds.
     This can be used to force an environment to always keep the same
     This can be used to force an environment to always keep the same
@@ -31,7 +32,7 @@ class ReseedWrapper(gym.Wrapper):
         return obs, reward, done, info
         return obs, reward, done, info
 
 
 
 
-class ActionBonus(gym.Wrapper):
+class ActionBonus(Wrapper):
     """
     """
     Wrapper which adds an exploration bonus.
     Wrapper which adds an exploration bonus.
     This is a reward to encourage exploration of less
     This is a reward to encourage exploration of less
@@ -66,7 +67,7 @@ class ActionBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
 
 
-class StateBonus(gym.Wrapper):
+class StateBonus(Wrapper):
     """
     """
     Adds an exploration bonus based on which positions
     Adds an exploration bonus based on which positions
     are visited on the grid.
     are visited on the grid.
@@ -102,7 +103,7 @@ class StateBonus(gym.Wrapper):
         return self.env.reset(**kwargs)
         return self.env.reset(**kwargs)
 
 
 
 
-class ImgObsWrapper(gym.ObservationWrapper):
+class ImgObsWrapper(ObservationWrapper):
     """
     """
     Use the image as the only observation output, no language/mission.
     Use the image as the only observation output, no language/mission.
     """
     """
@@ -115,7 +116,7 @@ class ImgObsWrapper(gym.ObservationWrapper):
         return obs["image"]
         return obs["image"]
 
 
 
 
-class OneHotPartialObsWrapper(gym.ObservationWrapper):
+class OneHotPartialObsWrapper(ObservationWrapper):
     """
     """
     Wrapper to get a one-hot encoding of a partially observable
     Wrapper to get a one-hot encoding of a partially observable
     agent view as observation.
     agent view as observation.
@@ -155,7 +156,7 @@ class OneHotPartialObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": out}
         return {**obs, "image": out}
 
 
 
 
-class RGBImgObsWrapper(gym.ObservationWrapper):
+class RGBImgObsWrapper(ObservationWrapper):
     """
     """
     Wrapper to use fully observable RGB image as observation,
     Wrapper to use fully observable RGB image as observation,
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -187,7 +188,7 @@ class RGBImgObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": rgb_img}
         return {**obs, "image": rgb_img}
 
 
 
 
-class RGBImgPartialObsWrapper(gym.ObservationWrapper):
+class RGBImgPartialObsWrapper(ObservationWrapper):
     """
     """
     Wrapper to use partially observable RGB image as observation.
     Wrapper to use partially observable RGB image as observation.
     This can be used to have the agent to solve the gridworld in pixel space.
     This can be used to have the agent to solve the gridworld in pixel space.
@@ -218,7 +219,7 @@ class RGBImgPartialObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": rgb_img_partial}
         return {**obs, "image": rgb_img_partial}
 
 
 
 
-class FullyObsWrapper(gym.ObservationWrapper):
+class FullyObsWrapper(ObservationWrapper):
     """
     """
     Fully observable gridworld using a compact grid encoding
     Fully observable gridworld using a compact grid encoding
     """
     """
@@ -247,7 +248,7 @@ class FullyObsWrapper(gym.ObservationWrapper):
         return {**obs, "image": full_grid}
         return {**obs, "image": full_grid}
 
 
 
 
-class DictObservationSpaceWrapper(gym.ObservationWrapper):
+class DictObservationSpaceWrapper(ObservationWrapper):
     """
     """
     Transforms the observation space (that has a textual component) to a fully numerical observation space,
     Transforms the observation space (that has a textual component) to a fully numerical observation space,
     where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
     where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
@@ -365,7 +366,7 @@ class DictObservationSpaceWrapper(gym.ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class FlatObsWrapper(gym.ObservationWrapper):
+class FlatObsWrapper(ObservationWrapper):
     """
     """
     Encode mission strings using a one-hot scheme,
     Encode mission strings using a one-hot scheme,
     and combine these with observed images into one flat array
     and combine these with observed images into one flat array
@@ -424,7 +425,7 @@ class FlatObsWrapper(gym.ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class ViewSizeWrapper(gym.Wrapper):
+class ViewSizeWrapper(Wrapper):
     """
     """
     Wrapper to customize the agent field of view size.
     Wrapper to customize the agent field of view size.
     This cannot be used with fully observable wrappers.
     This cannot be used with fully observable wrappers.
@@ -459,7 +460,7 @@ class ViewSizeWrapper(gym.Wrapper):
         return {**obs, "image": image}
         return {**obs, "image": image}
 
 
 
 
-class DirectionObsWrapper(gym.ObservationWrapper):
+class DirectionObsWrapper(ObservationWrapper):
     """
     """
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
     type = {slope , angle}
     type = {slope , angle}
@@ -493,7 +494,7 @@ class DirectionObsWrapper(gym.ObservationWrapper):
         return obs
         return obs
 
 
 
 
-class SymbolicObsWrapper(gym.ObservationWrapper):
+class SymbolicObsWrapper(ObservationWrapper):
     """
     """
     Fully observable grid with a symbolic state representation.
     Fully observable grid with a symbolic state representation.
     The symbol is a triple of (X, Y, IDX), where X and Y are
     The symbol is a triple of (X, Y, IDX), where X and Y are

+ 5 - 5
run_tests.py

@@ -5,10 +5,10 @@ import random
 import gym
 import gym
 import numpy as np
 import numpy as np
 from gym import spaces
 from gym import spaces
+from gym.envs.registration import registry
 
 
-from gym_minigrid.envs.empty import EmptyEnv5x5
+from gym_minigrid.envs.empty import EmptyEnv
 from gym_minigrid.minigrid import Grid
 from gym_minigrid.minigrid import Grid
-from gym_minigrid.register import env_list
 from gym_minigrid.wrappers import (
 from gym_minigrid.wrappers import (
     DictObservationSpaceWrapper,
     DictObservationSpaceWrapper,
     FlatObsWrapper,
     FlatObsWrapper,
@@ -21,7 +21,7 @@ from gym_minigrid.wrappers import (
     ViewSizeWrapper,
     ViewSizeWrapper,
 )
 )
 
 
-# Test importing wrappers
+env_list = [key for key in registry.keys() if key.startswith("MiniGrid")]
 
 
 
 
 print("%d environments registered" % len(env_list))
 print("%d environments registered" % len(env_list))
@@ -151,13 +151,13 @@ for env_idx, env_name in enumerate(env_list):
 print("testing extra observations")
 print("testing extra observations")
 
 
 
 
-class EmptyEnvWithExtraObs(EmptyEnv5x5):
+class EmptyEnvWithExtraObs(EmptyEnv):
     """
     """
     Custom environment with an extra observation
     Custom environment with an extra observation
     """
     """
 
 
     def __init__(self, **kwargs) -> None:
     def __init__(self, **kwargs) -> None:
-        super().__init__(**kwargs)
+        super().__init__(size=5, **kwargs)
         self.observation_space["size"] = spaces.Box(
         self.observation_space["size"] = spaces.Box(
             low=0,
             low=0,
             high=1000,  # gym does not like np.iinfo(np.uint).max,
             high=1000,  # gym does not like np.iinfo(np.uint).max,