arguments.py 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import argparse
  2. import torch
  3. def get_args():
  4. parser = argparse.ArgumentParser(description='RL')
  5. parser.add_argument('--algo', default='a2c',
  6. help='algorithm to use: a2c | ppo | acktr')
  7. parser.add_argument('--lr', type=float, default=7e-4,
  8. help='learning rate (default: 7e-4)')
  9. parser.add_argument('--eps', type=float, default=1e-5,
  10. help='RMSprop optimizer epsilon (default: 1e-5)')
  11. parser.add_argument('--alpha', type=float, default=0.99,
  12. help='RMSprop optimizer apha (default: 0.99)')
  13. parser.add_argument('--gamma', type=float, default=0.99,
  14. help='discount factor for rewards (default: 0.99)')
  15. parser.add_argument('--use-gae', action='store_true', default=False,
  16. help='use generalized advantage estimation')
  17. parser.add_argument('--tau', type=float, default=0.95,
  18. help='gae parameter (default: 0.95)')
  19. parser.add_argument('--entropy-coef', type=float, default=0.01,
  20. help='entropy term coefficient (default: 0.01)')
  21. parser.add_argument('--value-loss-coef', type=float, default=0.5,
  22. help='value loss coefficient (default: 0.5)')
  23. parser.add_argument('--max-grad-norm', type=float, default=0.5,
  24. help='value loss coefficient (default: 0.5)')
  25. parser.add_argument('--seed', type=int, default=1,
  26. help='random seed (default: 1)')
  27. parser.add_argument('--num-processes', type=int, default=32,
  28. help='how many training CPU processes to use (default: 32)')
  29. parser.add_argument('--num-steps', type=int, default=5,
  30. help='number of forward steps in A2C (default: 5)')
  31. parser.add_argument('--ppo-epoch', type=int, default=4,
  32. help='number of ppo epochs (default: 4)')
  33. parser.add_argument('--num-mini-batch', type=int, default=32,
  34. help='number of batches for ppo (default: 32)')
  35. parser.add_argument('--clip-param', type=float, default=0.2,
  36. help='ppo clip parameter (default: 0.2)')
  37. parser.add_argument('--num-stack', type=int, default=1,
  38. help='number of frames to stack (default: 1)')
  39. parser.add_argument('--log-interval', type=int, default=10,
  40. help='log interval, one log per n updates (default: 10)')
  41. parser.add_argument('--save-interval', type=int, default=100,
  42. help='save interval, one save per n updates (default: 10)')
  43. parser.add_argument('--vis-interval', type=int, default=100,
  44. help='vis interval, one log per n updates (default: 100)')
  45. parser.add_argument('--num-frames', type=int, default=10e7,
  46. help='number of frames to train (default: 10e7)')
  47. parser.add_argument('--env-name', default='PongNoFrameskip-v4',
  48. help='environment to train on (default: PongNoFrameskip-v4)')
  49. parser.add_argument('--log-dir', default='/tmp/gym/',
  50. help='directory to save agent logs (default: /tmp/gym)')
  51. parser.add_argument('--save-dir', default='./trained_models/',
  52. help='directory to save agent logs (default: ./trained_models/)')
  53. parser.add_argument('--no-cuda', action='store_true', default=False,
  54. help='disables CUDA training')
  55. parser.add_argument('--recurrent-policy', action='store_true', default=True,
  56. help='use a recurrent policy')
  57. parser.add_argument('--no-vis', action='store_true', default=False,
  58. help='disables visdom visualization')
  59. args = parser.parse_args()
  60. args.cuda = not args.no_cuda and torch.cuda.is_available()
  61. args.vis = not args.no_vis
  62. if not args.cuda:
  63. print('*** WARNING: CUDA NOT ENABLED ***')
  64. return args