run_tests.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #!/usr/bin/env python3
  2. import random
  3. import gym
  4. import numpy as np
  5. from gym_minigrid.register import envSet
  6. from gym_minigrid.minigrid import Grid
  7. # Test specifically importing a specific environment
  8. from gym_minigrid.envs import DoorKeyEnv
  9. # Test importing wrappers
  10. from gym_minigrid.wrappers import *
  11. ##############################################################################
  12. print('%d environments registered' % len(envSet))
  13. for envName in sorted(envSet):
  14. print('testing "%s"' % envName)
  15. # Load the gym environment
  16. env = gym.make(envName)
  17. env.reset()
  18. env.render('rgb_array')
  19. # Verify that the same seed always produces the same environment
  20. for i in range(0, 5):
  21. seed = 1337 + i
  22. env.seed(seed)
  23. grid1 = env.grid.encode()
  24. env.seed(seed)
  25. grid2 = env.grid.encode()
  26. assert np.array_equal(grid2, grid1)
  27. env.reset()
  28. # Run for a few episodes
  29. for i in range(5 * env.maxSteps):
  30. # Pick a random action
  31. action = random.randint(0, env.action_space.n - 1)
  32. obs, reward, done, info = env.step(action)
  33. # Test observation encode/decode roundtrip
  34. img = obs if type(obs) is np.ndarray else obs['image']
  35. grid = Grid.decode(img)
  36. img2 = grid.encode()
  37. assert np.array_equal(img2, img)
  38. # Check that the reward is within the specified range
  39. assert reward >= env.reward_range[0], reward
  40. assert reward <= env.reward_range[1], reward
  41. if done:
  42. env.reset()
  43. # Check that the agent doesn't overlap with an object
  44. assert env.grid.get(*env.agentPos) is None
  45. env.render('rgb_array')
  46. env.close()
  47. ##############################################################################
  48. # Test the env.agentSees() function
  49. env = gym.make('MiniGrid-Empty-6x6-v0')
  50. goalPos = (env.grid.width - 2, env.grid.height - 2)
  51. def goalInObs(obs):
  52. grid = Grid.decode(obs)
  53. for j in range(0, grid.height):
  54. for i in range(0, grid.width):
  55. cell = grid.get(i, j)
  56. if cell and cell.color == 'green':
  57. return True
  58. return False
  59. env.reset()
  60. for i in range(0, 200):
  61. action = random.randint(0, env.action_space.n - 1)
  62. obs, reward, done, info = env.step(action)
  63. assert env.agentSees(*goalPos) == goalInObs(obs)
  64. if done:
  65. env.reset()