浏览代码

Started work on goto env

Maxime Chevalier-Boisvert 7 年之前
父节点
当前提交
67acc5ff18
共有 2 个文件被更改,包括 126 次插入0 次删除
  1. 1 0
      gym_minigrid/envs/__init__.py
  2. 125 0
      gym_minigrid/envs/gotoobject.py

+ 1 - 0
gym_minigrid/envs/__init__.py

@@ -2,4 +2,5 @@ from gym_minigrid.envs.empty import *
 from gym_minigrid.envs.doorkey import *
 from gym_minigrid.envs.multiroom import *
 from gym_minigrid.envs.fetch import *
+from gym_minigrid.envs.gotoobject import *
 from gym_minigrid.envs.fourroomqa import *

+ 125 - 0
gym_minigrid/envs/gotoobject.py

@@ -0,0 +1,125 @@
+from gym_minigrid.minigrid import *
+from gym_minigrid.register import register
+
+class GoToObjectEnv(MiniGridEnv):
+    """
+    Environment in which the agent is instructed to go to a given object
+    named using an English text string
+    """
+
+    def __init__(
+        self,
+        size=6,
+        numObjs=2
+    ):
+        self.numObjs = numObjs
+        super().__init__(gridSize=size, maxSteps=5*size)
+        self.reward_range = (-1000, 1000)
+
+    def _genGrid(self, width, height):
+        assert width == height
+        gridSz = width
+
+        # Create a grid surrounded by walls
+        grid = Grid(width, height)
+        for i in range(0, width):
+            grid.set(i, 0, Wall())
+            grid.set(i, height-1, Wall())
+        for j in range(0, height):
+            grid.set(0, j, Wall())
+            grid.set(width-1, j, Wall())
+
+        # Types and colors of objects we can generate
+        types = ['key', 'ball', 'box']
+        colors = list(COLORS.keys())
+
+        objs = []
+
+        # For each object to be generated
+        for i in range(0, self.numObjs):
+            objType = self._randElem(types)
+            objColor = self._randElem(colors)
+
+            if objType == 'key':
+                obj = Key(objColor)
+            elif objType == 'ball':
+                obj = Ball(objColor)
+            elif objType == 'box':
+                obj = Box(objColor)
+
+            while True:
+                pos = (
+                    self._randInt(1, gridSz - 1),
+                    self._randInt(1, gridSz - 1)
+                )
+
+                if pos != self.startPos:
+                    grid.set(*pos, obj)
+                    break
+
+            objs.append(obj)
+
+        # Choose a random object to be picked up
+        target = objs[self._randInt(0, len(objs))]
+        self.targetType = target.type
+        self.targetColor = target.color
+
+        descStr = '%s %s' % (self.targetColor, self.targetType)
+
+        """
+        # Generate the mission string
+        idx = self._randInt(0, 5)
+        if idx == 0:
+            self.mission = 'get a %s' % descStr
+        elif idx == 1:
+            self.mission = 'go get a %s' % descStr
+        elif idx == 2:
+            self.mission = 'fetch a %s' % descStr
+        elif idx == 3:
+            self.mission = 'go fetch a %s' % descStr
+        elif idx == 4:
+            self.mission = 'you must fetch a %s' % descStr
+        assert hasattr(self, 'mission')
+        """
+
+        self.mission = 'go to the %s' % descStr
+
+        return grid
+
+    def _observation(self, obs):
+        """
+        Encode observations
+        """
+
+        obs = {
+            'image': obs,
+            'mission': self.mission,
+            'advice' : ''
+        }
+
+        return obs
+
+    def _reset(self):
+        obs = MiniGridEnv._reset(self)
+        return self._observation(obs)
+
+    def _step(self, action):
+        obs, reward, done, info = MiniGridEnv._step(self, action)
+
+        #if self.carrying:
+        #    if self.carrying.color == self.targetColor and \
+        #       self.carrying.type == self.targetType:
+        #        reward = 1000 - self.stepCount
+        #        done = True
+        #    else:
+        #        reward = -1000
+        #        done = True
+
+        obs = self._observation(obs)
+
+        return obs, reward, done, info
+
+register(
+    id='MiniGrid-GoToObject-6x6-N2-v0',
+    entry_point='gym_minigrid.envs:GoToObjectEnv'
+)