123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- from gym_minigrid.minigrid import *
- from gym_minigrid.register import register
- class Room:
- def __init__(
- self,
- top,
- size,
- color,
- objects
- ):
- self.top = top
- self.size = size
-
- self.color = color
-
- self.objects = objects
- class FourRoomQAEnv(MiniGridEnv):
- """
- Environment to experiment with embodied question answering
- https://arxiv.org/abs/1711.11543
- """
-
- class Actions(IntEnum):
- left = 0
- right = 1
- forward = 2
- toggle = 3
- say = 4
- def __init__(self, size=16):
- assert size >= 10
- super(FourRoomQAEnv, self).__init__(gridSize=size, maxSteps=8*size)
-
- self.actions = MiniGridEnv.Actions
-
-
- self.action_space = spaces.Discrete(len(self.actions))
-
- def _randPos(self, room, border=1):
- return (
- self._randInt(
- room.top[0] + border,
- room.top[0] + room.size[0] - border
- ),
- self._randInt(
- room.top[1] + border,
- room.top[1] + room.size[1] - border
- ),
- )
- def _genGrid(self, width, height):
- grid = Grid(width, height)
-
- vSplitIdx = self._randInt(5, width-4)
- hSplitIdx = self._randInt(5, height-4)
-
- self.rooms = []
- self.rooms.append(Room(
- (0, 0),
- (vSplitIdx, hSplitIdx),
- 'red',
- []
- ))
- self.rooms.append(Room(
- (vSplitIdx, 0),
- (width - vSplitIdx, hSplitIdx),
- 'purple',
- []
- ))
- self.rooms.append(Room(
- (0, hSplitIdx),
- (vSplitIdx, height - hSplitIdx),
- 'blue',
- []
- ))
- self.rooms.append(Room(
- (vSplitIdx, hSplitIdx),
- (width - vSplitIdx, height - hSplitIdx),
- 'yellow',
- []
- ))
-
- for room in self.rooms:
- x, y = room.top
- w, h = room.size
-
- for i in range(w):
- grid.set(x + i, y, Wall(room.color))
- grid.set(x + i, y + h - 1, Wall(room.color))
-
- for j in range(h):
- grid.set(x, y + j, Wall(room.color))
- grid.set(x + w - 1, y + j, Wall(room.color))
-
- hIdx = self._randInt(1, hSplitIdx-1)
- grid.set(vSplitIdx, hIdx, None)
- grid.set(vSplitIdx-1, hIdx, None)
- hIdx = self._randInt(hSplitIdx+1, height-1)
- grid.set(vSplitIdx, hIdx, None)
- grid.set(vSplitIdx-1, hIdx, None)
- vIdx = self._randInt(1, vSplitIdx-1)
- grid.set(vIdx, hSplitIdx, None)
- grid.set(vIdx, hSplitIdx-1, None)
- vIdx = self._randInt(vSplitIdx+1, width-1)
- grid.set(vIdx, hSplitIdx, None)
- grid.set(vIdx, hSplitIdx-1, None)
-
- self.startDir = self._randInt(0, 4)
- room = self._randElem(self.rooms)
- self.startPos = self._randPos(room)
-
- types = ['key', 'ball', 'box']
- colors = list(COLORS.keys())
-
- numObjs = self._randInt(1, 10)
- for i in range(0, numObjs):
-
- objType = self._randElem(types)
- objColor = self._randElem(colors)
- if objType == 'key':
- obj = Key(objColor)
- elif objType == 'ball':
- obj = Ball(objColor)
- elif objType == 'box':
- obj = Box(objColor)
-
- while True:
- room = self._randElem(self.rooms)
- pos = self._randPos(room, border=2)
- if pos == self.startPos:
- continue
- if grid.get(*pos) != None:
- continue
- grid.set(*pos, obj)
- break
- room.objects.append(obj)
-
-
-
-
-
-
-
-
-
- room = self._randElem(self.rooms)
-
- objType = self._randElem(types)
-
- count = len(list(filter(lambda o: o.type == objType, room.objects)))
-
- self.question = "Are there any %ss in the %s room?" % (objType, room.color)
- self.answer = "yes" if count > 0 else "no"
-
- print(self.question)
- print(self.answer)
- return grid
- def _reset(self):
- obs = MiniGridEnv._reset(self)
- obs = {
- 'image': obs,
- 'question': self.question
- }
- return obs
- def _step(self, action):
- if isinstance(action, dict):
- answer = action['answer']
- action = action['action']
- else:
- answer = ''
- obs, reward, done, info = MiniGridEnv._step(self, action)
- if answer == self.answer:
- reward = 1000 - self.stepCount
- done = True
- obs = {
- 'image': obs,
- 'question': self.question
- }
- return obs, reward, done, info
- register(
- id='MiniGrid-FourRoomQA-v0',
- entry_point='gym_minigrid.envs:FourRoomQAEnv'
- )
|