瀏覽代碼

use the new gym Text space and official git repo in setup and yml file

saleml 2 年之前
父節點
當前提交
6ddaa9fc6e
共有 3 個文件被更改,包括 5 次插入32 次删除
  1. 1 2
      .travis.yml
  2. 3 29
      gym_minigrid/minigrid.py
  3. 1 1
      setup.py

+ 1 - 2
.travis.yml

@@ -5,8 +5,7 @@ python:
 # command to install dependencies
 install:
   - pip3 install -e .
-  - pip3 install git+https://github.com/pseudo-rnd-thoughts/gym.git@fixed-env-checker
-  - pip3 install gym[other]
+  - pip3 install git+https://github.com/openai/gym.git@0.25.0
 
 # command to run tests
 script: ./run_tests.py

+ 3 - 29
gym_minigrid/minigrid.py

@@ -615,34 +615,6 @@ class Grid:
 
         return mask
 
-class StringGymSpace(gym.spaces.space.Space):
-    """
-    A gym space that represents a string of characters of bounded length
-    """
-    def __init__(self, min_length=0, max_length=1000):
-        super().__init__(shape=(), dtype='U')
-        self.min_length = min_length
-        self.max_length = max_length
-        self.letters = string.ascii_letters + string.digits + ' .,!- '
-
-    def sample(self):
-        length = np.random.randint(self.min_length, self.max_length)
-        string = ''.join(np.random.choice(list(self.letters), size=length))
-        return string
-
-    def contains(self, x):
-        return isinstance(x, str) and len(x) >= self.min_length and len(x) <= self.max_length
-
-    def __repr__(self):
-        return "StringGymSpace(min_length={}, max_length={})".format(self.min_length, self.max_length)
-
-    def __eq__(self, other):
-        return (isinstance(other, StringGymSpace) 
-                and self.min_length == other.min_length 
-                and self.max_length == other.max_length 
-                and self.letters == other.letters
-               )
-
 class MiniGridEnv(gym.Env):
     """
     2D grid world game environment
@@ -708,7 +680,9 @@ class MiniGridEnv(gym.Env):
         self.observation_space = spaces.Dict({
             'image': self.observation_space,
             'direction': spaces.Discrete(4),
-            'mission': StringGymSpace(min_length=0, max_length=200),
+            'mission': spaces.Text(max_length=200,
+                                   charset=string.ascii_letters + string.digits + ' .,!- '
+                                  )
         })
 
         # render mode

+ 1 - 1
setup.py

@@ -24,7 +24,7 @@ setup(
     python_requires=">=3.7, <3.11",
     long_description_content_type="text/markdown",
     install_requires=[
-        # 'gym>=0.24.1',
+        'gym>=0.25',
         "numpy>=1.18.0"
     ],
     classifiers=[