Преглед изворни кода

Merge pull request #226 from rodrigodelazcano/latest_rendering_api

Update rendering and reset to gym v26
Mark Towers пре 2 година
родитељ
комит
263c99d3af
6 измењених фајлова са 56 додато и 93 уклоњено
  1. 2 2
      gym_minigrid/benchmark.py
  2. 19 52
      gym_minigrid/minigrid.py
  3. 13 13
      gym_minigrid/wrappers.py
  4. 7 10
      tests/test_envs.py
  5. 1 1
      tests/test_scripts.py
  6. 14 15
      tests/test_wrappers.py

+ 2 - 2
gym_minigrid/benchmark.py

@@ -8,7 +8,7 @@ from gym_minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
 
 
 def benchmark(env_id, num_resets, num_frames):
-    env = gym.make(env_id, new_step_api=True, render_mode="rgb_array")
+    env = gym.make(env_id, render_mode="rgb_array")
     # Benchmark env.reset
     t0 = time.time()
     for i in range(num_resets):
@@ -26,7 +26,7 @@ def benchmark(env_id, num_resets, num_frames):
     frames_per_sec = num_frames / dt
 
     # Create an environment with an RGB agent observation
-    env = gym.make(env_id, new_step_api=True, render_mode="rgb_array")
+    env = gym.make(env_id, render_mode="rgb_array")
     env = RGBImgPartialObsWrapper(env)
     env = ImgObsWrapper(env)
 

+ 19 - 52
gym_minigrid/minigrid.py

@@ -2,14 +2,12 @@ import hashlib
 import math
 from abc import abstractmethod
 from enum import IntEnum
-from functools import partial
 from typing import Any, Callable, Optional, Union
 
 import gym
 import numpy as np
 from gym import spaces
 from gym.utils import seeding
