fetch.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from gym_minigrid.minigrid import COLOR_NAMES, Ball, Grid, Key, MiniGridEnv
  2. class FetchEnv(MiniGridEnv):
  3. """
  4. Environment in which the agent has to fetch a random object
  5. named using English text strings
  6. """
  7. def __init__(self, size=8, numObjs=3, **kwargs):
  8. self.numObjs = numObjs
  9. super().__init__(
  10. grid_size=size,
  11. max_steps=5 * size**2,
  12. # Set this to True for maximum speed
  13. see_through_walls=True,
  14. **kwargs,
  15. )
  16. def _gen_grid(self, width, height):
  17. self.grid = Grid(width, height)
  18. # Generate the surrounding walls
  19. self.grid.horz_wall(0, 0)
  20. self.grid.horz_wall(0, height - 1)
  21. self.grid.vert_wall(0, 0)
  22. self.grid.vert_wall(width - 1, 0)
  23. types = ["key", "ball"]
  24. objs = []
  25. # For each object to be generated
  26. while len(objs) < self.numObjs:
  27. objType = self._rand_elem(types)
  28. objColor = self._rand_elem(COLOR_NAMES)
  29. if objType == "key":
  30. obj = Key(objColor)
  31. elif objType == "ball":
  32. obj = Ball(objColor)
  33. else:
  34. raise ValueError(
  35. "{} object type given. Object type can only be of values key and ball.".format(
  36. objType
  37. )
  38. )
  39. self.place_obj(obj)
  40. objs.append(obj)
  41. # Randomize the player start position and orientation
  42. self.place_agent()
  43. # Choose a random object to be picked up
  44. target = objs[self._rand_int(0, len(objs))]
  45. self.targetType = target.type
  46. self.targetColor = target.color
  47. descStr = f"{self.targetColor} {self.targetType}"
  48. # Generate the mission string
  49. idx = self._rand_int(0, 5)
  50. if idx == 0:
  51. self.mission = "get a %s" % descStr
  52. elif idx == 1:
  53. self.mission = "go get a %s" % descStr
  54. elif idx == 2:
  55. self.mission = "fetch a %s" % descStr
  56. elif idx == 3:
  57. self.mission = "go fetch a %s" % descStr
  58. elif idx == 4:
  59. self.mission = "you must fetch a %s" % descStr
  60. assert hasattr(self, "mission")
  61. def step(self, action):
  62. obs, reward, done, info = MiniGridEnv.step(self, action)
  63. if self.carrying:
  64. if (
  65. self.carrying.color == self.targetColor
  66. and self.carrying.type == self.targetType
  67. ):
  68. reward = self._reward()
  69. done = True
  70. else:
  71. reward = 0
  72. done = True
  73. return obs, reward, done, info