|
@@ -30,11 +30,12 @@ for envName in sorted(envSet):
|
|
|
obs, reward, done, info = env.step(action)
|
|
|
|
|
|
# Test observation encode/decode roundtrip
|
|
|
- if type(obs) is np.ndarray:
|
|
|
- grid = Grid.decode(obs)
|
|
|
- obs2 = grid.encode()
|
|
|
- assert np.array_equal(obs2, obs)
|
|
|
+ img = obs if type(obs) is np.ndarray else obs['image']
|
|
|
+ grid = Grid.decode(img)
|
|
|
+ img2 = grid.encode()
|
|
|
+ assert np.array_equal(img2, img)
|
|
|
|
|
|
+ # Check that the reward is within the specified range
|
|
|
assert reward >= env.reward_range[0], reward
|
|
|
assert reward <= env.reward_range[1], reward
|
|
|
|