|
@@ -44,8 +44,6 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
# Actions are discrete integer values
|
|
|
self.action_space = spaces.Discrete(len(self.actions))
|
|
|
|
|
|
- # TODO: dictionary observation_space, to include question?
|
|
|
-
|
|
|
self.reward_range = (-1000, 1000)
|
|
|
|
|
|
def _randPos(self, room, border=1):
|
|
@@ -179,26 +177,16 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
|
|
|
# TODO: identify unique objects
|
|
|
|
|
|
- self.question = "Are there any %ss in the %s room?" % (objType, room.color)
|
|
|
+ self.mission = "Are there any %ss in the %s room?" % (objType, room.color)
|
|
|
self.answer = "yes" if count > 0 else "no"
|
|
|
|
|
|
# TODO: how many X in the Y room question type
|
|
|
|
|
|
- #print(self.question)
|
|
|
+ #print(self.mission)
|
|
|
#print(self.answer)
|
|
|
|
|
|
return grid
|
|
|
|
|
|
- def reset(self):
|
|
|
- obs = MiniGridEnv.reset(self)
|
|
|
-
|
|
|
- obs = {
|
|
|
- 'image': obs,
|
|
|
- 'question': self.question
|
|
|
- }
|
|
|
-
|
|
|
- return obs
|
|
|
-
|
|
|
def step(self, action):
|
|
|
if isinstance(action, dict):
|
|
|
answer = action['answer']
|
|
@@ -211,7 +199,7 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
obs, reward, done, info = MiniGridEnv.step(self, self.actions.wait)
|
|
|
done = True
|
|
|
|
|
|
- if answer == self.answer:
|
|
|
+ if answer == self.mission:
|
|
|
reward = 1000 - self.stepCount
|
|
|
else:
|
|
|
reward = -1000
|
|
@@ -220,11 +208,6 @@ class FourRoomQAEnv(MiniGridEnv):
|
|
|
# Let the superclass handle the action
|
|
|
obs, reward, done, info = MiniGridEnv.step(self, action)
|
|
|
|
|
|
- obs = {
|
|
|
- 'image': obs,
|
|
|
- 'question': self.question
|
|
|
- }
|
|
|
-
|
|
|
return obs, reward, done, info
|
|
|
|
|
|
register(
|