Pārlūkot izejas kodu

Merge pull request #226 from rodrigodelazcano/latest_rendering_api

Update rendering and reset to gym v26
Mark Towers 2 gadi atpakaļ
vecāks
revīzija
263c99d3af

+ 2 - 2
gym_minigrid/benchmark.py

@@ -8,7 +8,7 @@ from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 
 
 
 
 def benchmark(env_id, num_resets, num_frames):
 def benchmark(env_id, num_resets, num_frames):
-    env = gym.make(env_id, new_step_api=True, render_mode="rgb_array")
+    env = gym.make(env_id, render_mode="rgb_array")
     # Benchmark env.reset
     # Benchmark env.reset
     t0 = time.time()
     t0 = time.time()
     for i in range(num_resets):
     for i in range(num_resets):
@@ -26,7 +26,7 @@ def benchmark(env_id, num_resets, num_frames):
     frames_per_sec = num_frames / dt
     frames_per_sec = num_frames / dt
 
 
     # Create an environment with an RGB agent observation
     # Create an environment with an RGB agent observation
-    env = gym.make(env_id, new_step_api=True, render_mode="rgb_array")
+    env = gym.make(env_id, render_mode="rgb_array")
     env = RGBImgPartialObsWrapper(env)
     env = RGBImgPartialObsWrapper(env)
     env = ImgObsWrapper(env)
     env = ImgObsWrapper(env)
 
 

+ 19 - 52
gym_minigrid/minigrid.py

@@ -2,14 +2,12 @@ import hashlib
 import math
 import math
 from abc import abstractmethod
 from abc import abstractmethod
 from enum import IntEnum
 from enum import IntEnum
-from functools import partial
 from typing import Any, Callable, Optional, Union
 from typing import Any, Callable, Optional, Union
 
 
 import gym
 import gym
 import numpy as np
 import numpy as np
 from gym import spaces
 from gym import spaces
 from gym.utils import seeding
 from gym.utils import seeding
-from gym.utils.renderer import Renderer
 
 
 # Size in pixels of a tile in the full-scale human view
 # Size in pixels of a tile in the full-scale human view
 from gym_minigrid.rendering import (
 from gym_minigrid.rendering import (
@@ -827,7 +825,7 @@ class MiniGridEnv(gym.Env):
     """
     """
 
 
     metadata = {
     metadata = {
-        "render_modes": ["human", "rgb_array", "single_rgb_array"],
+        "render_modes": ["human", "rgb_array"],
         "render_fps": 10,
         "render_fps": 10,
     }
     }
 
 
@@ -837,7 +835,6 @@ class MiniGridEnv(gym.Env):
         left = 0
         left = 0
         right = 1
         right = 1
         forward = 2
         forward = 2
-
         # Pick up an object
         # Pick up an object
         pickup = 3
         pickup = 3
         # Drop an object
         # Drop an object
@@ -862,9 +859,6 @@ class MiniGridEnv(gym.Env):
         tile_size: int = TILE_PIXELS,
         tile_size: int = TILE_PIXELS,
         agent_pov: bool = False,
         agent_pov: bool = False,
     ):
     ):
-        # Rendering attributes
-        self.render_mode = render_mode
-
         # Initialize mission
         # Initialize mission
         self.mission = mission_space.sample()
         self.mission = mission_space.sample()
 
 
@@ -920,13 +914,13 @@ class MiniGridEnv(gym.Env):
         self.grid = Grid(width, height)
         self.grid = Grid(width, height)
         self.carrying = None
         self.carrying = None
 
 
-        frame_rendering = partial(
-            self._render, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
-        )
-
-        self.renderer = Renderer(self.render_mode, frame_rendering)
+        # Rendering attributes
+        self.render_mode = render_mode
+        self.highlight = highlight
+        self.tile_size = tile_size
+        self.agent_pov = agent_pov
 
 
-    def reset(self, *, seed=None, return_info=False, options=None):
+    def reset(self, *, seed=None, options=None):
         super().reset(seed=seed)
         super().reset(seed=seed)
 
 
         # Reinitialize episode-specific variables
         # Reinitialize episode-specific variables
