|
@@ -0,0 +1,125 @@
|
|
|
+from gym_minigrid.minigrid import *
|
|
|
+from gym_minigrid.register import register
|
|
|
+
|
|
|
+class GoToObjectEnv(MiniGridEnv):
|
|
|
+ """
|
|
|
+ Environment in which the agent is instructed to go to a given object
|
|
|
+ named using an English text string
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ size=6,
|
|
|
+ numObjs=2
|
|
|
+ ):
|
|
|
+ self.numObjs = numObjs
|
|
|
+ super().__init__(gridSize=size, maxSteps=5*size)
|
|
|
+ self.reward_range = (-1000, 1000)
|
|
|
+
|
|
|
+ def _genGrid(self, width, height):
|
|
|
+ assert width == height
|
|
|
+ gridSz = width
|
|
|
+
|
|
|
+ # Create a grid surrounded by walls
|
|
|
+ grid = Grid(width, height)
|
|
|
+ for i in range(0, width):
|
|
|
+ grid.set(i, 0, Wall())
|
|
|
+ grid.set(i, height-1, Wall())
|
|
|
+ for j in range(0, height):
|
|
|
+ grid.set(0, j, Wall())
|
|
|
+ grid.set(width-1, j, Wall())
|
|
|
+
|
|
|
+ # Types and colors of objects we can generate
|
|
|
+ types = ['key', 'ball', 'box']
|
|
|
+ colors = list(COLORS.keys())
|
|
|
+
|
|
|
+ objs = []
|
|
|
+
|
|
|
+ # For each object to be generated
|
|
|
+ for i in range(0, self.numObjs):
|
|
|
+ objType = self._randElem(types)
|
|
|
+ objColor = self._randElem(colors)
|
|
|
+
|
|
|
+ if objType == 'key':
|
|
|
+ obj = Key(objColor)
|
|
|
+ elif objType == 'ball':
|
|
|
+ obj = Ball(objColor)
|
|
|
+ elif objType == 'box':
|
|
|
+ obj = Box(objColor)
|
|
|
+
|
|
|
+ while True:
|
|
|
+ pos = (
|
|
|
+ self._randInt(1, gridSz - 1),
|
|
|
+ self._randInt(1, gridSz - 1)
|
|
|
+ )
|
|
|
+
|
|
|
+ if pos != self.startPos:
|
|
|
+ grid.set(*pos, obj)
|
|
|
+ break
|
|
|
+
|
|
|
+ objs.append(obj)
|
|
|
+
|
|
|
+ # Choose a random object to be picked up
|
|
|
+ target = objs[self._randInt(0, len(objs))]
|
|
|
+ self.targetType = target.type
|
|
|
+ self.targetColor = target.color
|
|
|
+
|
|
|
+ descStr = '%s %s' % (self.targetColor, self.targetType)
|
|
|
+
|
|
|
+ """
|
|
|
+ # Generate the mission string
|
|
|
+ idx = self._randInt(0, 5)
|
|
|
+ if idx == 0:
|
|
|
+ self.mission = 'get a %s' % descStr
|
|
|
+ elif idx == 1:
|
|
|
+ self.mission = 'go get a %s' % descStr
|
|
|
+ elif idx == 2:
|
|
|
+ self.mission = 'fetch a %s' % descStr
|
|
|
+ elif idx == 3:
|
|
|
+ self.mission = 'go fetch a %s' % descStr
|
|
|
+ elif idx == 4:
|
|
|
+ self.mission = 'you must fetch a %s' % descStr
|
|
|
+ assert hasattr(self, 'mission')
|
|
|
+ """
|
|
|
+
|
|
|
+ self.mission = 'go to the %s' % descStr
|
|
|
+
|
|
|
+ return grid
|
|
|
+
|
|
|
+ def _observation(self, obs):
|
|
|
+ """
|
|
|
+ Encode observations
|
|
|
+ """
|
|
|
+
|
|
|
+ obs = {
|
|
|
+ 'image': obs,
|
|
|
+ 'mission': self.mission,
|
|
|
+ 'advice' : ''
|
|
|
+ }
|
|
|
+
|
|
|
+ return obs
|
|
|
+
|
|
|
+ def _reset(self):
|
|
|
+ obs = MiniGridEnv._reset(self)
|
|
|
+ return self._observation(obs)
|
|
|
+
|
|
|
+ def _step(self, action):
|
|
|
+ obs, reward, done, info = MiniGridEnv._step(self, action)
|
|
|
+
|
|
|
+ #if self.carrying:
|
|
|
+ # if self.carrying.color == self.targetColor and \
|
|
|
+ # self.carrying.type == self.targetType:
|
|
|
+ # reward = 1000 - self.stepCount
|
|
|
+ # done = True
|
|
|
+ # else:
|
|
|
+ # reward = -1000
|
|
|
+ # done = True
|
|
|
+
|
|
|
+ obs = self._observation(obs)
|
|
|
+
|
|
|
+ return obs, reward, done, info
|
|
|
+
|
|
|
+register(
|
|
|
+ id='MiniGrid-GoToObject-6x6-N2-v0',
|
|
|
+ entry_point='gym_minigrid.envs:GoToObjectEnv'
|
|
|
+)
|