|
@@ -6,17 +6,32 @@ import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
import gym
|
|
|
import gym_minigrid
|
|
|
+from gym_minigrid.wrappers import *
|
|
|
+
|
|
|
+fig = None
|
|
|
+imshow_obj = None
|
|
|
+
|
|
|
+def redraw(img):
|
|
|
+ global imshow_obj
|
|
|
+
|
|
|
+ if not args.agent_view:
|
|
|
+ img = env.render('rgb_array', tile_size=args.tile_size)
|
|
|
+
|
|
|
+ # Show the first image of the environment
|
|
|
+ if imshow_obj is None:
|
|
|
+ imshow_obj = ax.imshow(img, interpolation='bilinear')
|
|
|
+
|
|
|
+ imshow_obj.set_data(img)
|
|
|
+ fig.canvas.draw()
|
|
|
|
|
|
def reset():
|
|
|
- env.reset()
|
|
|
+ obs = env.reset()
|
|
|
|
|
|
if hasattr(env, 'mission'):
|
|
|
print('Mission: %s' % env.mission)
|
|
|
plt.xlabel(env.mission)
|
|
|
|
|
|
- img = env.render('rgb_array', tile_size=args.tile_size)
|
|
|
- imshow_obj.set_data(img)
|
|
|
- fig.canvas.draw()
|
|
|
+ redraw(obs)
|
|
|
|
|
|
def step(action):
|
|
|
obs, reward, done, info = env.step(action)
|
|
@@ -26,9 +41,7 @@ def step(action):
|
|
|
print('done!')
|
|
|
reset()
|
|
|
|
|
|
- img = env.render('rgb_array', tile_size=args.tile_size)
|
|
|
- imshow_obj.set_data(img)
|
|
|
- fig.canvas.draw()
|
|
|
+ redraw(obs)
|
|
|
|
|
|
def key_handler(event):
|
|
|
print('pressed', event.key)
|
|
@@ -79,11 +92,21 @@ parser.add_argument(
|
|
|
help="size at which to render tiles",
|
|
|
default=32
|
|
|
)
|
|
|
+parser.add_argument(
|
|
|
+ '--agent_view',
|
|
|
+ default=False,
|
|
|
+ help="Draw the agent's partially observable view",
|
|
|
+ action='store_true'
|
|
|
+)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
env = gym.make(args.env_name)
|
|
|
|
|
|
+if args.agent_view:
|
|
|
+ env = RGBImgPartialObsWrapper(env)
|
|
|
+ env = ImgObsWrapper(env)
|
|
|
+
|
|
|
fig, ax = plt.subplots()
|
|
|
|
|
|
# Keyboard handler
|
|
@@ -96,10 +119,6 @@ fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
|
|
|
ax.set_xticks([], [])
|
|
|
ax.set_yticks([], [])
|
|
|
|
|
|
-# Show the first image of the environment
|
|
|
-img = env.render('rgb_array', tile_size=args.tile_size)
|
|
|
-imshow_obj = ax.imshow(img, interpolation='bilinear')
|
|
|
-
|
|
|
reset()
|
|
|
|
|
|
# Show the plot, enter the matplotlib event loop
|