|
@@ -1,7 +1,7 @@
|
|
import os
|
|
import os
|
|
import numpy
|
|
import numpy
|
|
import gym
|
|
import gym
|
|
-from gym.spaces.box import Box
|
|
|
|
|
|
+from gym import spaces
|
|
|
|
|
|
try:
|
|
try:
|
|
import gym_minigrid
|
|
import gym_minigrid
|
|
@@ -15,7 +15,9 @@ def make_env(env_id, seed, rank, log_dir):
|
|
|
|
|
|
env.seed(seed + rank)
|
|
env.seed(seed + rank)
|
|
|
|
|
|
- #env = FlatObsWrapper(env)
|
|
|
|
|
|
+ # Maxime: until RL code supports dict observations, squash observations into a flat vector
|
|
|
|
+ if isinstance(env.observation_space, spaces.Dict):
|
|
|
|
+ env = FlatObsWrapper(env)
|
|
|
|
|
|
# If the input has shape (W,H,3), wrap for PyTorch convolutions
|
|
# If the input has shape (W,H,3), wrap for PyTorch convolutions
|
|
obs_shape = env.observation_space.shape
|
|
obs_shape = env.observation_space.shape
|
|
@@ -30,7 +32,7 @@ class WrapPyTorch(gym.ObservationWrapper):
|
|
def __init__(self, env=None):
|
|
def __init__(self, env=None):
|
|
super(WrapPyTorch, self).__init__(env)
|
|
super(WrapPyTorch, self).__init__(env)
|
|
obs_shape = self.observation_space.shape
|
|
obs_shape = self.observation_space.shape
|
|
- self.observation_space = Box(
|
|
|
|
|
|
+ self.observation_space = spaces.Box(
|
|
self.observation_space.low[0,0,0],
|
|
self.observation_space.low[0,0,0],
|
|
self.observation_space.high[0,0,0],
|
|
self.observation_space.high[0,0,0],
|
|
[obs_shape[2], obs_shape[1], obs_shape[0]]
|
|
[obs_shape[2], obs_shape[1], obs_shape[0]]
|