@@ -953,17 +947,13 @@ class MiniGridEnv(gym.Env):
         # Step count since episode start
         # Step count since episode start
         self.step_count = 0
         self.step_count = 0
 
 
+        if self.render_mode == "human":
+            self.render()
+
         # Return first observation
         # Return first observation
         obs = self.gen_obs()
         obs = self.gen_obs()
 
 
-        # Reset Renderer
-        self.renderer.reset()
-        self.renderer.render_step()
-
-        if not return_info:
-            return obs
-        else:
-            return obs, {}
+        return obs, {}
 
 
     def hash(self, size=16):
     def hash(self, size=16):
         """Compute a hash that uniquely identifies the current state of the environment.
         """Compute a hash that uniquely identifies the current state of the environment.
@@ -1378,10 +1368,10 @@ class MiniGridEnv(gym.Env):
         if self.step_count >= self.max_steps:
         if self.step_count >= self.max_steps:
             truncated = True
             truncated = True
 
 
-        obs = self.gen_obs()
+        if self.render_mode == "human":
+            self.render()
 
 
-        # Reset Renderer
-        self.renderer.render_step()
+        obs = self.gen_obs()
 
 
         return obs, reward, terminated, truncated, {}
         return obs, reward, terminated, truncated, {}
 
 
@@ -1529,42 +1519,19 @@ class MiniGridEnv(gym.Env):
         else:
         else:
             return self.get_full_render(highlight, tile_size)
             return self.get_full_render(highlight, tile_size)
 
 
-    def _render(self, mode: str = "human", **kwargs):
-        """
-        Render the whole-grid human view
-        """
-        img = self.get_frame(**kwargs)
+    def render(self):
 
 
-        mode = self.render_mode
-        if mode == "human":
+        img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
+
+        if self.render_mode == "human":
             if self.window is None:
             if self.window is None:
                 self.window = Window("gym_minigrid")
                 self.window = Window("gym_minigrid")
                 self.window.show(block=False)
                 self.window.show(block=False)
             self.window.set_caption(self.mission)
             self.window.set_caption(self.mission)
             self.window.show_img(img)
             self.window.show_img(img)
-        else:
+        elif self.render_mode == "rgb_array":
             return img
             return img
 
 
-    def render(
-        self,
-        mode: str = "human",
-        highlight: Optional[bool] = None,
-        tile_size: Optional[int] = None,
-        agent_pov: Optional[bool] = None,
-    ):
-        if self.render_mode is not None:
-            assert (
-                highlight is None and tile_size is None and agent_pov is None
-            ), "Unexpected argument for render. Specify render arguments at environment initialization."
-            return self.renderer.get_renders()
-        else:
-            highlight = highlight if highlight is not None else True
-            tile_size = tile_size if tile_size is not None else TILE_PIXELS
-            agent_pov = agent_pov if agent_pov is not None else False
-            return self._render(
-                mode=mode, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
-            )
-
     def close(self):
     def close(self):
         if self.window:
         if self.window:
             self.window.close()
             self.window.close()

+ 13 - 13
gym_minigrid/wrappers.py

@@ -20,7 +20,7 @@ class ReseedWrapper(Wrapper):
     def __init__(self, env, seeds=[0], seed_idx=0):
     def __init__(self, env, seeds=[0], seed_idx=0):
         self.seeds = list(seeds)
         self.seeds = list(seeds)
         self.seed_idx = seed_idx
         self.seed_idx = seed_idx
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
     def reset(self, **kwargs):
     def reset(self, **kwargs):
         seed = self.seeds[self.seed_idx]
         seed = self.seeds[self.seed_idx]
