Explorar o código

Made gen_obs a public method, renamed public methods.

Maxime Chevalier-Boisvert %!s(int64=7) %!d(string=hai) anos
pai
achega
15e83a570a
Modificáronse 2 ficheiros con 14 adicións e 13 borrados
  1. 1 1
      gym_minigrid/envs/putnear.py
  2. 13 12
      gym_minigrid/minigrid.py

+ 1 - 1
gym_minigrid/envs/putnear.py

@@ -88,7 +88,7 @@ class PutNearEnv(MiniGridEnv):
 
         obs, reward, done, info = super().step(action)
 
-        u, v = self.getDirVec()
+        u, v = self.get_dir_vec()
         ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
         tx, ty = self.target_pos
 

+ 13 - 12
gym_minigrid/minigrid.py

@@ -687,7 +687,7 @@ class MiniGridEnv(gym.Env):
         self.step_count = 0
 
         # Return first observation
-        obs = self._gen_obs()
+        obs = self.gen_obs()
         return obs
 
     def seed(self, seed=1337):
@@ -696,6 +696,10 @@ class MiniGridEnv(gym.Env):
 
         return [seed]
 
+    @property
+    def steps_remaining(self):
+        return self.max_steps - self.step_count
+
     def __str__(self):
         """
         Produce a pretty string of the environment's grid along with the agent.
@@ -865,10 +869,7 @@ class MiniGridEnv(gym.Env):
 
         return pos
 
-    def getStepsRemaining(self):
-        return self.max_steps - self.step_count
-
-    def getDirVec(self):
+    def get_dir_vec(self):
         """
         Get the direction vector for the agent, pointing in the direction
         of forward movement.
@@ -894,7 +895,7 @@ class MiniGridEnv(gym.Env):
         Get the vector pointing to the right of the agent.
         """
 
-        dx, dy = self.getDirVec()
+        dx, dy = self.get_dir_vec()
         return -dy, dx
 
     def get_view_coords(self, i, j):
@@ -905,7 +906,7 @@ class MiniGridEnv(gym.Env):
         """
 
         ax, ay = self.agent_pos
-        dx, dy = self.getDirVec()
+        dx, dy = self.get_dir_vec()
         rx, ry = self.get_right_vec()
 
         # Compute the absolute coordinates of the top-left view corner
@@ -964,7 +965,7 @@ class MiniGridEnv(gym.Env):
         if vx < 0 or vy < 0 or vx >= AGENT_VIEW_SIZE or vy >= AGENT_VIEW_SIZE:
             return False
 
-        obs = self._gen_obs()
+        obs = self.gen_obs()
         obs_grid = Grid.decode(obs['image'])
         obs_cell = obs_grid.get(vx, vy)
         world_cell = self.grid.get(x, y)
@@ -978,7 +979,7 @@ class MiniGridEnv(gym.Env):
         done = False
 
         # Get the position in front of the agent
-        u, v = self.getDirVec()
+        u, v = self.get_dir_vec()
         fwdPos = (self.agent_pos[0] + u, self.agent_pos[1] + v)
 
         # Get the contents of the cell in front of the agent
@@ -1030,11 +1031,11 @@ class MiniGridEnv(gym.Env):
         if self.step_count >= self.max_steps:
             done = True
 
-        obs = self._gen_obs()
+        obs = self.gen_obs()
 
         return obs, reward, done, {}
 
-    def _gen_obs(self):
+    def gen_obs(self):
         """
         Generate the agent's view (partially observable, low-resolution encoding)
         """
@@ -1075,7 +1076,7 @@ class MiniGridEnv(gym.Env):
 
         return obs
 
-    def getObsRender(self, obs):
+    def get_obs_render(self, obs):
         """
         Render an agent observation for visualization
         """