putnear.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from gym_minigrid.minigrid import (
  2. COLOR_NAMES,
  3. Ball,
  4. Box,
  5. Grid,
  6. Key,
  7. MiniGridEnv,
  8. MissionSpace,
  9. )
  10. class PutNearEnv(MiniGridEnv):
  11. """
  12. ### Description
  13. The agent is instructed through a textual string to pick up an object and
  14. place it next to another object. This environment is easy to solve with two
  15. objects, but difficult to solve with more, as it involves both textual
  16. understanding and spatial reasoning involving multiple objects.
  17. ### Mission Space
  18. "put the {move_color} {move_type} near the {target_color} {target_type}"
  19. {move_color} and {target_color} can be "red", "green", "blue", "purple",
  20. "yellow" or "grey".
  21. {move_type} and {target_type} Can be "box", "ball" or "key".
  22. ### Action Space
  23. | Num | Name | Action |
  24. |-----|--------------|-------------------|
  25. | 0 | left | Turn left |
  26. | 1 | right | Turn right |
  27. | 2 | forward | Move forward |
  28. | 3 | pickup | Pick up an object |
  29. | 4 | drop | Drop an object |
  30. | 5 | toggle | Unused |
  31. | 6 | done | Unused |
  32. ### Observation Encoding
  33. - Each tile is encoded as a 3 dimensional tuple:
  34. `(OBJECT_IDX, COLOR_IDX, STATE)`
  35. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  36. [gym_minigrid/minigrid.py](gym_minigrid/minigrid.py)
  37. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  38. ### Rewards
  39. A reward of '1' is given for success, and '0' for failure.
  40. ### Termination
  41. The episode ends if any one of the following conditions is met:
  42. 1. The agent picks up the wrong object.
  43. 2. The agent drop the correct object near the target.
  44. 3. Timeout (see `max_steps`).
  45. ### Registered Configurations
  46. N: number of objects.
  47. - `MiniGrid-PutNear-6x6-N2-v0`
  48. - `MiniGrid-PutNear-8x8-N3-v0`
  49. """
  50. def __init__(self, size=6, numObjs=2, **kwargs):
  51. self.size = size
  52. self.numObjs = numObjs
  53. self.obj_types = ["key", "ball", "box"]
  54. mission_space = MissionSpace(
  55. mission_func=lambda move_color, move_type, target_color, target_type: f"put the {move_color} {move_type} near the {target_color} {target_type}",
  56. ordered_placeholders=[
  57. COLOR_NAMES,
  58. self.obj_types,
  59. COLOR_NAMES,
  60. self.obj_types,
  61. ],
  62. )
  63. super().__init__(
  64. mission_space=mission_space,
  65. width=size,
  66. height=size,
  67. max_steps=5 * size,
  68. # Set this to True for maximum speed
  69. see_through_walls=True,
  70. **kwargs,
  71. )
  72. def _gen_grid(self, width, height):
  73. self.grid = Grid(width, height)
  74. # Generate the surrounding walls
  75. self.grid.horz_wall(0, 0)
  76. self.grid.horz_wall(0, height - 1)
  77. self.grid.vert_wall(0, 0)
  78. self.grid.vert_wall(width - 1, 0)
  79. # Types and colors of objects we can generate
  80. types = ["key", "ball", "box"]
  81. objs = []
  82. objPos = []
  83. def near_obj(env, p1):
  84. for p2 in objPos:
  85. dx = p1[0] - p2[0]
  86. dy = p1[1] - p2[1]
  87. if abs(dx) <= 1 and abs(dy) <= 1:
  88. return True
  89. return False
  90. # Until we have generated all the objects
  91. while len(objs) < self.numObjs:
  92. objType = self._rand_elem(types)
  93. objColor = self._rand_elem(COLOR_NAMES)
  94. # If this object already exists, try again
  95. if (objType, objColor) in objs:
  96. continue
  97. if objType == "key":
  98. obj = Key(objColor)
  99. elif objType == "ball":
  100. obj = Ball(objColor)
  101. elif objType == "box":
  102. obj = Box(objColor)
  103. else:
  104. raise ValueError(
  105. "{} object type given. Object type can only be of values key, ball and box.".format(
  106. objType
  107. )
  108. )
  109. pos = self.place_obj(obj, reject_fn=near_obj)
  110. objs.append((objType, objColor))
  111. objPos.append(pos)
  112. # Randomize the agent start position and orientation
  113. self.place_agent()
  114. # Choose a random object to be moved
  115. objIdx = self._rand_int(0, len(objs))
  116. self.move_type, self.moveColor = objs[objIdx]
  117. self.move_pos = objPos[objIdx]
  118. # Choose a target object (to put the first object next to)
  119. while True:
  120. targetIdx = self._rand_int(0, len(objs))
  121. if targetIdx != objIdx:
  122. break
  123. self.target_type, self.target_color = objs[targetIdx]
  124. self.target_pos = objPos[targetIdx]
  125. self.mission = "put the {} {} near the {} {}".format(
  126. self.moveColor,
  127. self.move_type,
  128. self.target_color,
  129. self.target_type,
  130. )
  131. def step(self, action):
  132. preCarrying = self.carrying
  133. obs, reward, terminated, truncated, info = super().step(action)
  134. u, v = self.dir_vec
  135. ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
  136. tx, ty = self.target_pos
  137. # If we picked up the wrong object, terminate the episode
  138. if action == self.actions.pickup and self.carrying:
  139. if (
  140. self.carrying.type != self.move_type
  141. or self.carrying.color != self.moveColor
  142. ):
  143. terminated = True
  144. # If successfully dropping an object near the target
  145. if action == self.actions.drop and preCarrying:
  146. if self.grid.get(ox, oy) is preCarrying:
  147. if abs(ox - tx) <= 1 and abs(oy - ty) <= 1:
  148. reward = self._reward()
  149. terminated = True
  150. return obs, reward, terminated, truncated, info