123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- from gym_minigrid.minigrid import *
- from gym_minigrid.register import register
- class Room:
- def __init__(
- self,
- top,
- size,
- color,
- objects
- ):
- self.top = top
- self.size = size
- # Color of the room
- self.color = color
- # List of objects contained
- self.objects = objects
- class FourRoomQAEnv(MiniGridEnv):
- """
- Environment to experiment with embodied question answering
- https://arxiv.org/abs/1711.11543
- """
- # Enumeration of possible actions
- class Actions(IntEnum):
- left = 0
- right = 1
- forward = 2
- toggle = 3
- say = 4
- def __init__(self, size=16):
- assert size >= 10
- super(FourRoomQAEnv, self).__init__(gridSize=size, maxSteps=8*size)
- # Action enumeration for this environment
- self.actions = MiniGridEnv.Actions
- # TODO: dictionary action_space, to include answer sentence?
- # Actions are discrete integer values
- self.action_space = spaces.Discrete(len(self.actions))
- # TODO: dictionary observation_space, to include question?
- def _randPos(self, room, border=1):
- return (
- self._randInt(
- room.top[0] + border,
- room.top[0] + room.size[0] - border
- ),
- self._randInt(
- room.top[1] + border,
- room.top[1] + room.size[1] - border
- ),
- )
- def _genGrid(self, width, height):
- grid = Grid(width, height)
- # Horizontal and vertical split indices
- vSplitIdx = self._randInt(5, width-4)
- hSplitIdx = self._randInt(5, height-4)
- # Create the four rooms
- self.rooms = []
- self.rooms.append(Room(
- (0, 0),
- (vSplitIdx, hSplitIdx),
- 'red',
- []
- ))
- self.rooms.append(Room(
- (vSplitIdx, 0),
- (width - vSplitIdx, hSplitIdx),
- 'purple',
- []
- ))
- self.rooms.append(Room(
- (0, hSplitIdx),
- (vSplitIdx, height - hSplitIdx),
- 'blue',
- []
- ))
- self.rooms.append(Room(
- (vSplitIdx, hSplitIdx),
- (width - vSplitIdx, height - hSplitIdx),
- 'yellow',
- []
- ))
- # Place the room walls
- for room in self.rooms:
- x, y = room.top
- w, h = room.size
- # Horizontal walls
- for i in range(w):
- grid.set(x + i, y, Wall(room.color))
- grid.set(x + i, y + h - 1, Wall(room.color))
- # Vertical walls
- for j in range(h):
- grid.set(x, y + j, Wall(room.color))
- grid.set(x + w - 1, y + j, Wall(room.color))
- # Place wall openings connecting the rooms
- hIdx = self._randInt(1, hSplitIdx-1)
- grid.set(vSplitIdx, hIdx, None)
- grid.set(vSplitIdx-1, hIdx, None)
- hIdx = self._randInt(hSplitIdx+1, height-1)
- grid.set(vSplitIdx, hIdx, None)
- grid.set(vSplitIdx-1, hIdx, None)
- vIdx = self._randInt(1, vSplitIdx-1)
- grid.set(vIdx, hSplitIdx, None)
- grid.set(vIdx, hSplitIdx-1, None)
- vIdx = self._randInt(vSplitIdx+1, width-1)
- grid.set(vIdx, hSplitIdx, None)
- grid.set(vIdx, hSplitIdx-1, None)
- # Select a random position for the agent to start at
- self.startDir = self._randInt(0, 4)
- room = self._randElem(self.rooms)
- self.startPos = self._randPos(room)
- # Possible object types and colors
- types = ['key', 'ball', 'box']
- colors = list(COLORS.keys())
- # Place a number of random objects
- numObjs = self._randInt(1, 10)
- for i in range(0, numObjs):
- # Generate a random object
- objType = self._randElem(types)
- objColor = self._randElem(colors)
- if objType == 'key':
- obj = Key(objColor)
- elif objType == 'ball':
- obj = Ball(objColor)
- elif objType == 'box':
- obj = Box(objColor)
- # Pick a random position that doesn't overlap with anything
- while True:
- room = self._randElem(self.rooms)
- pos = self._randPos(room, border=2)
- if pos == self.startPos:
- continue
- if grid.get(*pos) != None:
- continue
- grid.set(*pos, obj)
- break
- room.objects.append(obj)
- # Question examples:
- # - What color is the X?
- # - What color is the X in the ROOM?
- # - What room is the X located in?
- # - What color is the X in the blue room?
- # - How many rooms contain chairs?
- # - How many keys are there in the yellow room?
- # - How many <OBJs> in the <ROOM>?
- # Pick a random room to be the subject of the question
- room = self._randElem(self.rooms)
- # Pick a random object type
- objType = self._randElem(types)
- # Count the number of objects of this type in the room
- count = len(list(filter(lambda o: o.type == objType, room.objects)))
- # TODO: identify unique objects
- self.question = "Are there any %ss in the %s room?" % (objType, room.color)
- self.answer = "yes" if count > 0 else "no"
- # TODO: how many X in the Y room question type
- print(self.question)
- print(self.answer)
- return grid
- def _reset(self):
- obs = MiniGridEnv._reset(self)
- obs = {
- 'image': obs,
- 'question': self.question
- }
- return obs
- def _step(self, action):
- if isinstance(action, dict):
- answer = action['answer']
- action = action['action']
- else:
- answer = ''
- obs, reward, done, info = MiniGridEnv._step(self, action)
- if answer == self.answer:
- reward = 1000 - self.stepCount
- done = True
- obs = {
- 'image': obs,
- 'question': self.question
- }
- return obs, reward, done, info
- register(
- id='MiniGrid-FourRoomQA-v0',
- entry_point='gym_minigrid.envs:FourRoomQAEnv'
- )
|