Maxime Chevalier-Boisvert 5 gadi atpakaļ
vecāks
revīzija
48bb96e238
3 mainītis faili ar 35 papildinājumiem un 39 dzēšanām
  1. 14 37
      gym_minigrid/minigrid.py
  2. 21 0
      gym_minigrid/rendering.py
  3. 0 2
      manual_control.py

+ 14 - 37
gym_minigrid/minigrid.py

@@ -190,43 +190,20 @@ class Lava(WorldObj):
     def can_overlap(self):
         return True
 
-    def render(self, r):
-        orange = 255, 128, 0
-        r.setLineColor(*orange)
-        r.setColor(*orange)
-        r.drawPolygon([
-            (0          , TILE_PIXELS),
-            (TILE_PIXELS, TILE_PIXELS),
-            (TILE_PIXELS, 0),
-            (0          , 0)
-        ])
-
-        # drawing the waves
-        r.setLineColor(0, 0, 0)
-
-        r.drawPolyline([
-            (.1 * TILE_PIXELS, .3 * TILE_PIXELS),
-            (.3 * TILE_PIXELS, .4 * TILE_PIXELS),
-            (.5 * TILE_PIXELS, .3 * TILE_PIXELS),
-            (.7 * TILE_PIXELS, .4 * TILE_PIXELS),
-            (.9 * TILE_PIXELS, .3 * TILE_PIXELS),
-        ])
-
-        r.drawPolyline([
-            (.1 * TILE_PIXELS, .5 * TILE_PIXELS),
-            (.3 * TILE_PIXELS, .6 * TILE_PIXELS),
-            (.5 * TILE_PIXELS, .5 * TILE_PIXELS),
-            (.7 * TILE_PIXELS, .6 * TILE_PIXELS),
-            (.9 * TILE_PIXELS, .5 * TILE_PIXELS),
-        ])
-
-        r.drawPolyline([
-            (.1 * TILE_PIXELS, .7 * TILE_PIXELS),
-            (.3 * TILE_PIXELS, .8 * TILE_PIXELS),
-            (.5 * TILE_PIXELS, .7 * TILE_PIXELS),
-            (.7 * TILE_PIXELS, .8 * TILE_PIXELS),
-            (.9 * TILE_PIXELS, .7 * TILE_PIXELS),
-        ])
+    def render(self, img):
+        c = (255, 128, 0)
+
+        # Background color
+        fill_coords(img, point_in_rect(0, 1, 0, 1), c)
+
+        # Little waves
+        for i in range(3):
+            ylo = 0.3 + 0.2 * i
+            yhi = 0.4 + 0.2 * i
+            fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0))
+            fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0))
+            fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0))
+            fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0))
 
 class Wall(WorldObj):
     def __init__(self, color='grey'):

+ 21 - 0
gym_minigrid/rendering.py

@@ -28,6 +28,27 @@ def rotate_fn(fin, cx, cy, theta):
 
     return fout
 
+def point_in_line(x0, y0, x1, y1, r):
+    p0 = np.array([x0, y0])
+    p1 = np.array([x1, y1])
+    dir = p1 - p0
+    dist = np.linalg.norm(dir)
+    dir = dir / dist
+
+    def fn(x, y):
+        q = np.array([x, y])
+        pq = q - p0
+
+        # Closest point on line
+        a = np.dot(pq, dir)
+        a = np.clip(a, 0, dist)
+        p = p0 + a * dir
+
+        dist_to_line = np.linalg.norm(q - p)
+        return dist_to_line <= r
+
+    return fn
+
 def point_in_circle(cx, cy, r):
     def fn(x, y):
         return (x-cx)*(x-cx) + (y-cy)*(y-cy) <= r * r

+ 0 - 2
manual_control.py

@@ -96,8 +96,6 @@ fig.canvas.set_window_title('gym_minigrid - ' + args.env_name)
 ax.set_xticks([], [])
 ax.set_yticks([], [])
 
-print(args.tile_size)
-
 # Show the first image of the environment
 img = env.render('rgb_array', tile_size=args.tile_size)
 imshow_obj = ax.imshow(img, interpolation='bilinear')