-from gym.utils.renderer import Renderer
 
 # Size in pixels of a tile in the full-scale human view
 from gym_minigrid.rendering import (
@@ -827,7 +825,7 @@ class MiniGridEnv(gym.Env):
     """
 
     metadata = {
-        "render_modes": ["human", "rgb_array", "single_rgb_array"],
+        "render_modes": ["human", "rgb_array"],
         "render_fps": 10,
     }
 
@@ -837,7 +835,6 @@ class MiniGridEnv(gym.Env):
         left = 0
         right = 1
         forward = 2
-
         # Pick up an object
         pickup = 3
         # Drop an object
@@ -862,9 +859,6 @@ class MiniGridEnv(gym.Env):
         tile_size: int = TILE_PIXELS,
         agent_pov: bool = False,
     ):
-        # Rendering attributes
-        self.render_mode = render_mode
-
         # Initialize mission
         self.mission = mission_space.sample()
 
@@ -920,13 +914,13 @@ class MiniGridEnv(gym.Env):
         self.grid = Grid(width, height)
         self.carrying = None
 
-        frame_rendering = partial(
-            self._render, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
-        )
-
-        self.renderer = Renderer(self.render_mode, frame_rendering)
+        # Rendering attributes
+        self.render_mode = render_mode
+        self.highlight = highlight
+        self.tile_size = tile_size
+        self.agent_pov = agent_pov
 
-    def reset(self, *, seed=None, return_info=False, options=None):
+    def reset(self, *, seed=None, options=None):
         super().reset(seed=seed)
 
         # Reinitialize episode-specific variables
@@ -953,17 +947,13 @@ class MiniGridEnv(gym.Env):
         # Step count since episode start
         self.step_count = 0
 
+        if self.render_mode == "human":
+            self.render()
+
         # Return first observation
         obs = self.gen_obs()
 
-        # Reset Renderer
-        self.renderer.reset()
-        self.renderer.render_step()
-
-        if not return_info:
-            return obs
-        else:
-            return obs, {}
+        return obs, {}
 
     def hash(self, size=16):
         """Compute a hash that uniquely identifies the current state of the environment.
@@ -1378,10 +1368,10 @@ class MiniGridEnv(gym.Env):
         if self.step_count >= self.max_steps:
             truncated = True
 
-        obs = self.gen_obs()
+        if self.render_mode == "human":
+            self.render()
 
-        # Reset Renderer
-        self.renderer.render_step()
+        obs = self.gen_obs()
 
         return obs, reward, terminated, truncated, {}
 
@@ -1529,42 +1519,19 @@ class MiniGridEnv(gym.Env):
         else:
             return self.get_full_render(highlight, tile_size)
 
-    def _render(self, mode: str = "human", **kwargs):
-        """
-        Render the whole-grid human view
-        """
-        img = self.get_frame(**kwargs)
+    def render(self):
 
-        mode = self.render_mode
-        if mode == "human":
+        img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
+
+        if self.render_mode == "human":
             if self.window is None:
                 self.window = Window("gym_minigrid")
                 self.window.show(block=False)
             self.window.set_caption(self.mission)
             self.window.show_img(img)
-        else:
+        elif self.render_mode == "rgb_array":
             return img
 
-    def render(
-        self,
-        mode: str = "human",
-        highlight: Optional[bool] = None,
-        tile_size: Optional[int] = None,
-        agent_pov: Optional[bool] = None,
-    ):
-        if self.render_mode is not None:
-            assert (
-                highlight is None and tile_size is None and agent_pov is None
-            ), "Unexpected argument for render. Specify render arguments at environment initialization."
-            return self.renderer.get_renders()
-        else:
-            highlight = highlight if highlight is not None else True
-            tile_size = tile_size if tile_size is not None else TILE_PIXELS
-            agent_pov = agent_pov if agent_pov is not None else False
-            return self._render(
-                mode=mode, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
-            )
-
     def close(self):
         if self.window:
             self.window.close()

+ 13 - 13
gym_minigrid/wrappers.py

@@ -20,7 +20,7 @@ class ReseedWrapper(Wrapper):
     def __init__(self, env, seeds=[0], seed_idx=0):
         self.seeds = list(seeds)
         self.seed_idx = seed_idx
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
     def reset(self, **kwargs):
         seed = self.seeds[self.seed_idx]
@@ -39,7 +39,7 @@ class ActionBonus(gym.Wrapper):
     """
 
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
         self.counts = {}
 
     def step(self, action):
@@ -73,7 +73,7 @@ class StateBonus(Wrapper):
     """
 
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
         self.counts = {}
 
     def step(self, action):
@@ -108,7 +108,7 @@ class ImgObsWrapper(ObservationWrapper):
     """
 
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
         self.observation_space = env.observation_space.spaces["image"]
 
     def observation(self, obs):
@@ -122,7 +122,7 @@ class OneHotPartialObsWrapper(ObservationWrapper):
     """
 
     def __init__(self, env, tile_size=8):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
         self.tile_size = tile_size
 
@@ -162,7 +162,7 @@ class RGBImgObsWrapper(ObservationWrapper):
     """
 
     def __init__(self, env, tile_size=8):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
         self.tile_size = tile_size
 
@@ -190,7 +190,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     """
 
     def __init__(self, env, tile_size=8):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
         # Rendering attributes for observations
         self.tile_size = tile_size
@@ -219,7 +219,7 @@ class FullyObsWrapper(ObservationWrapper):
     """
 
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
         new_image_space = spaces.Box(
             low=0,
@@ -254,7 +254,7 @@ class DictObservationSpaceWrapper(ObservationWrapper):
         word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
                   if None, use the Minigrid language
         """
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
         if word_dict is None:
             word_dict = self.get_minigrid_words()
@@ -367,7 +367,7 @@ class FlatObsWrapper(ObservationWrapper):
     """
 
     def __init__(self, env, maxStrLen=96):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
         self.maxStrLen = maxStrLen
         self.numCharCodes = 28
@@ -428,7 +428,7 @@ class ViewSizeWrapper(Wrapper):
     """
 
     def __init__(self, env, agent_view_size=7):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
         assert agent_view_size % 2 == 1
         assert agent_view_size >= 3
@@ -463,7 +463,7 @@ class DirectionObsWrapper(ObservationWrapper):
     """
 
     def __init__(self, env, type="slope"):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
         self.goal_position: tuple = None
         self.type = type
 
@@ -498,7 +498,7 @@ class SymbolicObsWrapper(ObservationWrapper):
     """
 
     def __init__(self, env):
-        super().__init__(env, new_step_api=env.new_step_api)
+        super().__init__(env)
 
         new_image_space = spaces.Box(
             low=0,

+ 7 - 10
tests/test_envs.py

@@ -15,9 +15,6 @@ CHECK_ENV_IGNORE_WARNINGS = [
         "A Box observation space minimum value is -infinity. This is probably too low.",
         "A Box observation space maximum value is -infinity. This is probably too high.",
         "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
-        "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
-        "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
-        "Core environment is written in old step API which returns one bool instead of two. It is recommended to  norewrite the environment with new step API. ",
     ]
 ]
 
@@ -61,8 +58,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
     if env_spec.nondeterministic is True:
         return
 
-    env_1 = env_spec.make(disable_env_checker=True, new_step_api=True)
-    env_2 = env_spec.make(disable_env_checker=True, new_step_api=True)
+    env_1 = env_spec.make(disable_env_checker=True)
+    env_2 = env_spec.make(disable_env_checker=True)
 
     initial_obs_1 = env_1.reset(seed=SEED)
     initial_obs_2 = env_2.reset(seed=SEED)
@@ -105,11 +102,11 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
     "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 def test_render_modes(spec):
-    env = spec.make(new_step_api=True)
+    env = spec.make()
 
     for mode in env.metadata.get("render_modes", []):
         if mode != "human":
-            new_env = spec.make(new_step_api=True, render_mode=mode)
+            new_env = spec.make(render_mode=mode)
 
             new_env.reset()
             new_env.step(new_env.action_space.sample())
@@ -118,7 +115,7 @@ def test_render_modes(spec):
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
 def test_agent_sees_method(env_id):
-    env = gym.make(env_id, new_step_api=True)
+    env = gym.make(env_id)
     goal_pos = (env.grid.width - 2, env.grid.height - 2)
 
     # Test the env.agent_sees() function
@@ -146,7 +143,7 @@ def test_agent_sees_method(env_id):
 )
 def old_run_test(env_spec):
     # Load the gym environment
