world_object.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. from typing import TYPE_CHECKING, Optional, Tuple
  2. import numpy as np
  3. from minigrid.core.constants import (
  4. COLOR_TO_IDX,
  5. COLORS,
  6. IDX_TO_COLOR,
  7. IDX_TO_OBJECT,
  8. OBJECT_TO_IDX,
  9. )
  10. from minigrid.utils.rendering import (
  11. fill_coords,
  12. point_in_circle,
  13. point_in_line,
  14. point_in_rect,
  15. )
  16. if TYPE_CHECKING:
  17. from minigrid.minigrid_env import MiniGridEnv
  18. class WorldObj:
  19. """
  20. Base class for grid world objects
  21. """
  22. def __init__(self, type: str, color: str):
  23. assert type in OBJECT_TO_IDX, type
  24. assert color in COLOR_TO_IDX, color
  25. self.type = type
  26. self.color = color
  27. self.contains = None
  28. # Initial position of the object
  29. self.init_pos = None
  30. # Current position of the object
  31. self.cur_pos = None
  32. def can_overlap(self) -> bool:
  33. """Can the agent overlap with this?"""
  34. return False
  35. def can_pickup(self) -> bool:
  36. """Can the agent pick this up?"""
  37. return False
  38. def can_contain(self) -> bool:
  39. """Can this contain another object?"""
  40. return False
  41. def see_behind(self) -> bool:
  42. """Can the agent see behind this object?"""
  43. return True
  44. def toggle(self, env: "MiniGridEnv", pos: Tuple[int, int]) -> bool:
  45. """Method to trigger/toggle an action this object performs"""
  46. return False
  47. def encode(self) -> Tuple[int, int, int]:
  48. """Encode the a description of this object as a 3-tuple of integers"""
  49. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  50. @staticmethod
  51. def decode(type_idx: int, color_idx: int, state: int) -> Optional["WorldObj"]:
  52. """Create an object from a 3-tuple state description"""
  53. obj_type = IDX_TO_OBJECT[type_idx]
  54. color = IDX_TO_COLOR[color_idx]
  55. if obj_type == "empty" or obj_type == "unseen":
  56. return None
  57. # State, 0: open, 1: closed, 2: locked
  58. is_open = state == 0
  59. is_locked = state == 2
  60. if obj_type == "wall":
  61. v = Wall(color)
  62. elif obj_type == "floor":
  63. v = Floor(color)
  64. elif obj_type == "ball":
  65. v = Ball(color)
  66. elif obj_type == "key":
  67. v = Key(color)
  68. elif obj_type == "box":
  69. v = Box(color)
  70. elif obj_type == "door":
  71. v = Door(color, is_open, is_locked)
  72. elif obj_type == "goal":
  73. v = Goal()
  74. elif obj_type == "lava":
  75. v = Lava()
  76. else:
  77. assert False, "unknown object type in decode '%s'" % obj_type
  78. return v
  79. def render(self, r: np.ndarray) -> np.ndarray:
  80. """Draw this object with the given renderer"""
  81. raise NotImplementedError
  82. class Goal(WorldObj):
  83. def __init__(self):
  84. super().__init__("goal", "green")
  85. def can_overlap(self):
  86. return True
  87. def render(self, img):
  88. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  89. class Floor(WorldObj):
  90. """
  91. Colored floor tile the agent can walk over
  92. """
  93. def __init__(self, color: str = "blue"):
  94. super().__init__("floor", color)
  95. def can_overlap(self):
  96. return True
  97. def render(self, img):
  98. # Give the floor a pale color
  99. color = COLORS[self.color] / 2
  100. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  101. class Lava(WorldObj):
  102. def __init__(self):
  103. super().__init__("lava", "red")
  104. def can_overlap(self):
  105. return True
  106. def render(self, img):
  107. c = (255, 128, 0)
  108. # Background color
  109. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  110. # Little waves
  111. for i in range(3):
  112. ylo = 0.3 + 0.2 * i
  113. yhi = 0.4 + 0.2 * i
  114. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  115. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  116. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  117. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  118. class Wall(WorldObj):
  119. def __init__(self, color: str = "grey"):
  120. super().__init__("wall", color)
  121. def see_behind(self):
  122. return False
  123. def render(self, img):
  124. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  125. class Door(WorldObj):
  126. def __init__(self, color: str, is_open: bool = False, is_locked: bool = False):
  127. super().__init__("door", color)
  128. self.is_open = is_open
  129. self.is_locked = is_locked
  130. def can_overlap(self):
  131. """The agent can only walk over this cell when the door is open"""
  132. return self.is_open
  133. def see_behind(self):
  134. return self.is_open
  135. def toggle(self, env, pos):
  136. # If the player has the right key to open the door
  137. if self.is_locked:
  138. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  139. self.is_locked = False
  140. self.is_open = True
  141. return True
  142. return False
  143. self.is_open = not self.is_open
  144. return True
  145. def encode(self):
  146. """Encode the a description of this object as a 3-tuple of integers"""
  147. # State, 0: open, 1: closed, 2: locked
  148. if self.is_open:
  149. state = 0
  150. elif self.is_locked:
  151. state = 2
  152. # if door is closed and unlocked
  153. elif not self.is_open:
  154. state = 1
  155. else:
  156. raise ValueError(
  157. f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
  158. )
  159. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  160. def render(self, img):
  161. c = COLORS[self.color]
  162. if self.is_open:
  163. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  164. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  165. return
  166. # Door frame and door
  167. if self.is_locked:
  168. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  169. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  170. # Draw key slot
  171. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  172. else:
  173. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  174. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  175. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  176. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  177. # Draw door handle
  178. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  179. class Key(WorldObj):
  180. def __init__(self, color: str = "blue"):
  181. super().__init__("key", color)
  182. def can_pickup(self):
  183. return True
  184. def render(self, img):
  185. c = COLORS[self.color]
  186. # Vertical quad
  187. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  188. # Teeth
  189. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  190. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  191. # Ring
  192. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  193. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  194. class Ball(WorldObj):
  195. def __init__(self, color="blue"):
  196. super().__init__("ball", color)
  197. def can_pickup(self):
  198. return True
  199. def render(self, img):
  200. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  201. class Box(WorldObj):
  202. def __init__(self, color, contains: Optional[WorldObj] = None):
  203. super().__init__("box", color)
  204. self.contains = contains
  205. def can_pickup(self):
  206. return True
  207. def render(self, img):
  208. c = COLORS[self.color]
  209. # Outline
  210. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  211. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  212. # Horizontal slit
  213. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  214. def toggle(self, env, pos):
  215. # Replace the box by its contents
  216. env.grid.set(pos[0], pos[1], self.contains)
  217. return True