浏览代码

new render api

Rodrigo Perez-Vicente 2 年之前
父节点
当前提交
dec6ded9a0
共有 1 个文件被更改,包括 40 次插入11 次删除
  1. 40 11
      gym_minigrid/minigrid.py

+ 40 - 11
gym_minigrid/minigrid.py

@@ -2,12 +2,14 @@ import hashlib
 import math
 import math
 from abc import abstractmethod
 from abc import abstractmethod
 from enum import IntEnum
 from enum import IntEnum
+from functools import partial
 from typing import Any, Callable, Optional, Union
 from typing import Any, Callable, Optional, Union
 
 
 import gym
 import gym
 import numpy as np
 import numpy as np
 from gym import spaces
 from gym import spaces
 from gym.utils import seeding
 from gym.utils import seeding
+from gym.utils.renderer import Renderer
 
 
 # Size in pixels of a tile in the full-scale human view
 # Size in pixels of a tile in the full-scale human view
 from gym_minigrid.rendering import (
 from gym_minigrid.rendering import (
@@ -865,6 +867,7 @@ class MiniGridEnv(gym.Env):
         self.render_mode = render_mode
         self.render_mode = render_mode
         self.highlight = highlight
         self.highlight = highlight
         self.tile_size = tile_size
         self.tile_size = tile_size
+        # Agent's Point of View
         self.agent_pov = agent_pov
         self.agent_pov = agent_pov
 
 
         # Initialize mission
         # Initialize mission
@@ -922,6 +925,10 @@ class MiniGridEnv(gym.Env):
         self.grid = Grid(width, height)
         self.grid = Grid(width, height)
         self.carrying = None
         self.carrying = None
 
 
+        render_frame = partial(self._render, highlight=highlight, tile_size=tile_size)
+
+        self.renderer = Renderer(self.render_mode, render_frame)
+
         # Initialize the state
         # Initialize the state
         self.reset()
         self.reset()
 
 
@@ -955,6 +962,10 @@ class MiniGridEnv(gym.Env):
         # Return first observation
         # Return first observation
         obs = self.gen_obs()
         obs = self.gen_obs()
 
 
+        # Reset Renderer
+        self.renderer.reset()
+        self.renderer.render_step()
+
         if not return_info:
         if not return_info:
             return obs
             return obs
         else:
         else:
@@ -1375,6 +1386,9 @@ class MiniGridEnv(gym.Env):
 
 
         obs = self.gen_obs()
         obs = self.gen_obs()
 
 
+        # Reset Renderer
+        self.renderer.render_step()
+
         return obs, reward, terminated, truncated, {}
         return obs, reward, terminated, truncated, {}
 
 
     def gen_obs_grid(self, agent_view_size=None):
     def gen_obs_grid(self, agent_view_size=None):
@@ -1432,7 +1446,7 @@ class MiniGridEnv(gym.Env):
 
 
         return obs
         return obs
 
 
-    def get_pov_render(self):
+    def get_pov_render(self, tile_size):
         """
         """
         Render an agent's POV observation for visualization
         Render an agent's POV observation for visualization
         """
         """
@@ -1440,7 +1454,7 @@ class MiniGridEnv(gym.Env):
 
 
         # Render the whole grid
         # Render the whole grid
         img = grid.render(
         img = grid.render(
-            self.tile_size,
+            tile_size,
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
             agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
             agent_dir=3,
             agent_dir=3,
             highlight_mask=vis_mask,
             highlight_mask=vis_mask,
@@ -1448,13 +1462,10 @@ class MiniGridEnv(gym.Env):
 
 
         return img
         return img
 
 
-    def get_full_render(self):
+    def get_full_render(self, highlight, tile_size):
         """
         """
         Render a non-paratial observation for visualization
         Render a non-paratial observation for visualization
         """
         """
-        tile_size = self.tile_size
-        highlight = self.highlight
-
         # Compute which cells are visible to the agent
         # Compute which cells are visible to the agent
         _, vis_mask = self.gen_obs_grid()
         _, vis_mask = self.gen_obs_grid()
 
 
@@ -1499,18 +1510,20 @@ class MiniGridEnv(gym.Env):
 
 
         return img
         return img
 
 
-    def get_render(self):
+    def get_render(self, highlight, tile_size):
         "Returns an image corresponding to the whole environment or the agent's pov"
         "Returns an image corresponding to the whole environment or the agent's pov"
         if self.agent_pov:
         if self.agent_pov:
-            return self.get_pov_render()
+            return self.get_pov_render(tile_size)
         else:
         else:
-            return self.get_full_render()
+            return self.get_full_render(highlight, tile_size)
 
 
-    def render(self):
+    def _render(
+        self, mode: str = "human", highlight: bool = True, tile_size: int = TILE_PIXELS
+    ):
         """
         """
         Render the whole-grid human view
         Render the whole-grid human view
         """
         """
-        img = self.get_render()
+        img = self.get_render(highlight, tile_size)
 
 
         mode = self.render_mode
         mode = self.render_mode
         if mode == "human":
         if mode == "human":
@@ -1522,6 +1535,22 @@ class MiniGridEnv(gym.Env):
         else:
         else:
             return img
             return img
 
 
+    def render(
+        self,
+        mode: str = "human",
+        highlight: Optional[bool] = None,
+        tile_size: Optional[int] = None,
+    ):
+        if self.render_mode is not None:
+            assert (
+                highlight is None and tile_size is None
+            ), "Unexpected argument for render. Specify render arguments at environment initialization."
+            return self.renderer.get_renders()
+        else:
+            highlight = highlight if highlight is not None else True
+            tile_size = tile_size if tile_size is not None else TILE_PIXELS
+            return self._render(mode=mode, highlight=highlight, tile_size=tile_size)
+
     def close(self):
     def close(self):
         if self.window:
         if self.window:
             self.window.close()
             self.window.close()