浏览代码

Add a gym Space to handle textual observations, and update the environment obervation_space accordingly

saleml 2 年之前
父节点
当前提交
c366eb2a0e
共有 1 个文件被更改,包括 26 次插入1 次删除
  1. 26 1
      gym_minigrid/minigrid.py

+ 26 - 1
gym_minigrid/minigrid.py

@@ -1,5 +1,6 @@
 import math
 import hashlib
+import string
 import gym
 from enum import IntEnum
 import numpy as np
@@ -615,6 +616,28 @@ class Grid:
 
         return mask
 
+class StringGymSpace(gym.spaces.space.Space):
+    """
+    A gym space that represents a string of characters of bounded length
+    """
+    def __init__(self, min_length=0, max_length=1000):
+        self.min_length = min_length
+        self.max_length = max_length
+        self.letters = string.ascii_letters + string.digits + ' .,!- '
+        self._shape = ()
+        self.dtype = np.dtype('U')
+
+    def sample(self):
+        length = np.random.randint(self.min_length, self.max_length)
+        string = ''.join(np.random.choice(self.letters, size=length))
+        return string
+
+    def contains(self, x):
+        return isinstance(x, str) and len(x) >= self.min_length and len(x) <= self.max_length
+
+    def __repr__(self):
+        return "StringGymSpace(min_length={}, max_length={})".format(self.min_length, self.max_length)
+
 class MiniGridEnv(gym.Env):
     """
     2D grid world game environment
@@ -678,7 +701,9 @@ class MiniGridEnv(gym.Env):
             dtype='uint8'
         )
         self.observation_space = spaces.Dict({
-            'image': self.observation_space
+            'image': self.observation_space,
+            'direction': spaces.Discrete(4),
+            'mission': StringGymSpace(min_length=0, max_length=200),
         })
 
         # Range of possible rewards