|
@@ -1,5 +1,6 @@
|
|
|
import math
|
|
|
import hashlib
|
|
|
+import string
|
|
|
import gym
|
|
|
from enum import IntEnum
|
|
|
import numpy as np
|
|
@@ -615,6 +616,28 @@ class Grid:
|
|
|
|
|
|
return mask
|
|
|
|
|
|
+class StringGymSpace(gym.spaces.space.Space):
|
|
|
+ """
|
|
|
+ A gym space that represents a string of characters of bounded length
|
|
|
+ """
|
|
|
+ def __init__(self, min_length=0, max_length=1000):
|
|
|
+ self.min_length = min_length
|
|
|
+ self.max_length = max_length
|
|
|
+ self.letters = string.ascii_letters + string.digits + ' .,!- '
|
|
|
+ self._shape = ()
|
|
|
+ self.dtype = np.dtype('U')
|
|
|
+
|
|
|
+ def sample(self):
|
|
|
+ length = np.random.randint(self.min_length, self.max_length)
|
|
|
+ string = ''.join(np.random.choice(self.letters, size=length))
|
|
|
+ return string
|
|
|
+
|
|
|
+ def contains(self, x):
|
|
|
+ return isinstance(x, str) and len(x) >= self.min_length and len(x) <= self.max_length
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return "StringGymSpace(min_length={}, max_length={})".format(self.min_length, self.max_length)
|
|
|
+
|
|
|
class MiniGridEnv(gym.Env):
|
|
|
"""
|
|
|
2D grid world game environment
|
|
@@ -678,7 +701,9 @@ class MiniGridEnv(gym.Env):
|
|
|
dtype='uint8'
|
|
|
)
|
|
|
self.observation_space = spaces.Dict({
|
|
|
- 'image': self.observation_space
|
|
|
+ 'image': self.observation_space,
|
|
|
+ 'direction': spaces.Discrete(4),
|
|
|
+ 'mission': StringGymSpace(min_length=0, max_length=200),
|
|
|
})
|
|
|
|
|
|
# Range of possible rewards
|