Parcourir la source

new render api

Rodrigo Perez-Vicente il y a 2 ans
Parent
commit
dec6ded9a0
1 fichiers modifiés avec 40 ajouts et 11 suppressions
  1. 40 11
      gym_minigrid/minigrid.py

+ 40 - 11
gym_minigrid/minigrid.py

@@ -2,12 +2,14 @@ import hashlib
 import math
 from abc import abstractmethod
 from enum import IntEnum
+from functools import partial
 from typing import Any, Callable, Optional, Union
 
 import gym
 import numpy as np
 from gym import spaces
 from gym.utils import seeding
+from gym.utils.renderer import Renderer
 
 # Size in pixels of a tile in the full-scale human view
 from gym_minigrid.rendering import (
@@ -865,6 +867,7 @@ class MiniGridEnv(gym.Env):
         self.render_mode = render_mode
         self.highlight = highlight
         self.tile_size = tile_size
+        # Agent's Point of View
         self.agent_pov = agent_pov
 
         # Initialize mission
@@ -922,6 +925,10 @@ class MiniGridEnv(gym.Env):
         self.grid = Grid(width, height)
         self.carrying = None
 
+        render_frame = partial(self._render, highlight=highlight, tile_size=tile_size)
+
+        self.renderer = Renderer(self.render_mode, render_frame)
+
         # Initialize the state
         self.reset()
 
@@ -955,6 +962,10 @@ class MiniGridEnv(gym.Env):
         # Return first observation
         obs = self.gen_obs()
 
+        # Reset Renderer
+        self.renderer.reset()
+        self.renderer.render_step()
+
         if not return_info:
             return obs
         else:
@@ -1375,6 +1386,9 @@ class MiniGridEnv(gym.Env):
 
         obs = self.gen_obs()
 
+        # Reset Renderer
+        self.renderer.render_step()
+
         return obs, reward, terminated, truncated, {}
 
     def gen_obs_grid(self, agent_view_size=None):
@@ -1432,7 +1446,7 @@ class MiniGridEnv(gym.Env):
 
         return obs
 
-    def get_pov_render(self):
+    def get_pov_render(self, tile_size):
         """
         Render an agent's POV observation for visualization
         """
@@ -1440,7 +1454,7 @@ class MiniGridEnv(gym.Env):
 
         # Render the whole grid
         img = grid.render(
-            self.tile_size,
+            tile_size,
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
             agent_dir=3,
             highlight_mask=vis_mask,
@@ -1448,13 +1462,10 @@ class MiniGridEnv(gym.Env):
 
         return img
 
-    def get_full_render(self):
+    def get_full_render(self, highlight, tile_size):
         """
         Render a non-paratial observation for visualization
         """
-        tile_size = self.tile_size
-        highlight = self.highlight
-
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
 
@@ -1499,18 +1510,20 @@ class MiniGridEnv(gym.Env):
 
         return img
 
-    def get_render(self):
+    def get_render(self, highlight, tile_size):
         "Returns an image corresponding to the whole environment or the agent's pov"
         if self.agent_pov:
-            return self.get_pov_render()
+            return self.get_pov_render(tile_size)
         else:
-            return self.get_full_render()
+            return self.get_full_render(highlight, tile_size)
 
-    def render(self):
+    def _render(
+        self, mode: str = "human", highlight: bool = True, tile_size: int = TILE_PIXELS
+    ):
         """
         Render the whole-grid human view
         """
-        img = self.get_render()
+        img = self.get_render(highlight, tile_size)
 
         mode = self.render_mode
         if mode == "human":
@@ -1522,6 +1535,22 @@ class MiniGridEnv(gym.Env):
         else:
             return img
 
+    def render(
+        self,
+        mode: str = "human",
+        highlight: Optional[bool] = None,
+        tile_size: Optional[int] = None,
+    ):
+        if self.render_mode is not None:
+            assert (
+                highlight is None and tile_size is None
+            ), "Unexpected argument for render. Specify render arguments at environment initialization."
+            return self.renderer.get_renders()
+        else:
+            highlight = highlight if highlight is not None else True
+            tile_size = tile_size if tile_size is not None else TILE_PIXELS
+            return self._render(mode=mode, highlight=highlight, tile_size=tile_size)
+
     def close(self):
         if self.window:
             self.window.close()