فهرست منبع

Added wrapper for one-hot string encoding. Fixed bugs in goto env.

Maxime Chevalier-Boisvert 7 سال پیش
والد
کامیت
360927639b
4فایلهای تغییر یافته به همراه70 افزوده شده و 12 حذف شده
  1. 11 2
      gym_minigrid/envs/gotoobject.py
  2. 55 0
      gym_minigrid/wrappers.py
  3. 4 7
      pytorch-rl/envs.py
  4. 0 3
      pytorch-rl/model.py

+ 11 - 2
gym_minigrid/envs/gotoobject.py

@@ -37,8 +37,8 @@ class GoToObjectEnv(MiniGridEnv):
         objs = []
         objPos = []
 
-        # For each object to be generated
-        for i in range(0, self.numObjs):
+        # Until we have generated all the objects
+        while len(objs) < self.numObjs:
             objType = self._randElem(types)
             objColor = self._randElem(colors)
 
@@ -129,7 +129,16 @@ class GoToObjectEnv(MiniGridEnv):
 
         return obs, reward, done, info
 
+class GotoEnv8x8N2(GoToObjectEnv):
+    def __init__(self):
+        super().__init__(size=8, numObjs=2)
+
 register(
     id='MiniGrid-GoToObject-6x6-N2-v0',
     entry_point='gym_minigrid.envs:GoToObjectEnv'
 )
+
+register(
+    id='MiniGrid-GoToObject-8x8-N2-v0',
+    entry_point='gym_minigrid.envs:GotoEnv8x8N2'
+)

+ 55 - 0
gym_minigrid/wrappers.py

@@ -1,5 +1,11 @@
 import math
+import operator
+from functools import reduce
+
+import numpy as np
+
 import gym
+from gym import error, spaces, utils
 
 class ActionBonus(gym.core.Wrapper):
     """
@@ -67,3 +73,52 @@ class StateBonus(gym.core.Wrapper):
         reward += bonus
 
         return obs, reward, done, info
+
+class FlatObsWrapper(gym.core.ObservationWrapper):
+    """
+    Encode mission strings using a one-hot scheme,
+    and combine these with observed images into one flat array
+    """
+
+    def __init__(self, env, maxStrLen=48):
+        super().__init__(env)
+
+        self.maxStrLen = maxStrLen
+        self.numCharCodes = 27
+
+        obsSize = batch_numel = reduce(operator.mul, self.observation_space.shape, 1)
+
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=obsSize + self.numCharCodes * self.maxStrLen
+        )
+
+        self.cachedStr = None
+        self.cachedArray = None
+
+    def _observation(self, obs):
+        image = obs['image']
+        mission = obs['mission']
+
+        # Cache the last-encoded mission string
+        if mission != self.cachedStr:
+            assert len(mission) <= self.maxStrLen, "mission string too long"
+            mission = mission.lower()
+
+            strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes))
+
+            for idx, ch in enumerate(mission):
+                if ch >= 'a' and ch <= 'z':
+                    chNo = ord(ch) - ord('a')
+                elif ch == ' ':
+                    chNo = ord('z') - ord('a') + 1
+                assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
+                strArray[idx, chNo] = 1
+
+            self.cachedStr = mission
+            self.cachedArray = strArray
+
+        obs = np.hstack((image.flatten(), self.cachedArray.flatten()))
+
+        return obs

+ 4 - 7
pytorch-rl/envs.py

@@ -4,7 +4,6 @@ import gym
 
 from gym.spaces.box import Box
 
-from baselines import bench
 from baselines.common.atari_wrappers import make_atari, wrap_deepmind
 
 try:
@@ -18,30 +17,28 @@ try:
 except:
     pass
 
-
 def make_env(env_id, seed, rank, log_dir):
     def _thunk():
         env = gym.make(env_id)
+
         is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
         if is_atari:
             env = make_atari(env_id)
         env.seed(seed + rank)
-        if log_dir is not None:
-            env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
         if is_atari:
             env = wrap_deepmind(env)
+
+        #env = FlatObsWrapper(env)
+
         # If the input has shape (W,H,3), wrap for PyTorch convolutions
         obs_shape = env.observation_space.shape
         if len(obs_shape) == 3 and obs_shape[2] == 3:
             env = WrapPyTorch(env)
 
-        #env = StateBonus(env)
-
         return env
 
     return _thunk
 
-
 class WrapPyTorch(gym.ObservationWrapper):
     def __init__(self, env=None):
         super(WrapPyTorch, self).__init__(env)

+ 0 - 3
pytorch-rl/model.py

@@ -13,7 +13,6 @@ def weights_init(m):
         if m.bias is not None:
             m.bias.data.fill_(0)
 
-
 class FFPolicy(nn.Module):
     def __init__(self):
         super(FFPolicy, self).__init__()
@@ -111,7 +110,6 @@ class CNNPolicy(FFPolicy):
                 x = torch.cat(outputs, 0)
         return self.critic_linear(x), x, states
 
-
 def weights_init_mlp(m):
     classname = m.__class__.__name__
     if classname.find('Linear') != -1:
@@ -120,7 +118,6 @@ def weights_init_mlp(m):
         if m.bias is not None:
             m.bias.data.fill_(0)
 
-
 class MLPPolicy(FFPolicy):
     def __init__(self, num_inputs, action_space):
         super(MLPPolicy, self).__init__()