@@ -39,7 +39,7 @@ class ActionBonus(gym.Wrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
         self.counts = {}
         self.counts = {}
 
 
     def step(self, action):
     def step(self, action):
@@ -73,7 +73,7 @@ class StateBonus(Wrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
         self.counts = {}
         self.counts = {}
 
 
     def step(self, action):
     def step(self, action):
@@ -108,7 +108,7 @@ class ImgObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
         self.observation_space = env.observation_space.spaces["image"]
         self.observation_space = env.observation_space.spaces["image"]
 
 
     def observation(self, obs):
     def observation(self, obs):
@@ -122,7 +122,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
@@ -162,7 +162,7 @@ class RGBImgObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
         self.tile_size = tile_size
         self.tile_size = tile_size
 
 
@@ -190,7 +190,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, tile_size=8):
     def __init__(self, env, tile_size=8):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
         # Rendering attributes for observations
         # Rendering attributes for observations
         self.tile_size = tile_size
         self.tile_size = tile_size
@@ -219,7 +219,7 @@ class FullyObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
         new_image_space = spaces.Box(
         new_image_space = spaces.Box(
             low=0,
             low=0,
@@ -254,7 +254,7 @@ class DictObservationSpaceWrapper(ObservationWrapper):
         word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
         word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
                   if None, use the Minigrid language
                   if None, use the Minigrid language
         """
         """
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
         if word_dict is None:
         if word_dict is None:
             word_dict = self.get_minigrid_words()
             word_dict = self.get_minigrid_words()
@@ -367,7 +367,7 @@ class FlatObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, maxStrLen=96):
     def __init__(self, env, maxStrLen=96):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
         self.maxStrLen = maxStrLen
         self.maxStrLen = maxStrLen
         self.numCharCodes = 28
         self.numCharCodes = 28
@@ -428,7 +428,7 @@ class ViewSizeWrapper(Wrapper):
     """
     """
 
 
     def __init__(self, env, agent_view_size=7):
     def __init__(self, env, agent_view_size=7):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
         assert agent_view_size % 2 == 1
         assert agent_view_size % 2 == 1
         assert agent_view_size >= 3
         assert agent_view_size >= 3
@@ -463,7 +463,7 @@ class DirectionObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env, type="slope"):
     def __init__(self, env, type="slope"):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
         self.goal_position: tuple = None
         self.goal_position: tuple = None
         self.type = type
         self.type = type
 
 
@@ -498,7 +498,7 @@ class SymbolicObsWrapper(ObservationWrapper):
     """
     """
 
 
     def __init__(self, env):
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
 
         new_image_space = spaces.Box(
         new_image_space = spaces.Box(
             low=0,
             low=0,

+ 7 - 10
tests/test_envs.py

@@ -15,9 +15,6 @@ CHECK_ENV_IGNORE_WARNINGS = [
         "A Box observation space minimum value is -infinity. This is probably too low.",
         "A Box observation space minimum value is -infinity. This is probably too low.",
         "A Box observation space maximum value is -infinity. This is probably too high.",
         "A Box observation space maximum value is -infinity. This is probably too high.",
         "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
         "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
-        "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
-        "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
-        "Core environment is written in old step API which returns one bool instead of two. It is recommended to  norewrite the environment with new step API. ",
     ]
     ]
 ]
 ]
 
 
@@ -61,8 +58,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
     if env_spec.nondeterministic is True:
     if env_spec.nondeterministic is True:
         return
         return
 
 
-    env_1 = env_spec.make(disable_env_checker=True, new_step_api=True)
-    env_2 = env_spec.make(disable_env_checker=True, new_step_api=True)
+    env_1 = env_spec.make(disable_env_checker=True)
+    env_2 = env_spec.make(disable_env_checker=True)
 
 
     initial_obs_1 = env_1.reset(seed=SEED)
     initial_obs_1 = env_1.reset(seed=SEED)
     initial_obs_2 = env_2.reset(seed=SEED)
     initial_obs_2 = env_2.reset(seed=SEED)
