Browse Source

Implemented flat, one-hot string encoding into observations

Maxime Chevalier-Boisvert 7 năm trước cách đây
mục cha
commit
7c4e96a806
2 tập tin đã thay đổi với 26 bổ sung4 xóa
  1. 23 4
      gym_minigrid/envs/fetch.py
  2. 3 0
      pytorch-rl/arguments.py

+ 23 - 4
gym_minigrid/envs/fetch.py

@@ -19,7 +19,7 @@ class FetchEnv(MiniGridEnv):
         self.observation_space = spaces.Box(
             low=0,
             high=255,
-            shape=obsSize + 2
+            shape=obsSize + 27 * 48
         )
 
     def _genGrid(self, width, height):
@@ -83,6 +83,8 @@ class FetchEnv(MiniGridEnv):
             self.mission = 'you must fetch a %s' % descStr
         assert hasattr(self, 'mission')
 
+        #self.mission = 'fetch a %s' % descStr
+
         return grid
 
     def _observation(self, obs):
@@ -98,9 +100,26 @@ class FetchEnv(MiniGridEnv):
         }
         """
 
-        typeIdx = OBJECT_TO_IDX[self.targetType]
-        colorIdx= COLOR_TO_IDX[self.targetColor]
-        obs = np.hstack((obs.flatten(), [typeIdx, colorIdx]))
+        #typeIdx = OBJECT_TO_IDX[self.targetType]
+        #colorIdx= COLOR_TO_IDX[self.targetColor]
+        #obs = np.hstack((obs.flatten(), [typeIdx, colorIdx]))
+
+        NUM_CHARS = 27
+        maxLen = 48
+        assert len(self.mission) > 0 and len(self.mission) <= maxLen, len(self.mission)
+        mission = self.mission.lower()
+
+        strArray = np.zeros(shape=(maxLen, NUM_CHARS))
+
+        for idx, ch in enumerate(mission):
+            if ch >= 'a' and ch <= 'z':
+                chNo = ord(ch) - ord('a')
+            elif ch == ' ':
+                chNo = ord('z') - ord('a') + 1
+            assert chNo < NUM_CHARS, '%s : %d' % (ch, chNo)
+            strArray[idx, chNo] = 1
+
+        obs = np.hstack((obs.flatten(), strArray.flatten()))
 
         return obs
 

+ 3 - 0
pytorch-rl/arguments.py

@@ -64,4 +64,7 @@ def get_args():
     args.cuda = not args.no_cuda and torch.cuda.is_available()
     args.vis = not args.no_vis
 
+    if not args.cuda:
+        print('*** CUDA DISABLED ***')
+
     return args