world_object.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING, Tuple
  3. import numpy as np
  4. from minigrid.core.constants import (
  5. COLOR_TO_IDX,
  6. COLORS,
  7. IDX_TO_COLOR,
  8. IDX_TO_OBJECT,
  9. OBJECT_TO_IDX,
  10. )
  11. from minigrid.utils.rendering import (
  12. fill_coords,
  13. point_in_circle,
  14. point_in_line,
  15. point_in_rect,
  16. )
  17. if TYPE_CHECKING:
  18. from minigrid.minigrid_env import MiniGridEnv
  19. Point = Tuple[int, int]
  20. class WorldObj:
  21. """
  22. Base class for grid world objects
  23. """
  24. def __init__(self, type: str, color: str):
  25. assert type in OBJECT_TO_IDX, type
  26. assert color in COLOR_TO_IDX, color
  27. self.type = type
  28. self.color = color
  29. self.contains = None
  30. # Initial position of the object
  31. self.init_pos: Point | None = None
  32. # Current position of the object
  33. self.cur_pos: Point | None = None
  34. def can_overlap(self) -> bool:
  35. """Can the agent overlap with this?"""
  36. return False
  37. def can_pickup(self) -> bool:
  38. """Can the agent pick this up?"""
  39. return False
  40. def can_contain(self) -> bool:
  41. """Can this contain another object?"""
  42. return False
  43. def see_behind(self) -> bool:
  44. """Can the agent see behind this object?"""
  45. return True
  46. def toggle(self, env: MiniGridEnv, pos: tuple[int, int]) -> bool:
  47. """Method to trigger/toggle an action this object performs"""
  48. return False
  49. def encode(self) -> tuple[int, int, int]:
  50. """Encode the a description of this object as a 3-tuple of integers"""
  51. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  52. @staticmethod
  53. def decode(type_idx: int, color_idx: int, state: int) -> WorldObj | None:
  54. """Create an object from a 3-tuple state description"""
  55. obj_type = IDX_TO_OBJECT[type_idx]
  56. color = IDX_TO_COLOR[color_idx]
  57. if obj_type == "empty" or obj_type == "unseen":
  58. return None
  59. # State, 0: open, 1: closed, 2: locked
  60. is_open = state == 0
  61. is_locked = state == 2
  62. if obj_type == "wall":
  63. v = Wall(color)
  64. elif obj_type == "floor":
  65. v = Floor(color)
  66. elif obj_type == "ball":
  67. v = Ball(color)
  68. elif obj_type == "key":
  69. v = Key(color)
  70. elif obj_type == "box":
  71. v = Box(color)
  72. elif obj_type == "door":
  73. v = Door(color, is_open, is_locked)
  74. elif obj_type == "goal":
  75. v = Goal()
  76. elif obj_type == "lava":
  77. v = Lava()
  78. else:
  79. assert False, "unknown object type in decode '%s'" % obj_type
  80. return v
  81. def render(self, r: np.ndarray) -> np.ndarray:
  82. """Draw this object with the given renderer"""
  83. raise NotImplementedError
  84. class Goal(WorldObj):
  85. def __init__(self):
  86. super().__init__("goal", "green")
  87. def can_overlap(self):
  88. return True
  89. def render(self, img):
  90. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  91. class Floor(WorldObj):
  92. """
  93. Colored floor tile the agent can walk over
  94. """
  95. def __init__(self, color: str = "blue"):
  96. super().__init__("floor", color)
  97. def can_overlap(self):
  98. return True
  99. def render(self, img):
  100. # Give the floor a pale color
  101. color = COLORS[self.color] / 2
  102. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  103. class Lava(WorldObj):
  104. def __init__(self):
  105. super().__init__("lava", "red")
  106. def can_overlap(self):
  107. return True
  108. def render(self, img):
  109. c = (255, 128, 0)
  110. # Background color
  111. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  112. # Little waves
  113. for i in range(3):
  114. ylo = 0.3 + 0.2 * i
  115. yhi = 0.4 + 0.2 * i
  116. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  117. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  118. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  119. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  120. class Wall(WorldObj):
  121. def __init__(self, color: str = "grey"):
  122. super().__init__("wall", color)
  123. def see_behind(self):
  124. return False
  125. def render(self, img):
  126. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  127. class Door(WorldObj):
  128. def __init__(self, color: str, is_open: bool = False, is_locked: bool = False):
  129. super().__init__("door", color)
  130. self.is_open = is_open
  131. self.is_locked = is_locked
  132. def can_overlap(self):
  133. """The agent can only walk over this cell when the door is open"""
  134. return self.is_open
  135. def see_behind(self):
  136. return self.is_open
  137. def toggle(self, env, pos):
  138. # If the player has the right key to open the door
  139. if self.is_locked:
  140. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  141. self.is_locked = False
  142. self.is_open = True
  143. return True
  144. return False
  145. self.is_open = not self.is_open
  146. return True
  147. def encode(self):
  148. """Encode the a description of this object as a 3-tuple of integers"""
  149. # State, 0: open, 1: closed, 2: locked
  150. if self.is_open:
  151. state = 0
  152. elif self.is_locked:
  153. state = 2
  154. # if door is closed and unlocked
  155. elif not self.is_open:
  156. state = 1
  157. else:
  158. raise ValueError(
  159. f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
  160. )
  161. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  162. def render(self, img):
  163. c = COLORS[self.color]
  164. if self.is_open:
  165. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  166. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  167. return
  168. # Door frame and door
  169. if self.is_locked:
  170. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  171. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  172. # Draw key slot
  173. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  174. else:
  175. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  176. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  177. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  178. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  179. # Draw door handle
  180. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  181. class Key(WorldObj):
  182. def __init__(self, color: str = "blue"):
  183. super().__init__("key", color)
  184. def can_pickup(self):
  185. return True
  186. def render(self, img):
  187. c = COLORS[self.color]
  188. # Vertical quad
  189. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  190. # Teeth
  191. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  192. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  193. # Ring
  194. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  195. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  196. class Ball(WorldObj):
  197. def __init__(self, color="blue"):
  198. super().__init__("ball", color)
  199. def can_pickup(self):
  200. return True
  201. def render(self, img):
  202. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  203. class Box(WorldObj):
  204. def __init__(self, color, contains: WorldObj | None = None):
  205. super().__init__("box", color)
  206. self.contains = contains
  207. def can_pickup(self):
  208. return True
  209. def render(self, img):
  210. c = COLORS[self.color]
  211. # Outline
  212. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  213. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  214. # Horizontal slit
  215. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  216. def toggle(self, env, pos):
  217. # Replace the box by its contents
  218. env.grid.set(pos[0], pos[1], self.contains)
  219. return True