@@ -105,11 +102,11 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
     "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
     "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 )
 def test_render_modes(spec):
 def test_render_modes(spec):
-    env = spec.make(new_step_api=True)
+    env = spec.make()
 
 
     for mode in env.metadata.get("render_modes", []):
     for mode in env.metadata.get("render_modes", []):
         if mode != "human":
         if mode != "human":
-            new_env = spec.make(new_step_api=True, render_mode=mode)
+            new_env = spec.make(render_mode=mode)
 
 
             new_env.reset()
             new_env.reset()
             new_env.step(new_env.action_space.sample())
             new_env.step(new_env.action_space.sample())
@@ -118,7 +115,7 @@ def test_render_modes(spec):
 
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
 @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
 def test_agent_sees_method(env_id):
 def test_agent_sees_method(env_id):
-    env = gym.make(env_id, new_step_api=True)
+    env = gym.make(env_id)
     goal_pos = (env.grid.width - 2, env.grid.height - 2)
     goal_pos = (env.grid.width - 2, env.grid.height - 2)
 
 
     # Test the env.agent_sees() function
     # Test the env.agent_sees() function
@@ -146,7 +143,7 @@ def test_agent_sees_method(env_id):
 )
 )
 def old_run_test(env_spec):
 def old_run_test(env_spec):
     # Load the gym environment
     # Load the gym environment
-    env = env_spec.make(new_step_api=True)
+    env = env_spec.make()
     env.max_steps = min(env.max_steps, 200)
     env.max_steps = min(env.max_steps, 200)
     env.reset()
     env.reset()
     env.render()
     env.render()
@@ -199,7 +196,7 @@ def old_run_test(env_spec):
 
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
 def test_interactive_mode(env_id):
 def test_interactive_mode(env_id):
-    env = gym.make(env_id, new_step_api=True)
+    env = gym.make(env_id)
     env.reset()
     env.reset()
 
 
     for i in range(0, 100):
     for i in range(0, 100):

+ 1 - 1
tests/test_scripts.py

@@ -43,7 +43,7 @@ def test_manual_control():
                 self.key = self.close_action
                 self.key = self.close_action
 
 
     env_id = "MiniGrid-Empty-16x16-v0"
     env_id = "MiniGrid-Empty-16x16-v0"
-    env = gym.make(env_id, new_step_api=True)
+    env = gym.make(env_id)
     window = Window(env_id)
     window = Window(env_id)
 
 
     reset(env, window)
     reset(env, window)

+ 14 - 15
tests/test_wrappers.py

@@ -32,8 +32,8 @@ def test_reseed_wrapper(env_spec):
     """
     """
     Test the ReseedWrapper with a list of SEEDS.
     Test the ReseedWrapper with a list of SEEDS.
     """
     """
-    unwrapped_env = env_spec.make(new_step_api=True)
-    env = env_spec.make(new_step_api=True)
+    unwrapped_env = env_spec.make()
+    env = env_spec.make()
     env = ReseedWrapper(env, seeds=SEEDS)
     env = ReseedWrapper(env, seeds=SEEDS)
     env.action_space.seed(0)
     env.action_space.seed(0)
 
 
@@ -76,8 +76,8 @@ def test_reseed_wrapper(env_spec):
 
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
 def test_state_bonus_wrapper(env_id):
 def test_state_bonus_wrapper(env_id):
-    env = gym.make(env_id, new_step_api=True)
-    wrapped_env = StateBonus(gym.make(env_id, new_step_api=True))
+    env = gym.make(env_id)
+    wrapped_env = StateBonus(gym.make(env_id))
 
 
     action_forward = MiniGridEnv.Actions.forward
     action_forward = MiniGridEnv.Actions.forward
     action_left = MiniGridEnv.Actions.left
     action_left = MiniGridEnv.Actions.left
