| 
					
				 | 
			
			
				@@ -128,9 +128,11 @@ def main(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     for j in range(num_updates): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for step in range(args.num_steps): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # Sample actions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            value, action, action_log_prob, states = actor_critic.act(Variable(rollouts.observations[step], volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                                      Variable(rollouts.states[step], volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                                      Variable(rollouts.masks[step], volatile=True)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            value, action, action_log_prob, states = actor_critic.act( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Variable(rollouts.observations[step], volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Variable(rollouts.states[step], volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Variable(rollouts.masks[step], volatile=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             cpu_actions = action.data.squeeze(1).cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # Obser reward and next obs 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -155,17 +157,21 @@ def main(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             update_current_obs(obs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             rollouts.insert(step, current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        next_value = actor_critic(Variable(rollouts.observations[-1], volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  Variable(rollouts.states[-1], volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  Variable(rollouts.masks[-1], volatile=True))[0].data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        next_value = actor_critic( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Variable(rollouts.observations[-1], volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Variable(rollouts.states[-1], volatile=True), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Variable(rollouts.masks[-1], volatile=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        )[0].data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if args.algo in ['a2c', 'acktr']: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(rollouts.observations[:-1].view(-1, *obs_shape)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                                                           Variable(rollouts.states[0].view(-1, actor_critic.state_size)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                                                           Variable(rollouts.masks[:-1].view(-1, 1)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                                                           Variable(rollouts.actions.view(-1, action_shape))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Variable(rollouts.observations[:-1].view(-1, *obs_shape)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Variable(rollouts.states[:-1].view(-1, actor_critic.state_size)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Variable(rollouts.masks[:-1].view(-1, 1)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Variable(rollouts.actions.view(-1, action_shape)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             values = values.view(args.num_steps, args.num_processes, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -205,11 +211,9 @@ def main(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             for e in range(args.ppo_epoch): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if args.recurrent_policy: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    data_generator = rollouts.recurrent_generator(advantages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                            args.num_mini_batch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    data_generator = rollouts.recurrent_generator(advantages, args.num_mini_batch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    data_generator = rollouts.feed_forward_generator(advantages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                            args.num_mini_batch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    data_generator = rollouts.feed_forward_generator(advantages, args.num_mini_batch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 for sample in data_generator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     observations_batch, states_batch, actions_batch, \ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -217,10 +221,12 @@ def main(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                             adv_targ = sample 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     # Reshape to do in a single forward pass for all steps 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(observations_batch), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                                                                   Variable(states_batch), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                                                                   Variable(masks_batch), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                                                                                   Variable(actions_batch)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        Variable(observations_batch), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        Variable(states_batch), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        Variable(masks_batch), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        Variable(actions_batch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     adv_targ = Variable(adv_targ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     ratio = torch.exp(action_log_probs - Variable(old_action_log_probs_batch)) 
			 |