|
@@ -19,7 +19,7 @@ class FetchEnv(MiniGridEnv):
|
|
|
self.observation_space = spaces.Box(
|
|
|
low=0,
|
|
|
high=255,
|
|
|
- shape=obsSize + 2
|
|
|
+ shape=obsSize + 27 * 48
|
|
|
)
|
|
|
|
|
|
def _genGrid(self, width, height):
|
|
@@ -83,6 +83,8 @@ class FetchEnv(MiniGridEnv):
|
|
|
self.mission = 'you must fetch a %s' % descStr
|
|
|
assert hasattr(self, 'mission')
|
|
|
|
|
|
+ #self.mission = 'fetch a %s' % descStr
|
|
|
+
|
|
|
return grid
|
|
|
|
|
|
def _observation(self, obs):
|
|
@@ -98,9 +100,26 @@ class FetchEnv(MiniGridEnv):
|
|
|
}
|
|
|
"""
|
|
|
|
|
|
- typeIdx = OBJECT_TO_IDX[self.targetType]
|
|
|
- colorIdx= COLOR_TO_IDX[self.targetColor]
|
|
|
- obs = np.hstack((obs.flatten(), [typeIdx, colorIdx]))
|
|
|
+ #typeIdx = OBJECT_TO_IDX[self.targetType]
|
|
|
+ #colorIdx= COLOR_TO_IDX[self.targetColor]
|
|
|
+ #obs = np.hstack((obs.flatten(), [typeIdx, colorIdx]))
|
|
|
+
|
|
|
+ NUM_CHARS = 27
|
|
|
+ maxLen = 48
|
|
|
+ assert len(self.mission) > 0 and len(self.mission) <= maxLen, len(self.mission)
|
|
|
+ mission = self.mission.lower()
|
|
|
+
|
|
|
+ strArray = np.zeros(shape=(maxLen, NUM_CHARS))
|
|
|
+
|
|
|
+ for idx, ch in enumerate(mission):
|
|
|
+ if ch >= 'a' and ch <= 'z':
|
|
|
+ chNo = ord(ch) - ord('a')
|
|
|
+ elif ch == ' ':
|
|
|
+ chNo = ord('z') - ord('a') + 1
|
|
|
+ assert chNo < NUM_CHARS, '%s : %d' % (ch, chNo)
|
|
|
+ strArray[idx, chNo] = 1
|
|
|
+
|
|
|
+ obs = np.hstack((obs.flatten(), strArray.flatten()))
|
|
|
|
|
|
return obs
|
|
|
|