@@ -106,8 +106,8 @@ def test_state_bonus_wrapper(env_id):
 
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
 def test_action_bonus_wrapper(env_id):
 def test_action_bonus_wrapper(env_id):
-    env = gym.make(env_id, new_step_api=True)
-    wrapped_env = ActionBonus(gym.make(env_id, new_step_api=True))
+    env = gym.make(env_id)
+    wrapped_env = ActionBonus(gym.make(env_id))
 
 
     action = MiniGridEnv.Actions.forward
     action = MiniGridEnv.Actions.forward
 
 
@@ -129,7 +129,7 @@ def test_action_bonus_wrapper(env_id):
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 )
 def test_dict_observation_space_wrapper(env_spec):
 def test_dict_observation_space_wrapper(env_spec):
-    env = env_spec.make(new_step_api=True)
+    env = env_spec.make()
     env = DictObservationSpaceWrapper(env)
     env = DictObservationSpaceWrapper(env)
     env.reset()
     env.reset()
     mission = env.mission
     mission = env.mission
@@ -157,7 +157,7 @@ def test_dict_observation_space_wrapper(env_spec):
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 )
 def test_main_wrappers(wrapper, env_spec):
 def test_main_wrappers(wrapper, env_spec):
-    env = env_spec.make(new_step_api=True)
+    env = env_spec.make()
     env = wrapper(env)
     env = wrapper(env)
     for _ in range(10):
     for _ in range(10):
         env.reset()
         env.reset()
@@ -177,7 +177,7 @@ def test_main_wrappers(wrapper, env_spec):
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 )
 def test_observation_space_wrappers(wrapper, env_spec):
 def test_observation_space_wrappers(wrapper, env_spec):
-    env = wrapper(env_spec.make(disable_env_checker=True, new_step_api=True))
+    env = wrapper(env_spec.make(disable_env_checker=True))
     obs_space, wrapper_name = env.observation_space, wrapper.__name__
     obs_space, wrapper_name = env.observation_space, wrapper.__name__
     assert isinstance(
     assert isinstance(
         obs_space, gym.spaces.Dict
         obs_space, gym.spaces.Dict
@@ -196,15 +196,14 @@ class EmptyEnvWithExtraObs(EmptyEnv):
 
 
     def __init__(self) -> None:
     def __init__(self) -> None:
         super().__init__(size=5)
         super().__init__(size=5)
-        self.new_step_api = True
         self.observation_space["size"] = gym.spaces.Box(
         self.observation_space["size"] = gym.spaces.Box(
             low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
             low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
         )
         )
 
 
     def reset(self, **kwargs):
     def reset(self, **kwargs):
-        obs = super().reset(**kwargs)
+        obs, info = super().reset(**kwargs)
         obs["size"] = np.array([self.width, self.height])
         obs["size"] = np.array([self.width, self.height])
-        return obs
+        return obs, info
 
 
     def step(self, action):
     def step(self, action):
         obs, reward, terminated, truncated, info = super().step(action)
         obs, reward, terminated, truncated, info = super().step(action)
@@ -223,10 +222,10 @@ class EmptyEnvWithExtraObs(EmptyEnv):
 )
 )
 def test_agent_sees_method(wrapper):
 def test_agent_sees_method(wrapper):
     env1 = wrapper(EmptyEnvWithExtraObs())
     env1 = wrapper(EmptyEnvWithExtraObs())
-    env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0", new_step_api=True))
+    env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
 
 
-    obs1 = env1.reset(seed=0)
-    obs2 = env2.reset(seed=0)
+    obs1, _ = env1.reset(seed=0)
+    obs2, _ = env2.reset(seed=0)
     assert "size" in obs1
     assert "size" in obs1
     assert obs1["size"].shape == (2,)
     assert obs1["size"].shape == (2,)
     assert (obs1["size"] == [5, 5]).all()
     assert (obs1["size"] == [5, 5]).all()