|
@@ -615,34 +615,6 @@ class Grid:
|
|
|
|
|
|
return mask
|
|
|
|
|
|
-class StringGymSpace(gym.spaces.space.Space):
|
|
|
- """
|
|
|
- A gym space that represents a string of characters of bounded length
|
|
|
- """
|
|
|
- def __init__(self, min_length=0, max_length=1000):
|
|
|
- super().__init__(shape=(), dtype='U')
|
|
|
- self.min_length = min_length
|
|
|
- self.max_length = max_length
|
|
|
- self.letters = string.ascii_letters + string.digits + ' .,!- '
|
|
|
-
|
|
|
- def sample(self):
|
|
|
- length = np.random.randint(self.min_length, self.max_length)
|
|
|
- string = ''.join(np.random.choice(list(self.letters), size=length))
|
|
|
- return string
|
|
|
-
|
|
|
- def contains(self, x):
|
|
|
- return isinstance(x, str) and len(x) >= self.min_length and len(x) <= self.max_length
|
|
|
-
|
|
|
- def __repr__(self):
|
|
|
- return "StringGymSpace(min_length={}, max_length={})".format(self.min_length, self.max_length)
|
|
|
-
|
|
|
- def __eq__(self, other):
|
|
|
- return (isinstance(other, StringGymSpace)
|
|
|
- and self.min_length == other.min_length
|
|
|
- and self.max_length == other.max_length
|
|
|
- and self.letters == other.letters
|
|
|
- )
|
|
|
-
|
|
|
class MiniGridEnv(gym.Env):
|
|
|
"""
|
|
|
2D grid world game environment
|
|
@@ -708,7 +680,9 @@ class MiniGridEnv(gym.Env):
|
|
|
self.observation_space = spaces.Dict({
|
|
|
'image': self.observation_space,
|
|
|
'direction': spaces.Discrete(4),
|
|
|
- 'mission': StringGymSpace(min_length=0, max_length=200),
|
|
|
+ 'mission': spaces.Text(max_length=200,
|
|
|
+ charset=string.ascii_letters + string.digits + ' .,!- '
|
|
|
+ )
|
|
|
})
|
|
|
|
|
|
# render mode
|