Sfoglia il codice sorgente

get_frame and arguments

Rodrigo Perez-Vicente 2 anni fa
parent
commit
d8fe5eb4fd

+ 1 - 1
gym_minigrid/manual_control.py

@@ -17,7 +17,7 @@ def reset(env, window):
         print("Mission: %s" % env.mission)
         window.set_caption(env.mission)
 
-    img = env.get_render()
+    img = env.get_frame()
 
     redraw(window, img)
 

+ 19 - 15
gym_minigrid/minigrid.py

@@ -861,14 +861,9 @@ class MiniGridEnv(gym.Env):
         highlight: bool = True,
         tile_size: int = TILE_PIXELS,
         agent_pov: bool = False,
-        **kwargs,
     ):
         # Rendering attributes
         self.render_mode = render_mode
-        self.highlight = highlight
-        self.tile_size = tile_size
-        # Agent's Point of View
-        self.agent_pov = agent_pov
 
         # Initialize mission
         self.mission = mission_space.sample()
@@ -925,9 +920,11 @@ class MiniGridEnv(gym.Env):
         self.grid = Grid(width, height)
         self.carrying = None
 
-        render_frame = partial(self._render, highlight=highlight, tile_size=tile_size)
+        frame_rendering = partial(
+            self._render, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
+        )
 
-        self.renderer = Renderer(self.render_mode, render_frame)
+        self.renderer = Renderer(self.render_mode, frame_rendering)
 
         # Initialize the state
         self.reset()
@@ -1510,20 +1507,23 @@ class MiniGridEnv(gym.Env):
 
         return img
 
-    def get_render(self, highlight, tile_size):
+    def get_frame(
+        self,
+        highlight: bool = True,
+        tile_size: int = TILE_PIXELS,
+        agent_pov: bool = False,
+    ):
         "Returns an image corresponding to the whole environment or the agent's pov"
-        if self.agent_pov:
+        if agent_pov:
             return self.get_pov_render(tile_size)
         else:
             return self.get_full_render(highlight, tile_size)
 
-    def _render(
-        self, mode: str = "human", highlight: bool = True, tile_size: int = TILE_PIXELS
-    ):
+    def _render(self, mode: str = "human", **kwargs):
         """
         Render the whole-grid human view
         """
-        img = self.get_render(highlight, tile_size)
+        img = self.get_frame(**kwargs)
 
         mode = self.render_mode
         if mode == "human":
@@ -1540,16 +1540,20 @@ class MiniGridEnv(gym.Env):
         mode: str = "human",
         highlight: Optional[bool] = None,
         tile_size: Optional[int] = None,
+        agent_pov: Optional[bool] = None,
     ):
         if self.render_mode is not None:
             assert (
-                highlight is None and tile_size is None
+                highlight is None and tile_size is None and agent_pov is None
             ), "Unexpected argument for render. Specify render arguments at environment initialization."
             return self.renderer.get_renders()
         else:
             highlight = highlight if highlight is not None else True
             tile_size = tile_size if tile_size is not None else TILE_PIXELS
-            return self._render(mode=mode, highlight=highlight, tile_size=tile_size)
+            agent_pov = agent_pov if agent_pov is not None else False
+            return self._render(
+                mode=mode, highlight=highlight, tile_size=tile_size, agent_pov=agent_pov
+            )
 
     def close(self):
         if self.window:

+ 0 - 1
gym_minigrid/window.py

@@ -88,6 +88,5 @@ class Window:
         """
         Close the window
         """
-
         plt.close()
         self.closed = True

+ 3 - 8
gym_minigrid/wrappers.py

@@ -164,10 +164,6 @@ class RGBImgObsWrapper(ObservationWrapper):
     def __init__(self, env, tile_size=8):
         super().__init__(env, new_step_api=env.new_step_api)
 
-        # Rendering attributes
-        self.highlight = True
-        self.tile_size = tile_size
-
         self.tile_size = tile_size
 
         new_image_space = spaces.Box(
@@ -182,7 +178,7 @@ class RGBImgObsWrapper(ObservationWrapper):
         )
 
     def observation(self, obs):
-        rgb_img = self.get_full_render()
+        rgb_img = self.get_frame(highlight=True, tile_size=self.tile_size)
 
         return {**obs, "image": rgb_img}
 
@@ -196,8 +192,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
     def __init__(self, env, tile_size=8):
         super().__init__(env, new_step_api=env.new_step_api)
 
-        # Rendering attributes
-        self.unwrapped.agent_pov = True
+        # Rendering attributes for observations
         self.tile_size = tile_size
 
         obs_shape = env.observation_space.spaces["image"].shape
@@ -213,7 +208,7 @@ class RGBImgPartialObsWrapper(ObservationWrapper):
         )
 
     def observation(self, obs):
-        rgb_img_partial = self.get_pov_render()
+        rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
 
         return {**obs, "image": rgb_img_partial}