world_object.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import numpy as np
  2. from minigrid.core.constants import (
  3. COLOR_TO_IDX,
  4. COLORS,
  5. IDX_TO_COLOR,
  6. IDX_TO_OBJECT,
  7. OBJECT_TO_IDX,
  8. )
  9. from minigrid.utils.rendering import (
  10. fill_coords,
  11. point_in_circle,
  12. point_in_line,
  13. point_in_rect,
  14. )
  15. class WorldObj:
  16. """
  17. Base class for grid world objects
  18. """
  19. def __init__(self, type, color):
  20. assert type in OBJECT_TO_IDX, type
  21. assert color in COLOR_TO_IDX, color
  22. self.type = type
  23. self.color = color
  24. self.contains = None
  25. # Initial position of the object
  26. self.init_pos = None
  27. # Current position of the object
  28. self.cur_pos = None
  29. def can_overlap(self):
  30. """Can the agent overlap with this?"""
  31. return False
  32. def can_pickup(self):
  33. """Can the agent pick this up?"""
  34. return False
  35. def can_contain(self):
  36. """Can this contain another object?"""
  37. return False
  38. def see_behind(self):
  39. """Can the agent see behind this object?"""
  40. return True
  41. def toggle(self, env, pos):
  42. """Method to trigger/toggle an action this object performs"""
  43. return False
  44. def encode(self):
  45. """Encode the a description of this object as a 3-tuple of integers"""
  46. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  47. @staticmethod
  48. def decode(type_idx, color_idx, state):
  49. """Create an object from a 3-tuple state description"""
  50. obj_type = IDX_TO_OBJECT[type_idx]
  51. color = IDX_TO_COLOR[color_idx]
  52. if obj_type == "empty" or obj_type == "unseen":
  53. return None
  54. # State, 0: open, 1: closed, 2: locked
  55. is_open = state == 0
  56. is_locked = state == 2
  57. if obj_type == "wall":
  58. v = Wall(color)
  59. elif obj_type == "floor":
  60. v = Floor(color)
  61. elif obj_type == "ball":
  62. v = Ball(color)
  63. elif obj_type == "key":
  64. v = Key(color)
  65. elif obj_type == "box":
  66. v = Box(color)
  67. elif obj_type == "door":
  68. v = Door(color, is_open, is_locked)
  69. elif obj_type == "goal":
  70. v = Goal()
  71. elif obj_type == "lava":
  72. v = Lava()
  73. else:
  74. assert False, "unknown object type in decode '%s'" % obj_type
  75. return v
  76. def render(self, r):
  77. """Draw this object with the given renderer"""
  78. raise NotImplementedError
  79. class Goal(WorldObj):
  80. def __init__(self):
  81. super().__init__("goal", "green")
  82. def can_overlap(self):
  83. return True
  84. def render(self, img):
  85. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  86. class Floor(WorldObj):
  87. """
  88. Colored floor tile the agent can walk over
  89. """
  90. def __init__(self, color="blue"):
  91. super().__init__("floor", color)
  92. def can_overlap(self):
  93. return True
  94. def render(self, img):
  95. # Give the floor a pale color
  96. color = COLORS[self.color] / 2
  97. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  98. class Lava(WorldObj):
  99. def __init__(self):
  100. super().__init__("lava", "red")
  101. def can_overlap(self):
  102. return True
  103. def render(self, img):
  104. c = (255, 128, 0)
  105. # Background color
  106. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  107. # Little waves
  108. for i in range(3):
  109. ylo = 0.3 + 0.2 * i
  110. yhi = 0.4 + 0.2 * i
  111. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  112. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  113. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  114. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  115. class Wall(WorldObj):
  116. def __init__(self, color="grey"):
  117. super().__init__("wall", color)
  118. def see_behind(self):
  119. return False
  120. def render(self, img):
  121. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  122. class Door(WorldObj):
  123. def __init__(self, color, is_open=False, is_locked=False):
  124. super().__init__("door", color)
  125. self.is_open = is_open
  126. self.is_locked = is_locked
  127. def can_overlap(self):
  128. """The agent can only walk over this cell when the door is open"""
  129. return self.is_open
  130. def see_behind(self):
  131. return self.is_open
  132. def toggle(self, env, pos):
  133. # If the player has the right key to open the door
  134. if self.is_locked:
  135. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  136. self.is_locked = False
  137. self.is_open = True
  138. return True
  139. return False
  140. self.is_open = not self.is_open
  141. return True
  142. def encode(self):
  143. """Encode the a description of this object as a 3-tuple of integers"""
  144. # State, 0: open, 1: closed, 2: locked
  145. if self.is_open:
  146. state = 0
  147. elif self.is_locked:
  148. state = 2
  149. # if door is closed and unlocked
  150. elif not self.is_open:
  151. state = 1
  152. else:
  153. raise ValueError(
  154. f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
  155. )
  156. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  157. def render(self, img):
  158. c = COLORS[self.color]
  159. if self.is_open:
  160. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  161. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  162. return
  163. # Door frame and door
  164. if self.is_locked:
  165. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  166. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  167. # Draw key slot
  168. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  169. else:
  170. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  171. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  172. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  173. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  174. # Draw door handle
  175. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  176. class Key(WorldObj):
  177. def __init__(self, color="blue"):
  178. super().__init__("key", color)
  179. def can_pickup(self):
  180. return True
  181. def render(self, img):
  182. c = COLORS[self.color]
  183. # Vertical quad
  184. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  185. # Teeth
  186. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  187. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  188. # Ring
  189. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  190. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  191. class Ball(WorldObj):
  192. def __init__(self, color="blue"):
  193. super().__init__("ball", color)
  194. def can_pickup(self):
  195. return True
  196. def render(self, img):
  197. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  198. class Box(WorldObj):
  199. def __init__(self, color, contains=None):
  200. super().__init__("box", color)
  201. self.contains = contains
  202. def can_pickup(self):
  203. return True
  204. def render(self, img):
  205. c = COLORS[self.color]
  206. # Outline
  207. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  208. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  209. # Horizontal slit
  210. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  211. def toggle(self, env, pos):
  212. # Replace the box by its contents
  213. env.grid.set(pos[0], pos[1], self.contains)
  214. return True