putnear.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from typing import Optional
  2. from minigrid.core.constants import COLOR_NAMES
  3. from minigrid.core.grid import Grid
  4. from minigrid.core.mission import MissionSpace
  5. from minigrid.core.world_object import Ball, Box, Key
  6. from minigrid.minigrid_env import MiniGridEnv
  7. class PutNearEnv(MiniGridEnv):
  8. """
  9. ### Description
  10. The agent is instructed through a textual string to pick up an object and
  11. place it next to another object. This environment is easy to solve with two
  12. objects, but difficult to solve with more, as it involves both textual
  13. understanding and spatial reasoning involving multiple objects.
  14. ### Mission Space
  15. "put the {move_color} {move_type} near the {target_color} {target_type}"
  16. {move_color} and {target_color} can be "red", "green", "blue", "purple",
  17. "yellow" or "grey".
  18. {move_type} and {target_type} Can be "box", "ball" or "key".
  19. ### Action Space
  20. | Num | Name | Action |
  21. |-----|--------------|-------------------|
  22. | 0 | left | Turn left |
  23. | 1 | right | Turn right |
  24. | 2 | forward | Move forward |
  25. | 3 | pickup | Pick up an object |
  26. | 4 | drop | Drop an object |
  27. | 5 | toggle | Unused |
  28. | 6 | done | Unused |
  29. ### Observation Encoding
  30. - Each tile is encoded as a 3 dimensional tuple:
  31. `(OBJECT_IDX, COLOR_IDX, STATE)`
  32. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  33. [minigrid/minigrid.py](minigrid/minigrid.py)
  34. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  35. ### Rewards
  36. A reward of '1' is given for success, and '0' for failure.
  37. ### Termination
  38. The episode ends if any one of the following conditions is met:
  39. 1. The agent picks up the wrong object.
  40. 2. The agent drop the correct object near the target.
  41. 3. Timeout (see `max_steps`).
  42. ### Registered Configurations
  43. N: number of objects.
  44. - `MiniGrid-PutNear-6x6-N2-v0`
  45. - `MiniGrid-PutNear-8x8-N3-v0`
  46. """
  47. def __init__(self, size=6, numObjs=2, max_steps: Optional[int] = None, **kwargs):
  48. self.size = size
  49. self.numObjs = numObjs
  50. self.obj_types = ["key", "ball", "box"]
  51. mission_space = MissionSpace(
  52. mission_func=lambda move_color, move_type, target_color, target_type: f"put the {move_color} {move_type} near the {target_color} {target_type}",
  53. ordered_placeholders=[
  54. COLOR_NAMES,
  55. self.obj_types,
  56. COLOR_NAMES,
  57. self.obj_types,
  58. ],
  59. )
  60. if max_steps is None:
  61. max_steps = 5 * size
  62. super().__init__(
  63. mission_space=mission_space,
  64. width=size,
  65. height=size,
  66. # Set this to True for maximum speed
  67. see_through_walls=True,
  68. max_steps=max_steps,
  69. **kwargs,
  70. )
  71. def _gen_grid(self, width, height):
  72. self.grid = Grid(width, height)
  73. # Generate the surrounding walls
  74. self.grid.horz_wall(0, 0)
  75. self.grid.horz_wall(0, height - 1)
  76. self.grid.vert_wall(0, 0)
  77. self.grid.vert_wall(width - 1, 0)
  78. # Types and colors of objects we can generate
  79. types = ["key", "ball", "box"]
  80. objs = []
  81. objPos = []
  82. def near_obj(env, p1):
  83. for p2 in objPos:
  84. dx = p1[0] - p2[0]
  85. dy = p1[1] - p2[1]
  86. if abs(dx) <= 1 and abs(dy) <= 1:
  87. return True
  88. return False
  89. # Until we have generated all the objects
  90. while len(objs) < self.numObjs:
  91. objType = self._rand_elem(types)
  92. objColor = self._rand_elem(COLOR_NAMES)
  93. # If this object already exists, try again
  94. if (objType, objColor) in objs:
  95. continue
  96. if objType == "key":
  97. obj = Key(objColor)
  98. elif objType == "ball":
  99. obj = Ball(objColor)
  100. elif objType == "box":
  101. obj = Box(objColor)
  102. else:
  103. raise ValueError(
  104. "{} object type given. Object type can only be of values key, ball and box.".format(
  105. objType
  106. )
  107. )
  108. pos = self.place_obj(obj, reject_fn=near_obj)
  109. objs.append((objType, objColor))
  110. objPos.append(pos)
  111. # Randomize the agent start position and orientation
  112. self.place_agent()
  113. # Choose a random object to be moved
  114. objIdx = self._rand_int(0, len(objs))
  115. self.move_type, self.moveColor = objs[objIdx]
  116. self.move_pos = objPos[objIdx]
  117. # Choose a target object (to put the first object next to)
  118. while True:
  119. targetIdx = self._rand_int(0, len(objs))
  120. if targetIdx != objIdx:
  121. break
  122. self.target_type, self.target_color = objs[targetIdx]
  123. self.target_pos = objPos[targetIdx]
  124. self.mission = "put the {} {} near the {} {}".format(
  125. self.moveColor,
  126. self.move_type,
  127. self.target_color,
  128. self.target_type,
  129. )
  130. def step(self, action):
  131. preCarrying = self.carrying
  132. obs, reward, terminated, truncated, info = super().step(action)
  133. u, v = self.dir_vec
  134. ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
  135. tx, ty = self.target_pos
  136. # If we picked up the wrong object, terminate the episode
  137. if action == self.actions.pickup and self.carrying:
  138. if (
  139. self.carrying.type != self.move_type
  140. or self.carrying.color != self.moveColor
  141. ):
  142. terminated = True
  143. # If successfully dropping an object near the target
  144. if action == self.actions.drop and preCarrying:
  145. if self.grid.get(ox, oy) is preCarrying:
  146. if abs(ox - tx) <= 1 and abs(oy - ty) <= 1:
  147. reward = self._reward()
  148. terminated = True
  149. return obs, reward, terminated, truncated, info