| 
					
				 | 
			
			
				@@ -6,13 +6,6 @@ import torch.nn.functional as F 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from distributions import Categorical, DiagGaussian 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from utils import orthogonal 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def weights_init(m): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    classname = m.__class__.__name__ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if classname.find('Conv') != -1 or classname.find('Linear') != -1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        orthogonal(m.weight.data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if m.bias is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            m.bias.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class FFPolicy(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         super(FFPolicy, self).__init__() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -31,71 +24,63 @@ class FFPolicy(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return value, action_log_probs, dist_entropy, states 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def weights_init_mlp(m): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    classname = m.__class__.__name__ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if classname.find('Linear') != -1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        m.weight.data.normal_(0, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if m.bias is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            m.bias.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-class CNNPolicy(FFPolicy): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __init__(self, num_inputs, action_space, use_gru): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        super(CNNPolicy, self).__init__() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.conv3 = nn.Conv2d(64, 32, 3, stride=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class RecMLPPolicy(FFPolicy): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, num_inputs, action_space): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        super(RecMLPPolicy, self).__init__() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.linear1 = nn.Linear(32 * 7 * 7, 512) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.action_space = action_space 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert action_space.__class__.__name__ == "Discrete" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        num_outputs = action_space.n 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if use_gru: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self.gru = nn.GRUCell(512, 512) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.a_fc1 = nn.Linear(num_inputs, 64) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.a_fc2 = nn.Linear(64, 64) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.critic_linear = nn.Linear(512, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.v_fc1 = nn.Linear(num_inputs, 64) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.v_fc2 = nn.Linear(64, 64) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.v_fc3 = nn.Linear(64, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if action_space.__class__.__name__ == "Discrete": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            num_outputs = action_space.n 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self.dist = Categorical(512, num_outputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif action_space.__class__.__name__ == "Box": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            num_outputs = action_space.shape[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self.dist = DiagGaussian(512, num_outputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            raise NotImplementedError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Input size, hidden size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.gru = nn.GRUCell(64, 64) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.dist = Categorical(64, num_outputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.train() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.reset_parameters() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     @property 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def state_size(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if hasattr(self, 'gru'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return 512 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Size of the recurrent state of the model (propagated between steps 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return 64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def reset_parameters(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.apply(weights_init) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        relu_gain = nn.init.calculate_gain('relu') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.conv1.weight.data.mul_(relu_gain) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.conv2.weight.data.mul_(relu_gain) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.conv3.weight.data.mul_(relu_gain) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.linear1.weight.data.mul_(relu_gain) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.apply(weights_init_mlp) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if hasattr(self, 'gru'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            orthogonal(self.gru.weight_ih.data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            orthogonal(self.gru.weight_hh.data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self.gru.bias_ih.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self.gru.bias_hh.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        orthogonal(self.gru.weight_ih.data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        orthogonal(self.gru.weight_hh.data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.gru.bias_ih.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.gru.bias_hh.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if self.dist.__class__.__name__ == "DiagGaussian": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.dist.fc_mean.weight.data.mul_(0.01) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def forward(self, inputs, states, masks): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = self.conv1(inputs / 255.0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = F.relu(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = self.conv2(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = F.relu(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = self.conv3(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = F.relu(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        batch_numel = reduce(operator.mul, inputs.size()[1:], 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        inputs = inputs.view(-1, batch_numel) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = x.view(-1, 32 * 7 * 7) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = self.linear1(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        x = F.relu(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.a_fc1(inputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = F.tanh(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.a_fc2(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = F.tanh(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if hasattr(self, 'gru'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if inputs.size(0) == states.size(0): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -104,19 +89,22 @@ class CNNPolicy(FFPolicy): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 x = x.view(-1, states.size(0), x.size(1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 masks = masks.view(-1, states.size(0), 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 outputs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # For every element in the batch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 for i in range(x.size(0)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     hx = states = self.gru(x[i], states * masks[i]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     outputs.append(hx) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 x = torch.cat(outputs, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return self.critic_linear(x), x, states 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def weights_init_mlp(m): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    classname = m.__class__.__name__ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if classname.find('Linear') != -1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        m.weight.data.normal_(0, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if m.bias is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            m.bias.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        actions = x 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.v_fc1(inputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = F.tanh(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.v_fc2(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = F.tanh(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.v_fc3(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        value = x 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return value, actions, states 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class MLPPolicy(FFPolicy): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__(self, num_inputs, action_space): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -181,3 +169,89 @@ class MLPPolicy(FFPolicy): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x = F.tanh(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return value, x, states 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def weights_init_cnn(m): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    classname = m.__class__.__name__ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if classname.find('Conv') != -1 or classname.find('Linear') != -1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        orthogonal(m.weight.data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if m.bias is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            m.bias.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class CNNPolicy(FFPolicy): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, num_inputs, action_space, use_gru): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        super(CNNPolicy, self).__init__() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.conv3 = nn.Conv2d(64, 32, 3, stride=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.linear1 = nn.Linear(32 * 7 * 7, 512) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if use_gru: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.gru = nn.GRUCell(512, 512) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.critic_linear = nn.Linear(512, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if action_space.__class__.__name__ == "Discrete": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            num_outputs = action_space.n 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.dist = Categorical(512, num_outputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif action_space.__class__.__name__ == "Box": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            num_outputs = action_space.shape[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.dist = DiagGaussian(512, num_outputs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise NotImplementedError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.train() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.reset_parameters() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @property 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def state_size(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if hasattr(self, 'gru'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return 512 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def reset_parameters(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.apply(weights_init_cnn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        relu_gain = nn.init.calculate_gain('relu') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.conv1.weight.data.mul_(relu_gain) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.conv2.weight.data.mul_(relu_gain) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.conv3.weight.data.mul_(relu_gain) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.linear1.weight.data.mul_(relu_gain) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if hasattr(self, 'gru'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            orthogonal(self.gru.weight_ih.data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            orthogonal(self.gru.weight_hh.data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.gru.bias_ih.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.gru.bias_hh.data.fill_(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.dist.__class__.__name__ == "DiagGaussian": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.dist.fc_mean.weight.data.mul_(0.01) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def forward(self, inputs, states, masks): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.conv1(inputs / 255.0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = F.relu(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.conv2(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = F.relu(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.conv3(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = F.relu(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = x.view(-1, 32 * 7 * 7) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = self.linear1(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x = F.relu(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if hasattr(self, 'gru'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if inputs.size(0) == states.size(0): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                x = states = self.gru(x, states * masks) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                x = x.view(-1, states.size(0), x.size(1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                masks = masks.view(-1, states.size(0), 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                outputs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for i in range(x.size(0)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    hx = states = self.gru(x[i], states * masks[i]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    outputs.append(hx) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                x = torch.cat(outputs, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return self.critic_linear(x), x, states 
			 |