ソースを参照

Added operator == definition for grids

Maxime Chevalier-Boisvert 7 年 前
コミット
e84bba090e
2 ファイル変更13 行追加5 行削除
  1. 8 0
      gym_minigrid/minigrid.py
  2. 5 5
      run_tests.py

+ 8 - 0
gym_minigrid/minigrid.py

@@ -322,6 +322,14 @@ class Grid:
                     return True
         return False
 
+    def __eq__(self, other):
+        grid1 = self.encode()
+        grid2 = other.encode()
+        return np.array_equal(grid2, grid1)
+
+    def __ne__(self, other):
+        return not self == other
+
     def copy(self):
         from copy import deepcopy
         return deepcopy(self)

+ 5 - 5
run_tests.py

@@ -1,8 +1,8 @@
 #!/usr/bin/env python3
 
 import random
-import gym
 import numpy as np
+import gym
 from gym_minigrid.register import envList
 from gym_minigrid.minigrid import Grid
 
@@ -28,10 +28,10 @@ for envName in envList:
     for i in range(0, 5):
         seed = 1337 + i
         env.seed(seed)
-        grid1 = env.grid.encode()
+        grid1 = env.grid
         env.seed(seed)
-        grid2 = env.grid.encode()
-        assert np.array_equal(grid2, grid1)
+        grid2 = env.grid
+        assert grid1 == grid2
 
     env.reset()
 
@@ -46,7 +46,7 @@ for envName in envList:
         img = obs['image']
         grid = Grid.decode(img)
         img2 = grid.encode()
-        assert np.array_equal(img2, img)
+        assert np.array_equal(img, img2)
 
         # Check that the reward is within the specified range
         assert reward >= env.reward_range[0], reward