|
@@ -14,7 +14,7 @@ class FetchEnv(MiniGridEnv):
|
|
|
):
|
|
|
self.numObjs = numObjs
|
|
|
super().__init__(gridSize=size, maxSteps=5*size)
|
|
|
- self.reward_range = (-1000, 1000)
|
|
|
+ self.reward_range = (0, 1)
|
|
|
|
|
|
def _genGrid(self, width, height):
|
|
|
assert width == height
|
|
@@ -85,10 +85,10 @@ class FetchEnv(MiniGridEnv):
|
|
|
if self.carrying:
|
|
|
if self.carrying.color == self.targetColor and \
|
|
|
self.carrying.type == self.targetType:
|
|
|
- reward = 1000 - self.stepCount
|
|
|
+ reward = 1
|
|
|
done = True
|
|
|
else:
|
|
|
- reward = -1000
|
|
|
+ reward = 0
|
|
|
done = True
|
|
|
|
|
|
return obs, reward, done, info
|
|
@@ -97,12 +97,21 @@ class FetchEnv5x5N2(FetchEnv):
|
|
|
def __init__(self):
|
|
|
super().__init__(size=5, numObjs=2)
|
|
|
|
|
|
+class FetchEnv6x6N2(FetchEnv):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__(size=6, numObjs=2)
|
|
|
+
|
|
|
register(
|
|
|
id='MiniGrid-Fetch-5x5-N2-v0',
|
|
|
entry_point='gym_minigrid.envs:FetchEnv5x5N2'
|
|
|
)
|
|
|
|
|
|
register(
|
|
|
+ id='MiniGrid-Fetch-6x6-N2-v0',
|
|
|
+ entry_point='gym_minigrid.envs:FetchEnv6x6N2'
|
|
|
+)
|
|
|
+
|
|
|
+register(
|
|
|
id='MiniGrid-Fetch-8x8-N3-v0',
|
|
|
entry_point='gym_minigrid.envs:FetchEnv'
|
|
|
)
|