Explorar o código

RL code updates from upstream repos, removed Atari dependencies

Maxime Chevalier-Boisvert %!s(int64=7) %!d(string=hai) anos
pai
achega
780c75e2cd
Modificáronse 2 ficheiros con 4 adicións e 21 borrados
  1. 1 13
      pytorch_rl/envs.py
  2. 3 8
      pytorch_rl/kfac.py

+ 1 - 13
pytorch_rl/envs.py

@@ -1,19 +1,12 @@
 import os
 import numpy
 import gym
-
 from gym.spaces.box import Box
 
-from baselines.common.atari_wrappers import make_atari, wrap_deepmind
-
-try:
-    import pybullet_envs
-except ImportError:
-    pass
-
 try:
     import gym_minigrid
     from gym_minigrid.wrappers import *
+    #from gym_minigrid.envs import *
 except:
     pass
 
@@ -21,12 +14,7 @@ def make_env(env_id, seed, rank, log_dir):
     def _thunk():
         env = gym.make(env_id)
 
-        is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
-        if is_atari:
-            env = make_atari(env_id)
         env.seed(seed + rank)
-        if is_atari:
-            env = wrap_deepmind(env)
 
         #env = FlatObsWrapper(env)
 

+ 3 - 8
pytorch_rl/kfac.py

@@ -104,7 +104,7 @@ class KFACOptimizer(optim.Optimizer):
                     split_bias(child)
 
         split_bias(model)
-            
+
         super(KFACOptimizer, self).__init__(model.parameters(), defaults)
 
         self.known_modules = {'Linear', 'Conv2d', 'AddBias'}
@@ -203,14 +203,9 @@ class KFACOptimizer(optim.Optimizer):
                 # My asynchronous implementation exists, I will add it later.
                 # Experimenting with different ways to this in PyTorch.
                 self.d_a[m], self.Q_a[m] = torch.symeig(
-                    self.m_aa[m].cpu().double(), eigenvectors=True)
+                    self.m_aa[m], eigenvectors=True)
                 self.d_g[m], self.Q_g[m] = torch.symeig(
-                    self.m_gg[m].cpu().double(), eigenvectors=True)
-                self.d_a[m], self.Q_a[m] = self.d_a[m].float(), self.Q_a[m].float()
-                self.d_g[m], self.Q_g[m] = self.d_g[m].float(), self.Q_g[m].float()
-                if self.m_aa[m].is_cuda:
-                    self.d_a[m], self.Q_a[m] = self.d_a[m].cuda(), self.Q_a[m].cuda()
-                    self.d_g[m], self.Q_g[m] = self.d_g[m].cuda(), self.Q_g[m].cuda()
+                    self.m_gg[m], eigenvectors=True)
 
                 self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
                 self.d_g[m].mul_((self.d_g[m] > 1e-6).float())