| 
					
				 | 
			
			
				@@ -14,6 +14,7 @@ class GoToObjectEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.numObjs = numObjs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         super().__init__(gridSize=size, maxSteps=5*size) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.reward_range = (-1000, 1000) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _genGrid(self, width, height): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -34,12 +35,17 @@ class GoToObjectEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         colors = list(COLORS.keys()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         objs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        objPos = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # For each object to be generated 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for i in range(0, self.numObjs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             objType = self._randElem(types) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             objColor = self._randElem(colors) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # If this object already exists, try again 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (objType, objColor) in objs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if objType == 'key': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 obj = Key(objColor) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             elif objType == 'ball': 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -57,12 +63,13 @@ class GoToObjectEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     grid.set(*pos, obj) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            objs.append(obj) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            objs.append((objType, objColor)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            objPos.append(pos) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Choose a random object to be picked up 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        target = objs[self._randInt(0, len(objs))] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.targetType = target.type 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.targetColor = target.color 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        objIdx = self._randInt(0, len(objs)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.targetType, self.targetColor = objs[objIdx] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.targetPos = objPos[objIdx] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         descStr = '%s %s' % (self.targetColor, self.targetType) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -83,6 +90,7 @@ class GoToObjectEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.mission = 'go to the %s' % descStr 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #print(self.mission) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return grid 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -106,14 +114,16 @@ class GoToObjectEnv(MiniGridEnv): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _step(self, action): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         obs, reward, done, info = MiniGridEnv._step(self, action) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #if self.carrying: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #    if self.carrying.color == self.targetColor and \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #       self.carrying.type == self.targetType: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #        reward = 1000 - self.stepCount 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #        done = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #        reward = -1000 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #        done = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ax, ay = self.agentPos 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        tx, ty = self.targetPos 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Reward being next to the object 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Double reward waiting next to the object 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if abs(ax - tx) <= 1 and abs(ay - ty) <= 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if action == self.actions.wait: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                reward = 2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                reward = 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         obs = self._observation(obs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |