瀏覽代碼

Added action enumeration, started work on QA environment

Maxime Chevalier-Boisvert 7 年之前
父節點
當前提交
8af443c879
共有 4 個文件被更改,包括 39 次插入32 次删除
  1. 1 1
      README.md
  2. 12 9
      gym_minigrid/envs/fourroomqa.py
  3. 22 17
      gym_minigrid/minigrid.py
  4. 4 5
      standalone.py

+ 1 - 1
README.md

@@ -64,7 +64,7 @@ python3 basicrl/main.py --env-name MiniGrid-Empty-8x8-v0 --no-vis --num-processe
 You can view the result of training using the `enjoy.py` script:
 
 ```
-python3 basicrl/enjoy.py --env-name MiniGrid-Empty-8x8-v0 --load-dir ./trained_models/acktr
+python3 basicrl/enjoy.py --env-name MiniGrid-Empty-6x6-v0 --load-dir ./trained_models/acktr
 ```
 
 ## Included Environments

+ 12 - 9
gym_minigrid/envs/fourroomqa.py

@@ -21,21 +21,24 @@ class FourRoomQAEnv(MiniGridEnv):
     https://arxiv.org/abs/1711.11543
     """
 
-    # TODO: define actions
-
-
-
+    # Enumeration of possible actions
+    class Actions(IntEnum):
+        left = 0
+        right = 1
+        forward = 2
+        toggle = 3
+        say = 4
 
     def __init__(self, size=16):
         assert size >= 8
         super(FourRoomQAEnv, self).__init__(gridSize=size, maxSteps=8*size)
 
-        # TODO: self.actions
-
-        # TODO: self.action_space
-
-
+        # Action enumeration for this environment
+        self.actions = MiniGridEnv.Actions
 
+        # TODO: dictionary action space, to include answer sentence?
+        # Actions are discrete integer values
+        self.action_space = spaces.Discrete(len(self.actions))
 
     def _genGrid(self, width, height):
         grid = super(FourRoomQAEnv, self)._genGrid(width, height)

+ 22 - 17
gym_minigrid/minigrid.py

@@ -1,8 +1,9 @@
 import math
 import gym
+from enum import IntEnum
+import numpy as np
 from gym import error, spaces, utils
 from gym.utils import seeding
-import numpy as np
 from gym_minigrid.rendering import *
 
 # Size in pixels of a cell in the full-scale human view
@@ -458,22 +459,19 @@ class MiniGridEnv(gym.Env):
         'video.frames_per_second' : 10
     }
 
-    # Possible actions
-    NUM_ACTIONS = 4
-    ACTION_LEFT = 0
-    ACTION_RIGHT = 1
-    ACTION_FORWARD = 2
-    ACTION_TOGGLE = 3
+    # Enumeration of possible actions
+    class Actions(IntEnum):
+        left = 0
+        right = 1
+        forward = 2
+        toggle = 3
 
     def __init__(self, gridSize=16, maxSteps=100):
-        # Renderer object used to render the whole grid (full-scale)
-        self.gridRender = None
-
-        # Renderer used to render observations (small-scale agent view)
-        self.obsRender = None
+        # Action enumeration for this environment
+        self.actions = MiniGridEnv.Actions
 
         # Actions are discrete integer values
-        self.action_space = spaces.Discrete(MiniGridEnv.NUM_ACTIONS)
+        self.action_space = spaces.Discrete(len(self.actions))
 
         # The observations are RGB images
         self.observation_space = spaces.Box(
@@ -482,8 +480,15 @@ class MiniGridEnv(gym.Env):
             shape=OBS_ARRAY_SIZE
         )
 
+        # Range of possible rewards
         self.reward_range = (-1, 1000)
 
+        # Renderer object used to render the whole grid (full-scale)
+        self.gridRender = None
+
+        # Renderer used to render observations (small-scale agent view)
+        self.obsRender = None
+
         # Environment configuration
         self.gridSize = gridSize
         self.maxSteps = maxSteps
@@ -615,17 +620,17 @@ class MiniGridEnv(gym.Env):
         done = False
 
         # Rotate left
-        if action == MiniGridEnv.ACTION_LEFT:
+        if action == self.actions.left:
             self.agentDir -= 1
             if self.agentDir < 0:
                 self.agentDir += 4
 
         # Rotate right
-        elif action == MiniGridEnv.ACTION_RIGHT:
+        elif action == self.actions.right:
             self.agentDir = (self.agentDir + 1) % 4
 
         # Move forward
-        elif action == MiniGridEnv.ACTION_FORWARD:
+        elif action == self.actions.forward:
             u, v = self.getDirVec()
             newPos = (self.agentPos[0] + u, self.agentPos[1] + v)
             targetCell = self.grid.get(newPos[0], newPos[1])
@@ -636,7 +641,7 @@ class MiniGridEnv(gym.Env):
                 reward = 1000 - self.stepCount
 
         # Pick up or trigger/activate an item
-        elif action == MiniGridEnv.ACTION_TOGGLE:
+        elif action == self.actions.toggle:
             u, v = self.getDirVec()
             cell = self.grid.get(self.agentPos[0] + u, self.agentPos[1] + v)
             if cell and cell.canPickup() and self.carrying is None:

+ 4 - 5
standalone.py

@@ -9,7 +9,6 @@ import time
 from optparse import OptionParser
 
 import gym_minigrid
-from gym_minigrid.envs import MiniGridEnv
 
 def main():
     parser = OptionParser()
@@ -32,13 +31,13 @@ def main():
     def keyDownCb(keyName):
         action = 0
         if keyName == 'LEFT':
-            action = MiniGridEnv.ACTION_LEFT
+            action = env.actions.left
         elif keyName == 'RIGHT':
-            action = MiniGridEnv.ACTION_RIGHT
+            action = env.actions.right
         elif keyName == 'UP':
-            action = MiniGridEnv.ACTION_FORWARD
+            action = env.actions.forward
         elif keyName == 'SPACE':
-            action = MiniGridEnv.ACTION_TOGGLE
+            action = env.actions.toggle
         elif keyName == 'RETURN':
             env.reset()
         elif keyName == 'ESCAPE':