|
@@ -14,6 +14,11 @@ class FetchEnv(MiniGridEnv):
|
|
|
):
|
|
|
self.numObjs = numObjs
|
|
|
super().__init__(gridSize=size, maxSteps=5*size)
|
|
|
+
|
|
|
+ self.observation_space = spaces.Dict({
|
|
|
+ 'image': self.observation_space
|
|
|
+ })
|
|
|
+
|
|
|
self.reward_range = (-1000, 1000)
|
|
|
|
|
|
def _genGrid(self, width, height):
|
|
@@ -35,7 +40,7 @@ class FetchEnv(MiniGridEnv):
|
|
|
objs = []
|
|
|
|
|
|
# For each object to be generated
|
|
|
- for i in range(0, self.numObjs):
|
|
|
+ while len(objs) < self.numObjs:
|
|
|
objType = self._randElem(types)
|
|
|
objColor = self._randElem(colors)
|
|
|
|