浏览代码

Added command line option to see agent view. Fixed issues with wrappers.

Maxime Chevalier-Boisvert 5 年之前
父节点
当前提交
573e5e77a2
共有 3 个文件被更改,包括 35 次插入35 次删除
  1. 4 22
      gym_minigrid/minigrid.py
  2. 1 2
      gym_minigrid/wrappers.py
  3. 30 11
      manual_control.py

+ 4 - 22
gym_minigrid/minigrid.py

@@ -1213,29 +1213,11 @@ class MiniGridEnv(gym.Env):
         grid = Grid.decode(obs)
 
         # Render the whole grid
-        img = grid.render(r, tile_size)
-
-        assert False
-
-        """
-        # Draw the agent
-        ratio = tile_size / TILE_PIXELS
-        r.push()
-        r.scale(ratio, ratio)
-        r.translate(
-            TILE_PIXELS * (0.5 + self.agent_view_size // 2),
-            TILE_PIXELS * (self.agent_view_size - 0.5)
+        img = grid.render(
+            tile_size,
+            agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
+            agent_dir = 3
         )
-        r.rotate(3 * 90)
-        r.setLineColor(255, 0, 0)
-        r.setColor(255, 0, 0)
-        r.drawPolygon([
-            (-12, 10),
-            ( 12,  0),
-            (-12, -10)
-        ])
-        r.pop()
-        """
 
         return img
 

+ 1 - 2
gym_minigrid/wrappers.py

@@ -210,8 +210,7 @@ class RGBImgPartialObsWrapper(gym.core.ObservationWrapper):
 
         rgb_img_partial = env.get_obs_render(
             obs['image'],
-            tile_size=self.tile_size,
-            mode='rgb_array'
+            tile_size=self.tile_size
         )
 
         return {

+ 30 - 11
manual_control.py

@@ -6,17 +6,32 @@ import matplotlib.pyplot as plt
 import numpy as np
 import gym
 import gym_minigrid
+from gym_minigrid.wrappers import *
+
+fig = None
+imshow_obj = None
+
+def redraw(img):
+    global imshow_obj
+
+    if not args.agent_view:
+        img = env.render('rgb_array', tile_size=args.tile_size)
+
+    # Show the first image of the environment
+    if imshow_obj is None:
+        imshow_obj = ax.imshow(img, interpolation='bilinear')
+
+    imshow_obj.set_data(img)
+    fig.canvas.draw()
 
 def reset():
-    env.reset()
+    obs = env.reset()
 
     if hasattr(env, 'mission'):
         print('Mission: %s' % env.mission)
         plt.xlabel(env.mission)
 
-    img = env.render('rgb_array', tile_size=args.tile_size)
-    imshow_obj.set_data(img)
-    fig.canvas.draw()
+    redraw(obs)
 
 def step(action):
     obs, reward, done, info = env.step(action)
@@ -26,9 +41,7 @@ def step(action):
         print('done!')
         reset()
 
-    img = env.render('rgb_array', tile_size=args.tile_size)
-    imshow_obj.set_data(img)
-    fig.canvas.draw()
+    redraw(obs)
 
 def key_handler(event):
     print('pressed', event.key)
@@ -79,11 +92,21 @@ parser.add_argument(
     help="size at which to render tiles",
     default=32
 )
+parser.add_argument(
+    '--agent_view',
+    default=False,
+    help="Draw the agent's partially observable view",
+    action='store_true'
+)
 
 args = parser.parse_args()
 
 env = gym.make(args.env_name)
 
+if args.agent_view:
+    env = RGBImgPartialObsWrapper(env)
+    env = ImgObsWrapper(env)
+
 fig, ax = plt.subplots()
 
 # Keyboard handler
@@ -96,10 +119,6 @@ fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
 ax.set_xticks([], [])
 ax.set_yticks([], [])
 
-# Show the first image of the environment
-img = env.render('rgb_array', tile_size=args.tile_size)
-imshow_obj = ax.imshow(img, interpolation='bilinear')
-
 reset()
 
 # Show the plot, enter the matplotlib event loop