run_tests.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #!/usr/bin/env python3
  2. import random
  3. import numpy as np
  4. import gym
  5. from gym_minigrid.register import envList
  6. from gym_minigrid.minigrid import Grid
  7. # Test specifically importing a specific environment
  8. from gym_minigrid.envs import DoorKeyEnv
  9. # Test importing wrappers
  10. from gym_minigrid.wrappers import *
  11. ##############################################################################
  12. print('%d environments registered' % len(envList))
  13. for envName in envList:
  14. print('testing "%s"' % envName)
  15. # Load the gym environment
  16. env = gym.make(envName)
  17. env.reset()
  18. env.render('rgb_array')
  19. # Verify that the same seed always produces the same environment
  20. for i in range(0, 5):
  21. seed = 1337 + i
  22. env.seed(seed)
  23. grid1 = env.grid
  24. env.seed(seed)
  25. grid2 = env.grid
  26. assert grid1 == grid2
  27. env.reset()
  28. # Run for a few episodes
  29. num_episodes = 0
  30. while num_episodes < 5:
  31. # Pick a random action
  32. action = random.randint(0, env.action_space.n - 1)
  33. obs, reward, done, info = env.step(action)
  34. # Test observation encode/decode roundtrip
  35. img = obs['image']
  36. grid = Grid.decode(img)
  37. img2 = grid.encode()
  38. assert np.array_equal(img, img2)
  39. # Check that the reward is within the specified range
  40. assert reward >= env.reward_range[0], reward
  41. assert reward <= env.reward_range[1], reward
  42. if done:
  43. num_episodes += 1
  44. env.reset()
  45. env.render('rgb_array')
  46. env.close()
  47. ##############################################################################
  48. env = gym.make('MiniGrid-Empty-6x6-v0')
  49. goalPos = (env.grid.width - 2, env.grid.height - 2)
  50. # Test the "in" operator on grid objects
  51. assert ('green', 'goal') in env.grid
  52. assert ('blue', 'key') not in env.grid
  53. # Test the env.agentSees() function
  54. env.reset()
  55. for i in range(0, 200):
  56. action = random.randint(0, env.action_space.n - 1)
  57. obs, reward, done, info = env.step(action)
  58. goalVisible = ('green', 'goal') in Grid.decode(obs['image'])
  59. assert env.agentSees(*goalPos) == goalVisible
  60. if done:
  61. env.reset()