Ver código fonte

Curriculum learning experiment with Door & Key environment

Maxime Chevalier-Boisvert 7 anos atrás
pai
commit
8a733bfa98
3 arquivos alterados com 51 adições e 13 exclusões
  1. 7 8
      basicrl/envs.py
  2. 25 1
      basicrl/main.py
  3. 19 4
      gym_minigrid/envs/simple_envs.py

+ 7 - 8
basicrl/envs.py

@@ -19,17 +19,16 @@ except:
     pass
 
 
-def make_env(env_id, seed, rank, log_dir):
+def make_env(env_id, seed, rank, log_dir, size=None):
     def _thunk():
+
         env = gym.make(env_id)
-        is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
-        if is_atari:
-            env = make_atari(env_id)
+
         env.seed(seed + rank)
-        if log_dir is not None:
-            env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
-        if is_atari:
-            env = wrap_deepmind(env)
+
+        if size is not None:
+            env.gridSize = size
+
         # If the input has shape (W,H,3), wrap for PyTorch convolutions
         obs_shape = env.observation_space.shape
         if len(obs_shape) == 3 and obs_shape[2] == 3:

+ 25 - 1
basicrl/main.py

@@ -56,7 +56,13 @@ def main():
         viz = Visdom()
         win = None
 
-    envs = [make_env(args.env_name, args.seed, i, args.log_dir)
+
+    paramSteps = [5,6,7,8,9,10,11,12,13,14,15,16]
+
+    roomSize = paramSteps[0]
+    paramSteps = paramSteps[1:]
+
+    envs = [make_env(args.env_name, args.seed, i, args.log_dir, roomSize)
                 for i in range(args.num_processes)]
 
     if args.num_processes > 1:
@@ -250,6 +256,9 @@ def main():
         if j % args.log_interval == 0:
             end = time.time()
             total_num_steps = (j + 1) * args.num_processes * args.num_steps
+
+            print('roomSize=%s' % roomSize)
+
             print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
                 format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
@@ -258,6 +267,21 @@ def main():
                        final_rewards.min(),
                        final_rewards.max(), dist_entropy.data[0],
                        value_loss.data[0], action_loss.data[0]))
+
+            #print(final_rewards.min())
+            if final_rewards.min() > 950 and len(paramSteps) > 0:
+                roomSize = paramSteps[0]
+                paramSteps = paramSteps[1:]
+
+                envs.close()
+                envs = [make_env(args.env_name, args.seed, i, args.log_dir, roomSize) for i in range(args.num_processes)]
+                envs = SubprocVecEnv(envs)
+                obs = envs.reset()
+                update_current_obs(obs)
+
+                # Reset the rewards
+                final_rewards = torch.zeros([args.num_processes, 1])
+
         if args.vis and j % args.vis_interval == 0:
             try:
                 # Sometimes monitor doesn't properly flush the outputs

+ 19 - 4
gym_minigrid/envs/simple_envs.py

@@ -50,14 +50,29 @@ class DoorKeyEnv(MiniGridEnv):
         for i in range(0, gridSz):
             grid.set(splitIdx, i, Wall())
 
+        # Place the agent at a random position and orientation
+        self.startPos = (
+            self._randInt(1, splitIdx),
+            self._randInt(1, gridSz-1)
+        )
+        self.startDir = self._randInt(0, 4)
+
         # Place a door in the wall
         doorIdx = self._randInt(1, gridSz-2)
         grid.set(splitIdx, doorIdx, LockedDoor('yellow'))
 
-        # Place a key on the left side
-        #keyIdx = self._randInt(1 + gridSz // 2, gridSz-2)
-        keyIdx = gridSz-2
-        grid.set(1, keyIdx, Key('yellow'))
+        # Place a yellow key on the left side
+        while True:
+            pos = (
+                self._randInt(1, splitIdx),
+                self._randInt(1, gridSz-1)
+            )
+            if pos == self.startPos:
+                continue
+            if grid.get(*pos) != None:
+                continue
+            grid.set(*pos, Key('yellow'))
+            break
 
         return grid