-    env = env_spec.make(new_step_api=True)
+    env = env_spec.make()
     env.max_steps = min(env.max_steps, 200)
     env.reset()
     env.render()
@@ -199,7 +196,7 @@ def old_run_test(env_spec):
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
 def test_interactive_mode(env_id):
-    env = gym.make(env_id, new_step_api=True)
+    env = gym.make(env_id)
     env.reset()
 
     for i in range(0, 100):

+ 1 - 1
tests/test_scripts.py

@@ -43,7 +43,7 @@ def test_manual_control():
                 self.key = self.close_action
 
     env_id = "MiniGrid-Empty-16x16-v0"
-    env = gym.make(env_id, new_step_api=True)
+    env = gym.make(env_id)
     window = Window(env_id)
 
     reset(env, window)

+ 14 - 15
tests/test_wrappers.py

@@ -32,8 +32,8 @@ def test_reseed_wrapper(env_spec):
     """
     Test the ReseedWrapper with a list of SEEDS.
     """
-    unwrapped_env = env_spec.make(new_step_api=True)
-    env = env_spec.make(new_step_api=True)
+    unwrapped_env = env_spec.make()
+    env = env_spec.make()
     env = ReseedWrapper(env, seeds=SEEDS)
     env.action_space.seed(0)
 
@@ -76,8 +76,8 @@ def test_reseed_wrapper(env_spec):
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
 def test_state_bonus_wrapper(env_id):
-    env = gym.make(env_id, new_step_api=True)
-    wrapped_env = StateBonus(gym.make(env_id, new_step_api=True))
+    env = gym.make(env_id)
+    wrapped_env = StateBonus(gym.make(env_id))
 
     action_forward = MiniGridEnv.Actions.forward
     action_left = MiniGridEnv.Actions.left
@@ -106,8 +106,8 @@ def test_state_bonus_wrapper(env_id):
 
 @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
 def test_action_bonus_wrapper(env_id):
-    env = gym.make(env_id, new_step_api=True)
-    wrapped_env = ActionBonus(gym.make(env_id, new_step_api=True))
+    env = gym.make(env_id)
+    wrapped_env = ActionBonus(gym.make(env_id))
 
     action = MiniGridEnv.Actions.forward
 
@@ -129,7 +129,7 @@ def test_action_bonus_wrapper(env_id):
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 def test_dict_observation_space_wrapper(env_spec):
-    env = env_spec.make(new_step_api=True)
+    env = env_spec.make()
     env = DictObservationSpaceWrapper(env)
     env.reset()
     mission = env.mission
@@ -157,7 +157,7 @@ def test_dict_observation_space_wrapper(env_spec):
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 def test_main_wrappers(wrapper, env_spec):
-    env = env_spec.make(new_step_api=True)
+    env = env_spec.make()
     env = wrapper(env)
     for _ in range(10):
         env.reset()
@@ -177,7 +177,7 @@ def test_main_wrappers(wrapper, env_spec):
     "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
 )
 def test_observation_space_wrappers(wrapper, env_spec):
-    env = wrapper(env_spec.make(disable_env_checker=True, new_step_api=True))
+    env = wrapper(env_spec.make(disable_env_checker=True))
     obs_space, wrapper_name = env.observation_space, wrapper.__name__
     assert isinstance(
         obs_space, gym.spaces.Dict
@@ -196,15 +196,14 @@ class EmptyEnvWithExtraObs(EmptyEnv):
 
     def __init__(self) -> None:
         super().__init__(size=5)
-        self.new_step_api = True
         self.observation_space["size"] = gym.spaces.Box(
             low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
         )
 
     def reset(self, **kwargs):
-        obs = super().reset(**kwargs)
+        obs, info = super().reset(**kwargs)
         obs["size"] = np.array([self.width, self.height])
-        return obs
+        return obs, info
 
     def step(self, action):
         obs, reward, terminated, truncated, info = super().step(action)
@@ -223,10 +222,10 @@ class EmptyEnvWithExtraObs(EmptyEnv):
 )
 def test_agent_sees_method(wrapper):
     env1 = wrapper(EmptyEnvWithExtraObs())
-    env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0", new_step_api=True))
+    env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
 
-    obs1 = env1.reset(seed=0)
-    obs2 = env2.reset(seed=0)
+    obs1, _ = env1.reset(seed=0)
+    obs2, _ = env2.reset(seed=0)
     assert "size" in obs1
     assert obs1["size"].shape == (2,)
     assert (obs1["size"] == [5, 5]).all()