浏览代码

Updated default arguments for RL code

Maxime Chevalier-Boisvert 7 年之前
父节点
当前提交
cd33e57ae6
共有 3 个文件被更改,包括 7 次插入16 次删除
  1. 4 4
      pytorch_rl/arguments.py
  2. 2 2
      pytorch_rl/enjoy.py
  3. 1 10
      pytorch_rl/model.py

+ 4 - 4
pytorch_rl/arguments.py

@@ -27,8 +27,8 @@ def get_args():
                         help='value loss coefficient (default: 0.5)')
     parser.add_argument('--seed', type=int, default=1,
                         help='random seed (default: 1)')
-    parser.add_argument('--num-processes', type=int, default=16,
-                        help='how many training CPU processes to use (default: 16)')
+    parser.add_argument('--num-processes', type=int, default=32,
+                        help='how many training CPU processes to use (default: 32)')
     parser.add_argument('--num-steps', type=int, default=5,
                         help='number of forward steps in A2C (default: 5)')
     parser.add_argument('--ppo-epoch', type=int, default=4,
@@ -37,8 +37,8 @@ def get_args():
                         help='number of batches for ppo (default: 32)')
     parser.add_argument('--clip-param', type=float, default=0.2,
                         help='ppo clip parameter (default: 0.2)')
-    parser.add_argument('--num-stack', type=int, default=4,
-                        help='number of frames to stack (default: 4)')
+    parser.add_argument('--num-stack', type=int, default=1,
+                        help='number of frames to stack (default: 1)')
     parser.add_argument('--log-interval', type=int, default=10,
                         help='log interval, one log per n updates (default: 10)')
     parser.add_argument('--save-interval', type=int, default=100,

+ 2 - 2
pytorch_rl/enjoy.py

@@ -14,8 +14,8 @@ from envs import make_env
 parser = argparse.ArgumentParser(description='RL')
 parser.add_argument('--seed', type=int, default=1,
                     help='random seed (default: 1)')
-parser.add_argument('--num-stack', type=int, default=4,
-                    help='number of frames to stack (default: 4)')
+parser.add_argument('--num-stack', type=int, default=1,
+                    help='number of frames to stack (default: 1)')
 parser.add_argument('--log-interval', type=int, default=10,
                     help='log interval, one log per n updates (default: 10)')
 parser.add_argument('--env-name', default='PongNoFrameskip-v4',

+ 1 - 10
pytorch_rl/model.py

@@ -232,15 +232,6 @@ class CNNPolicy(FFPolicy):
         x = F.relu(x)
 
         if hasattr(self, 'gru'):
-            if inputs.size(0) == states.size(0):
-                x = states = self.gru(x, states * masks)
-            else:
-                x = x.view(-1, states.size(0), x.size(1))
-                masks = masks.view(-1, states.size(0), 1)
-                outputs = []
-                for i in range(x.size(0)):
-                    hx = states = self.gru(x[i], states * masks[i])
-                    outputs.append(hx)
-                x = torch.cat(outputs, 0)
+            x = states = self.gru(x, states * masks)
 
         return self.critic_linear(x), x, states