浏览代码

Initial commit

Maxime Chevalier-Boisvert 7 年之前
当前提交
51a5d9079d

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+*.pyc
+*__pycache__

+ 29 - 0
LICENSE

@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2017, Maxime Chevalier-Boisvert
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+  contributors may be used to endorse or promote products derived from
+  this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

+ 46 - 0
README.md

@@ -0,0 +1,46 @@
+# Minimalistic Grid World Environment (MiniGrid)
+
+Simple and minimailistic grid world environment for OpenAI Gym.
+
+## Installation
+
+Requirements:
+- Python 3
+- OpenAI gym
+- numpy
+- PyQT5
+- PyTorch (if using the supplied `basicrl` training code)
+- matplotlib (if using the supplied `basicrl` training code)
+
+Start by manually installing [PyTorch](http://pytorch.org/).
+
+Then, clone the repository and install the other dependencies with `pip3`:
+
+```
+git clone https://github.com/maximecb/gym-minigrid.git
+cd gym-minigrid
+pip3 install -e .
+```
+
+## Usage
+
+To run the standalone UI application:
+
+```
+./standalone.py
+```
+
+The environment being run can be selected with the `--env-name` option, eg:
+
+```
+./standalone.py --env-name MiniGrid-Fetch-8x8-v0
+```
+
+To see available environments and their implementation, look at [simple_envs.py](gym_minigrid/envs/simple_envs.py).
+
+Basic reinforcement learning code is provided in the `basicrl` subdirectory.
+You can perform training using the ACKTR algorithm with:
+
+```
+python3 basicrl/main.py --env-name MiniGrid-Empty-8x8-v0 --no-vis --num-processes 32 --algo acktr
+```

+ 21 - 0
basicrl/LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2017 Ilya Kostrikov
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 153 - 0
basicrl/README.md

@@ -0,0 +1,153 @@
+# pytorch-a2c-ppo-acktr
+
+## Update 10/06/2017: added enjoy.py and a link to pretrained models!
+## Update 09/27/2017: now supports both Atari and MuJoCo/Roboschool!
+
+This is a PyTorch implementation of
+* Advantage Actor Critic (A2C), a synchronous deterministic version of [A3C](https://arxiv.org/pdf/1602.01783v1.pdf)
+* Proximal Policy Optimization [PPO](https://arxiv.org/pdf/1707.06347.pdf)
+* Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation [ACKTR](https://arxiv.org/abs/1708.05144)
+
+Also see the OpenAI posts: [A2C/ACKTR](https://blog.openai.com/baselines-acktr-a2c/) and [PPO](https://blog.openai.com/openai-baselines-ppo/) for more information.
+
+This implementation is inspired by the OpenAI baselines for [A2C](https://github.com/openai/baselines/tree/master/baselines/a2c), [ACKTR](https://github.com/openai/baselines/tree/master/baselines/acktr) and [PPO](https://github.com/openai/baselines/tree/master/baselines/ppo1). It uses the same hyper parameters and the model since they were well tuned for Atari games.
+
+## Supported (and tested) environments (via [OpenAI Gym](https://gym.openai.com))
+* [Atari Learning Environment](https://github.com/mgbellemare/Arcade-Learning-Environment)
+* [MuJoCo](http://mujoco.org)
+* [PyBullet](http://pybullet.org) (including Racecar, Minitaur and Kuka)
+
+I highly recommend PyBullet as a free open source alternative to MuJoCo for continuous control tasks.
+
+All environments are operated using exactly the same Gym interface. See their documentations for a comprehensive list.
+
+## Requirements
+
+* Python 3 (it might work with Python 2, but I didn't test it)
+* [PyTorch](http://pytorch.org/)
+* [Visdom](https://github.com/facebookresearch/visdom)
+* [OpenAI baselines](https://github.com/openai/baselines)
+
+In order to install requirements, follow:
+
+```bash
+# PyTorch
+conda install pytorch torchvision -c soumith
+
+# Baselines for Atari preprocessing
+git clone https://github.com/openai/baselines.git
+cd baselines
+pip install -e .
+
+# Other requirements
+pip install -r requirements.txt
+```
+
+## Contributions
+
+Contributions are very welcome. If you know how to make this code better, don't hesitate to send a pull request. Also see a todo list below.
+
+Also I'm searching for volunteers to run all experiments on Atari and MuJoCo (with multiple random seeds).
+
+## Disclaimer
+
+It's extremely difficult to reproduce results for Reinforcement Learning methods. See ["Deep Reinforcement Learning that Matters"](https://arxiv.org/abs/1709.06560) for more information. I tried to reproduce OpenAI results as closely as possible. However, majors differences in performance can be caused even by minor differences in TensorFlow and PyTorch libraries.
+
+### TODO
+* Improve this README file. Rearrange images.
+* Improve performance of KFAC, see kfac.py for more information
+* Run evaluation for all games and algorithms
+
+## Training
+
+Start a `Visdom` server with `python -m visdom.server`, it will serve `http://localhost:8097/` by default.
+
+### Atari
+#### A2C
+
+```bash
+python main.py --env-name "PongNoFrameskip-v4"
+```
+
+#### PPO
+
+```bash
+python main.py --env-name "PongNoFrameskip-v4" --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --num-processes 8 --num-steps 128 --num-mini-batch 4 --vis-interval 1 --log-interval 1
+```
+
+#### ACKTR
+
+```bash
+python main.py --env-name "PongNoFrameskip-v4" --algo acktr --num-processes 32 --num-steps 20
+```
+
+### MuJoCo
+#### A2C
+
+```bash
+python main.py --env-name "Reacher-v1" --num-stack 1 --num-frames 1000000
+```
+
+#### PPO
+
+```bash
+python main.py --env-name "Reacher-v1" --algo ppo --use-gae --vis-interval 1  --log-interval 1 --num-stack 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-frames 1000000
+```
+
+#### ACKTR
+
+ACKTR requires some modifications to be made specifically for MuJoCo. But at the moment, I want to keep this code as unified as possible. Thus, I'm going for better ways to integrate it into the codebase.
+
+## Enjoy
+
+Load a pretrained model from [my Google Drive](https://drive.google.com/open?id=0Bw49qC_cgohKS3k2OWpyMWdzYkk).
+
+Also pretrained models for other games are available on request. Send me an email or create an issue, and I will upload it.
+
+Disclaimer: I might have used different hyper-parameters to train these models.
+
+### Atari
+
+```bash
+python enjoy.py --load-dir trained_models/a2c --env-name "PongNoFrameskip-v4" --num-stack 4
+```
+
+### MuJoCo
+
+```bash
+python enjoy.py --load-dir trained_models/ppo --env-name "Reacher-v1" --num-stack 1
+```
+
+## Results
+
+### A2C
+
+![BreakoutNoFrameskip-v4](imgs/a2c_breakout.png)
+
+![SeaquestNoFrameskip-v4](imgs/a2c_seaquest.png)
+
+![QbertNoFrameskip-v4](imgs/a2c_qbert.png)
+
+![beamriderNoFrameskip-v4](imgs/a2c_beamrider.png)
+
+### PPO
+
+
+![BreakoutNoFrameskip-v4](imgs/ppo_halfcheetah.png)
+
+![SeaquestNoFrameskip-v4](imgs/ppo_hopper.png)
+
+![QbertNoFrameskip-v4](imgs/ppo_reacher.png)
+
+![beamriderNoFrameskip-v4](imgs/ppo_walker.png)
+
+
+### ACKTR
+
+![BreakoutNoFrameskip-v4](imgs/acktr_breakout.png)
+
+![SeaquestNoFrameskip-v4](imgs/acktr_seaquest.png)
+
+![QbertNoFrameskip-v4](imgs/acktr_qbert.png)
+
+![beamriderNoFrameskip-v4](imgs/acktr_beamrider.png)

+ 67 - 0
basicrl/arguments.py

@@ -0,0 +1,67 @@
+import argparse
+
+import torch
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description='RL')
+    parser.add_argument('--algo', default='a2c',
+                        help='algorithm to use: a2c | ppo | acktr')
+    parser.add_argument('--lr', type=float, default=7e-4,
+                        help='learning rate (default: 7e-4)')
+    parser.add_argument('--eps', type=float, default=1e-5,
+                        help='RMSprop optimizer epsilon (default: 1e-5)')
+    parser.add_argument('--alpha', type=float, default=0.99,
+                        help='RMSprop optimizer apha (default: 0.99)')
+    parser.add_argument('--gamma', type=float, default=0.99,
+                        help='discount factor for rewards (default: 0.99)')
+    parser.add_argument('--use-gae', action='store_true', default=False,
+                        help='use generalized advantage estimation')
+    parser.add_argument('--tau', type=float, default=0.95,
+                        help='gae parameter (default: 0.95)')
+    parser.add_argument('--entropy-coef', type=float, default=0.01,
+                        help='entropy term coefficient (default: 0.01)')
+    parser.add_argument('--value-loss-coef', type=float, default=0.5,
+                        help='value loss coefficient (default: 0.5)')
+    parser.add_argument('--max-grad-norm', type=float, default=0.5,
+                        help='value loss coefficient (default: 0.5)')
+    parser.add_argument('--seed', type=int, default=1,
+                        help='random seed (default: 1)')
+    parser.add_argument('--num-processes', type=int, default=16,
+                        help='how many training CPU processes to use (default: 16)')
+    parser.add_argument('--num-steps', type=int, default=5,
+                        help='number of forward steps in A2C (default: 5)')
+    parser.add_argument('--ppo-epoch', type=int, default=4,
+                        help='number of ppo epochs (default: 4)')
+    parser.add_argument('--num-mini-batch', type=int, default=32,
+                        help='number of batches for ppo (default: 32)')
+    parser.add_argument('--clip-param', type=float, default=0.2,
+                        help='ppo clip parameter (default: 0.2)')
+    parser.add_argument('--num-stack', type=int, default=4,
+                        help='number of frames to stack (default: 4)')
+    parser.add_argument('--log-interval', type=int, default=10,
+                        help='log interval, one log per n updates (default: 10)')
+    parser.add_argument('--save-interval', type=int, default=100,
+                        help='save interval, one save per n updates (default: 10)')
+    parser.add_argument('--vis-interval', type=int, default=100,
+                        help='vis interval, one log per n updates (default: 100)')
+    parser.add_argument('--num-frames', type=int, default=10e6,
+                        help='number of frames to train (default: 10e6)')
+    parser.add_argument('--env-name', default='PongNoFrameskip-v4',
+                        help='environment to train on (default: PongNoFrameskip-v4)')
+    parser.add_argument('--log-dir', default='/tmp/gym/',
+                        help='directory to save agent logs (default: /tmp/gym)')
+    parser.add_argument('--save-dir', default='./trained_models/',
+                        help='directory to save agent logs (default: ./trained_models/)')
+    parser.add_argument('--no-cuda', action='store_true', default=False,
+                        help='disables CUDA training')
+    parser.add_argument('--recurrent-policy', action='store_true', default=False,
+                        help='use a recurrent policy')
+    parser.add_argument('--no-vis', action='store_true', default=False,
+                        help='disables visdom visualization')
+    args = parser.parse_args()
+
+    args.cuda = not args.no_cuda and torch.cuda.is_available()
+    args.vis = not args.no_vis
+
+    return args

+ 81 - 0
basicrl/distributions.py

@@ -0,0 +1,81 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from utils import AddBias
+
+
+class Categorical(nn.Module):
+    def __init__(self, num_inputs, num_outputs):
+        super(Categorical, self).__init__()
+        self.linear = nn.Linear(num_inputs, num_outputs)
+
+    def forward(self, x):
+        x = self.linear(x)
+        return x
+
+    def sample(self, x, deterministic):
+        x = self(x)
+
+        probs = F.softmax(x)
+        if deterministic is False:
+            action = probs.multinomial()
+        else:
+            action = probs.max(1, keepdim=True)[1]
+        return action
+
+    def logprobs_and_entropy(self, x, actions):
+        x = self(x)
+
+        log_probs = F.log_softmax(x)
+        probs = F.softmax(x)
+
+        action_log_probs = log_probs.gather(1, actions)
+
+        dist_entropy = -(log_probs * probs).sum(-1).mean()
+        return action_log_probs, dist_entropy
+
+
+class DiagGaussian(nn.Module):
+    def __init__(self, num_inputs, num_outputs):
+        super(DiagGaussian, self).__init__()
+        self.fc_mean = nn.Linear(num_inputs, num_outputs)
+        self.logstd = AddBias(torch.zeros(num_outputs))
+
+    def forward(self, x):
+        action_mean = self.fc_mean(x)
+
+        #  An ugly hack for my KFAC implementation.
+        zeros = Variable(torch.zeros(action_mean.size()), volatile=x.volatile)
+        if x.is_cuda:
+            zeros = zeros.cuda()
+
+        action_logstd = self.logstd(zeros)
+        return action_mean, action_logstd
+
+    def sample(self, x, deterministic):
+        action_mean, action_logstd = self(x)
+
+        action_std = action_logstd.exp()
+
+        if deterministic is False:
+            noise = Variable(torch.randn(action_std.size()))
+            if action_std.is_cuda:
+                noise = noise.cuda()
+            action = action_mean + action_std * noise
+        else:
+            action = action_mean
+        return action
+
+    def logprobs_and_entropy(self, x, actions):
+        action_mean, action_logstd = self(x)
+
+        action_std = action_logstd.exp()
+
+        action_log_probs = -0.5 * ((actions - action_mean) / action_std).pow(2) - 0.5 * math.log(2 * math.pi) - action_logstd
+        action_log_probs = action_log_probs.sum(-1, keepdim=True)
+        dist_entropy = 0.5 + 0.5 * math.log(2 * math.pi) + action_logstd
+        dist_entropy = dist_entropy.sum(-1).mean()
+        return action_log_probs, dist_entropy

+ 110 - 0
basicrl/enjoy.py

@@ -0,0 +1,110 @@
+import argparse
+import os
+import sys
+import types
+import time
+
+import numpy as np
+import torch
+from torch.autograd import Variable
+from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
+from baselines.common.vec_env.vec_normalize import VecNormalize
+
+from envs import make_env
+
+
+parser = argparse.ArgumentParser(description='RL')
+parser.add_argument('--seed', type=int, default=1,
+                    help='random seed (default: 1)')
+parser.add_argument('--num-stack', type=int, default=4,
+                    help='number of frames to stack (default: 4)')
+parser.add_argument('--log-interval', type=int, default=10,
+                    help='log interval, one log per n updates (default: 10)')
+parser.add_argument('--env-name', default='PongNoFrameskip-v4',
+                    help='environment to train on (default: PongNoFrameskip-v4)')
+parser.add_argument('--load-dir', default='./trained_models/',
+                    help='directory to save agent logs (default: ./trained_models/)')
+args = parser.parse_args()
+
+
+env = make_env(args.env_name, args.seed, 0, None, 5, 7)
+env = DummyVecEnv([env])
+
+actor_critic, ob_rms = \
+            torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))
+
+
+if len(env.observation_space.shape) == 1:
+    env = VecNormalize(env, ret=False)
+    env.ob_rms = ob_rms
+
+    # An ugly hack to remove updates
+    def _obfilt(self, obs):
+        if self.ob_rms:
+            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
+            return obs
+        else:
+            return obs
+    env._obfilt = types.MethodType(_obfilt, env)
+    render_func = env.venv.envs[0].render
+else:
+    render_func = env.envs[0].render
+
+obs_shape = env.observation_space.shape
+obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
+current_obs = torch.zeros(1, *obs_shape)
+states = torch.zeros(1, actor_critic.state_size)
+masks = torch.zeros(1, 1)
+
+
+def update_current_obs(obs):
+    shape_dim0 = env.observation_space.shape[0]
+    obs = torch.from_numpy(obs).float()
+    if args.num_stack > 1:
+        current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
+    current_obs[:, -shape_dim0:] = obs
+
+
+render_func('human')
+obs = env.reset()
+update_current_obs(obs)
+
+if args.env_name.find('Bullet') > -1:
+    import pybullet as p
+
+    torsoId = -1
+    for i in range(p.getNumBodies()):
+        if (p.getBodyInfo(i)[0].decode() == "torso"):
+            torsoId = i
+
+while True:
+    value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True),
+                                                Variable(states, volatile=True),
+                                                Variable(masks, volatile=True),
+                                                deterministic=True)
+    states = states.data
+    cpu_actions = action.data.squeeze(1).cpu().numpy()
+    # Obser reward and next obs
+    obs, reward, done, _ = env.step(cpu_actions)
+
+    time.sleep(0.05)
+
+    masks.fill_(0.0 if done else 1.0)
+
+    if current_obs.dim() == 4:
+        current_obs *= masks.unsqueeze(2).unsqueeze(2)
+    else:
+        current_obs *= masks
+    update_current_obs(obs)
+
+    if args.env_name.find('Bullet') > -1:
+        if torsoId > -1:
+            distance = 5
+            yaw = 0
+            humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)
+            p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)
+
+    renderer = render_func('human')
+
+    if not renderer.window:
+        sys.exit(0)

+ 62 - 0
basicrl/envs.py

@@ -0,0 +1,62 @@
+import os
+import numpy
+import gym
+
+from gym.spaces.box import Box
+
+from baselines import bench
+from baselines.common.atari_wrappers import make_atari, wrap_deepmind
+
+try:
+    import pybullet_envs
+except ImportError:
+    pass
+
+try:
+    import gym_minigrid
+except:
+    pass
+
+class ScaleActions(gym.ActionWrapper):
+    def __init__(self, env=None):
+        super(ScaleActions, self).__init__(env)
+
+    def _step(self, action):
+        action = (numpy.tanh(action) + 1) / 2 * (self.action_space.high - self.action_space.low) + self.action_space.low
+        return self.env.step(action)
+
+def make_env(env_id, seed, rank, log_dir):
+    def _thunk():
+        env = gym.make(env_id)
+        is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
+        if is_atari:
+            env = make_atari(env_id)
+        env.seed(seed + rank)
+        if log_dir is not None:
+            env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
+        if is_atari:
+            env = wrap_deepmind(env)
+        # If the input has shape (W,H,3), wrap for PyTorch convolutions
+        obs_shape = env.observation_space.shape
+        if len(obs_shape) == 3 and obs_shape[2] == 3:
+            env = WrapPyTorch(env)
+
+        #env = ScaleActions(env)
+
+        return env
+
+    return _thunk
+
+
+class WrapPyTorch(gym.ObservationWrapper):
+    def __init__(self, env=None):
+        super(WrapPyTorch, self).__init__(env)
+        obs_shape = self.observation_space.shape
+        self.observation_space = Box(
+            self.observation_space.low[0,0,0],
+            self.observation_space.high[0,0,0],
+            [obs_shape[2], obs_shape[1], obs_shape[0]]
+        )
+
+    def _observation(self, observation):
+        return observation.transpose(2, 0, 1)

二进制
basicrl/imgs/a2c_beamrider.png


二进制
basicrl/imgs/a2c_breakout.png


二进制
basicrl/imgs/a2c_qbert.png


二进制
basicrl/imgs/a2c_seaquest.png


二进制
basicrl/imgs/acktr_beamrider.png


二进制
basicrl/imgs/acktr_breakout.png


二进制
basicrl/imgs/acktr_qbert.png


二进制
basicrl/imgs/acktr_seaquest.png


二进制
basicrl/imgs/ppo_halfcheetah.png


二进制
basicrl/imgs/ppo_hopper.png


二进制
basicrl/imgs/ppo_reacher.png


二进制
basicrl/imgs/ppo_walker.png


+ 244 - 0
basicrl/kfac.py

@@ -0,0 +1,244 @@
+import math
+
+import torch
+import torch.optim as optim
+import torch.nn as nn
+import torch.nn.functional as F
+from utils import AddBias
+
+# TODO: In order to make this code faster:
+# 1) Implement _extract_patches as a single cuda kernel
+# 2) Compute QR decomposition in a separate process
+# 3) Actually make a general KFAC optimizer so it fits PyTorch
+
+
+def _extract_patches(x, kernel_size, stride, padding):
+    if padding[0] + padding[1] > 0:
+        x = F.pad(x, (padding[1], padding[1], padding[0],
+                      padding[0])).data  # Actually check dims
+    x = x.unfold(2, kernel_size[0], stride[0])
+    x = x.unfold(3, kernel_size[1], stride[1])
+    x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
+    x = x.view(
+        x.size(0), x.size(1), x.size(2), x.size(3) * x.size(4) * x.size(5))
+    return x
+
+
+def compute_cov_a(a, classname, layer_info, fast_cnn):
+    batch_size = a.size(0)
+
+    if classname == 'Conv2d':
+        if fast_cnn:
+            a = _extract_patches(a, *layer_info)
+            a = a.view(a.size(0), -1, a.size(-1))
+            a = a.mean(1)
+        else:
+            a = _extract_patches(a, *layer_info)
+            a = a.view(-1, a.size(-1)).div_(a.size(1)).div_(a.size(2))
+    elif classname == 'AddBias':
+        is_cuda = a.is_cuda
+        a = torch.ones(a.size(0), 1)
+        if is_cuda:
+            a = a.cuda()
+
+    return a.t() @ (a / batch_size)
+
+
+def compute_cov_g(g, classname, layer_info, fast_cnn):
+    batch_size = g.size(0)
+
+    if classname == 'Conv2d':
+        if fast_cnn:
+            g = g.view(g.size(0), g.size(1), -1)
+            g = g.sum(-1)
+        else:
+            g = g.transpose(1, 2).transpose(2, 3).contiguous()
+            g = g.view(-1, g.size(-1)).mul_(g.size(1)).mul_(g.size(2))
+    elif classname == 'AddBias':
+        g = g.view(g.size(0), g.size(1), -1)
+        g = g.sum(-1)
+
+    g_ = g * batch_size
+    return g_.t() @ (g_ / g.size(0))
+
+
+def update_running_stat(aa, m_aa, momentum):
+    # Do the trick to keep aa unchanged and not create any additional tensors
+    m_aa *= momentum / (1 - momentum)
+    m_aa += aa
+    m_aa *= (1 - momentum)
+
+
+class SplitBias(nn.Module):
+    def __init__(self, module):
+        super(SplitBias, self).__init__()
+        self.module = module
+        self.add_bias = AddBias(module.bias.data)
+        self.module.bias = None
+
+    def forward(self, input):
+        x = self.module(input)
+        x = self.add_bias(x)
+        return x
+
+
+class KFACOptimizer(optim.Optimizer):
+    def __init__(self,
+                 model,
+                 lr=0.25,
+                 momentum=0.9,
+                 stat_decay=0.99,
+                 kl_clip=0.001,
+                 damping=1e-2,
+                 weight_decay=0,
+                 fast_cnn=False,
+                 Ts=1,
+                 Tf=10):
+        defaults = dict()
+
+        def split_bias(module):
+            for mname, child in module.named_children():
+                if hasattr(child, 'bias'):
+                    module._modules[mname] = SplitBias(child)
+                else:
+                    split_bias(child)
+
+        split_bias(model)
+            
+        super(KFACOptimizer, self).__init__(model.parameters(), defaults)
+
+        self.known_modules = {'Linear', 'Conv2d', 'AddBias'}
+
+        self.modules = []
+        self.grad_outputs = {}
+
+        self.model = model
+        self._prepare_model()
+
+        self.steps = 0
+
+        self.m_aa, self.m_gg = {}, {}
+        self.Q_a, self.Q_g = {}, {}
+        self.d_a, self.d_g = {}, {}
+
+        self.momentum = momentum
+        self.stat_decay = stat_decay
+
+        self.lr = lr
+        self.kl_clip = kl_clip
+        self.damping = damping
+        self.weight_decay = weight_decay
+
+        self.fast_cnn = fast_cnn
+
+        self.Ts = Ts
+        self.Tf = Tf
+
+        self.optim = optim.SGD(
+            model.parameters(),
+            lr=self.lr * (1 - self.momentum),
+            momentum=self.momentum)
+
+    def _save_input(self, module, input):
+        if input[0].volatile == False and self.steps % self.Ts == 0:
+            classname = module.__class__.__name__
+            layer_info = None
+            if classname == 'Conv2d':
+                layer_info = (module.kernel_size, module.stride,
+                              module.padding)
+
+            aa = compute_cov_a(input[0].data, classname, layer_info,
+                               self.fast_cnn)
+
+            # Initialize buffers
+            if self.steps == 0:
+                self.m_aa[module] = aa.clone()
+
+            update_running_stat(aa, self.m_aa[module], self.stat_decay)
+
+    def _save_grad_output(self, module, grad_input, grad_output):
+        if self.acc_stats:
+            classname = module.__class__.__name__
+            layer_info = None
+            if classname == 'Conv2d':
+                layer_info = (module.kernel_size, module.stride,
+                              module.padding)
+
+            gg = compute_cov_g(grad_output[0].data, classname,
+                               layer_info, self.fast_cnn)
+
+            # Initialize buffers
+            if self.steps == 0:
+                self.m_gg[module] = gg.clone()
+
+            update_running_stat(gg, self.m_gg[module], self.stat_decay)
+
+    def _prepare_model(self):
+        for module in self.model.modules():
+            classname = module.__class__.__name__
+            if classname in self.known_modules:
+                assert not ((classname in ['Linear', 'Conv2d']) and module.bias is not None), \
+                                    "You must have a bias as a separate layer"
+
+                self.modules.append(module)
+                module.register_forward_pre_hook(self._save_input)
+                module.register_backward_hook(self._save_grad_output)
+
+    def step(self):
+        # Add weight decay
+        if self.weight_decay > 0:
+            for p in self.model.parameters():
+                p.grad.data.add_(self.weight_decay, p.data)
+
+        updates = {}
+        for i, m in enumerate(self.modules):
+            assert len(list(m.parameters())
+                       ) == 1, "Can handle only one parameter at the moment"
+            classname = m.__class__.__name__
+            p = next(m.parameters())
+
+            la = self.damping + self.weight_decay
+
+            if self.steps % self.Tf == 0:
+                # My asynchronous implementation exists, I will add it later.
+                # Experimenting with different ways to this in PyTorch.
+                self.d_a[m], self.Q_a[m] = torch.symeig(
+                    self.m_aa[m].cpu().double(), eigenvectors=True)
+                self.d_g[m], self.Q_g[m] = torch.symeig(
+                    self.m_gg[m].cpu().double(), eigenvectors=True)
+                self.d_a[m], self.Q_a[m] = self.d_a[m].float(), self.Q_a[m].float()
+                self.d_g[m], self.Q_g[m] = self.d_g[m].float(), self.Q_g[m].float()
+                if self.m_aa[m].is_cuda:
+                    self.d_a[m], self.Q_a[m] = self.d_a[m].cuda(), self.Q_a[m].cuda()
+                    self.d_g[m], self.Q_g[m] = self.d_g[m].cuda(), self.Q_g[m].cuda()
+
+                self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
+                self.d_g[m].mul_((self.d_g[m] > 1e-6).float())
+
+            if classname == 'Conv2d':
+                p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1)
+            else:
+                p_grad_mat = p.grad.data
+
+            v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
+            v2 = v1 / (
+                self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
+            v = self.Q_g[m] @ v2 @ self.Q_a[m].t()
+
+            v = v.view(p.grad.data.size())
+            updates[p] = v
+
+        vg_sum = 0
+        for p in self.model.parameters():
+            v = updates[p]
+            vg_sum += (v * p.grad.data * self.lr * self.lr).sum()
+
+        nu = min(1, math.sqrt(self.kl_clip / vg_sum))
+
+        for p in self.model.parameters():
+            v = updates[p]
+            p.grad.data.copy_(v)
+            p.grad.data.mul_(nu)
+
+        self.optim.step()
+        self.steps += 1

+ 269 - 0
basicrl/main.py

@@ -0,0 +1,269 @@
+import copy
+import glob
+import os
+import time
+import operator
+from functools import reduce
+
+import gym
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.autograd import Variable
+
+from arguments import get_args
+from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
+from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
+from baselines.common.vec_env.vec_normalize import VecNormalize
+from envs import make_env
+from kfac import KFACOptimizer
+from model import CNNPolicy, MLPPolicy
+from storage import RolloutStorage
+from visualize import visdom_plot
+
+args = get_args()
+
+assert args.algo in ['a2c', 'ppo', 'acktr']
+if args.recurrent_policy:
+    assert args.algo in ['a2c', 'ppo'], \
+        'Recurrent policy is not implemented for ACKTR'
+
+num_updates = int(args.num_frames) // args.num_steps // args.num_processes
+
+torch.manual_seed(args.seed)
+if args.cuda:
+    torch.cuda.manual_seed(args.seed)
+
+try:
+    os.makedirs(args.log_dir)
+except OSError:
+    files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
+    for f in files:
+        os.remove(f)
+
+
+def main():
+    print("#######")
+    print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
+    print("#######")
+
+    os.environ['OMP_NUM_THREADS'] = '1'
+
+    if args.vis:
+        from visdom import Visdom
+        viz = Visdom()
+        win = None
+
+    envs = [make_env(args.env_name, args.seed, i, args.log_dir)
+                for i in range(args.num_processes)]
+
+    if args.num_processes > 1:
+        envs = SubprocVecEnv(envs)
+    else:
+        envs = DummyVecEnv(envs)
+
+    if len(envs.observation_space.shape) == 1:
+        envs = VecNormalize(envs)
+
+    obs_shape = envs.observation_space.shape
+    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
+
+    obs_numel = reduce(operator.mul, obs_shape, 1)
+
+    if len(obs_shape) == 3 and obs_numel > 1024:
+        actor_critic = CNNPolicy(obs_shape[0], envs.action_space, args.recurrent_policy)
+    else:
+        assert not args.recurrent_policy, \
+            "Recurrent policy is not implemented for the MLP controller"
+        actor_critic = MLPPolicy(obs_numel, envs.action_space)
+
+    if envs.action_space.__class__.__name__ == "Discrete":
+        action_shape = 1
+    else:
+        action_shape = envs.action_space.shape[0]
+
+    if args.cuda:
+        actor_critic.cuda()
+
+    if args.algo == 'a2c':
+        optimizer = optim.RMSprop(actor_critic.parameters(), args.lr, eps=args.eps, alpha=args.alpha)
+    elif args.algo == 'ppo':
+        optimizer = optim.Adam(actor_critic.parameters(), args.lr, eps=args.eps)
+    elif args.algo == 'acktr':
+        optimizer = KFACOptimizer(actor_critic)
+
+    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space, actor_critic.state_size)
+    current_obs = torch.zeros(args.num_processes, *obs_shape)
+
+    def update_current_obs(obs):
+        shape_dim0 = envs.observation_space.shape[0]
+        obs = torch.from_numpy(obs).float()
+        if args.num_stack > 1:
+            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
+        current_obs[:, -shape_dim0:] = obs
+
+    obs = envs.reset()
+    update_current_obs(obs)
+
+    rollouts.observations[0].copy_(current_obs)
+
+    # These variables are used to compute average rewards for all processes.
+    episode_rewards = torch.zeros([args.num_processes, 1])
+    final_rewards = torch.zeros([args.num_processes, 1])
+
+    if args.cuda:
+        current_obs = current_obs.cuda()
+        rollouts.cuda()
+
+    start = time.time()
+    for j in range(num_updates):
+        for step in range(args.num_steps):
+            # Sample actions
+            value, action, action_log_prob, states = actor_critic.act(Variable(rollouts.observations[step], volatile=True),
+                                                                      Variable(rollouts.states[step], volatile=True),
+                                                                      Variable(rollouts.masks[step], volatile=True))
+            cpu_actions = action.data.squeeze(1).cpu().numpy()
+
+            # Obser reward and next obs
+            obs, reward, done, info = envs.step(cpu_actions)
+            reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
+            episode_rewards += reward
+
+            # If done then clean the history of observations.
+            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
+            final_rewards *= masks
+            final_rewards += (1 - masks) * episode_rewards
+            episode_rewards *= masks
+
+            if args.cuda:
+                masks = masks.cuda()
+
+            if current_obs.dim() == 4:
+                current_obs *= masks.unsqueeze(2).unsqueeze(2)
+            else:
+                current_obs *= masks
+
+            update_current_obs(obs)
+            rollouts.insert(step, current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)
+
+        next_value = actor_critic(Variable(rollouts.observations[-1], volatile=True),
+                                  Variable(rollouts.states[-1], volatile=True),
+                                  Variable(rollouts.masks[-1], volatile=True))[0].data
+
+        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
+
+        if args.algo in ['a2c', 'acktr']:
+            values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
+                                                                                           Variable(rollouts.states[0].view(-1, actor_critic.state_size)),
+                                                                                           Variable(rollouts.masks[:-1].view(-1, 1)),
+                                                                                           Variable(rollouts.actions.view(-1, action_shape)))
+
+            values = values.view(args.num_steps, args.num_processes, 1)
+            action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1)
+
+            advantages = Variable(rollouts.returns[:-1]) - values
+            value_loss = advantages.pow(2).mean()
+
+            action_loss = -(Variable(advantages.data) * action_log_probs).mean()
+
+            if args.algo == 'acktr' and optimizer.steps % optimizer.Ts == 0:
+                # Sampled fisher, see Martens 2014
+                actor_critic.zero_grad()
+                pg_fisher_loss = -action_log_probs.mean()
+
+                value_noise = Variable(torch.randn(values.size()))
+                if args.cuda:
+                    value_noise = value_noise.cuda()
+
+                sample_values = values + value_noise
+                vf_fisher_loss = -(values - Variable(sample_values.data)).pow(2).mean()
+
+                fisher_loss = pg_fisher_loss + vf_fisher_loss
+                optimizer.acc_stats = True
+                fisher_loss.backward(retain_graph=True)
+                optimizer.acc_stats = False
+
+            optimizer.zero_grad()
+            (value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()
+
+            if args.algo == 'a2c':
+                nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)
+
+            optimizer.step()
+        elif args.algo == 'ppo':
+            advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
+            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
+
+            for e in range(args.ppo_epoch):
+                if args.recurrent_policy:
+                    data_generator = rollouts.recurrent_generator(advantages,
+                                                            args.num_mini_batch)
+                else:
+                    data_generator = rollouts.feed_forward_generator(advantages,
+                                                            args.num_mini_batch)
+
+                for sample in data_generator:
+                    observations_batch, states_batch, actions_batch, \
+                       return_batch, masks_batch, old_action_log_probs_batch, \
+                            adv_targ = sample
+
+                    # Reshape to do in a single forward pass for all steps
+                    values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(observations_batch),
+                                                                                                   Variable(states_batch),
+                                                                                                   Variable(masks_batch),
+                                                                                                   Variable(actions_batch))
+
+                    adv_targ = Variable(adv_targ)
+                    ratio = torch.exp(action_log_probs - Variable(old_action_log_probs_batch))
+                    surr1 = ratio * adv_targ
+                    surr2 = torch.clamp(ratio, 1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ
+                    action_loss = -torch.min(surr1, surr2).mean() # PPO's pessimistic surrogate (L^CLIP)
+
+                    value_loss = (Variable(return_batch) - values).pow(2).mean()
+
+                    optimizer.zero_grad()
+                    (value_loss + action_loss - dist_entropy * args.entropy_coef).backward()
+                    nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)
+                    optimizer.step()
+
+        rollouts.after_update()
+
+        if j % args.save_interval == 0 and args.save_dir != "":
+            save_path = os.path.join(args.save_dir, args.algo)
+            try:
+                os.makedirs(save_path)
+            except OSError:
+                pass
+
+            # A really ugly way to save a model to CPU
+            save_model = actor_critic
+            if args.cuda:
+                save_model = copy.deepcopy(actor_critic).cpu()
+
+            save_model = [save_model,
+                            hasattr(envs, 'ob_rms') and envs.ob_rms or None]
+
+            torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))
+
+        if j % args.log_interval == 0:
+            end = time.time()
+            total_num_steps = (j + 1) * args.num_processes * args.num_steps
+            print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
+                format(j, total_num_steps,
+                       int(total_num_steps / (end - start)),
+                       final_rewards.mean(),
+                       final_rewards.median(),
+                       final_rewards.min(),
+                       final_rewards.max(), dist_entropy.data[0],
+                       value_loss.data[0], action_loss.data[0]))
+        if args.vis and j % args.vis_interval == 0:
+            try:
+                # Sometimes monitor doesn't properly flush the outputs
+                win = visdom_plot(viz, win, args.log_dir, args.env_name, args.algo)
+            except IOError:
+                pass
+
+if __name__ == "__main__":
+    main()

+ 186 - 0
basicrl/model.py

@@ -0,0 +1,186 @@
+import operator
+from functools import reduce
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from distributions import Categorical, DiagGaussian
+from utils import orthogonal
+
+def weights_init(m):
+    classname = m.__class__.__name__
+    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
+        orthogonal(m.weight.data)
+        if m.bias is not None:
+            m.bias.data.fill_(0)
+
+
+class FFPolicy(nn.Module):
+    def __init__(self):
+        super(FFPolicy, self).__init__()
+
+    def forward(self, inputs, states, masks):
+        raise NotImplementedError
+
+    def act(self, inputs, states, masks, deterministic=False):
+        value, x, states = self(inputs, states, masks)
+        action = self.dist.sample(x, deterministic=deterministic)
+        action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, action)
+        return value, action, action_log_probs, states
+
+    def evaluate_actions(self, inputs, states, masks, actions):
+        value, x, states = self(inputs, states, masks)
+        action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
+        return value, action_log_probs, dist_entropy, states
+
+
+class CNNPolicy(FFPolicy):
+    def __init__(self, num_inputs, action_space, use_gru):
+        super(CNNPolicy, self).__init__()
+        self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
+        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
+        self.conv3 = nn.Conv2d(64, 32, 3, stride=1)
+
+        self.linear1 = nn.Linear(32 * 7 * 7, 512)
+
+        if use_gru:
+            self.gru = nn.GRUCell(512, 512)
+
+        self.critic_linear = nn.Linear(512, 1)
+
+        if action_space.__class__.__name__ == "Discrete":
+            num_outputs = action_space.n
+            self.dist = Categorical(512, num_outputs)
+        elif action_space.__class__.__name__ == "Box":
+            num_outputs = action_space.shape[0]
+            self.dist = DiagGaussian(512, num_outputs)
+        else:
+            raise NotImplementedError
+
+        self.train()
+        self.reset_parameters()
+
+    @property
+    def state_size(self):
+        if hasattr(self, 'gru'):
+            return 512
+        else:
+            return 1
+
+    def reset_parameters(self):
+        self.apply(weights_init)
+
+        relu_gain = nn.init.calculate_gain('relu')
+        self.conv1.weight.data.mul_(relu_gain)
+        self.conv2.weight.data.mul_(relu_gain)
+        self.conv3.weight.data.mul_(relu_gain)
+        self.linear1.weight.data.mul_(relu_gain)
+
+        if hasattr(self, 'gru'):
+            orthogonal(self.gru.weight_ih.data)
+            orthogonal(self.gru.weight_hh.data)
+            self.gru.bias_ih.data.fill_(0)
+            self.gru.bias_hh.data.fill_(0)
+
+        if self.dist.__class__.__name__ == "DiagGaussian":
+            self.dist.fc_mean.weight.data.mul_(0.01)
+
+    def forward(self, inputs, states, masks):
+        x = self.conv1(inputs / 255.0)
+        x = F.relu(x)
+
+        x = self.conv2(x)
+        x = F.relu(x)
+
+        x = self.conv3(x)
+        x = F.relu(x)
+
+        x = x.view(-1, 32 * 7 * 7)
+        x = self.linear1(x)
+        x = F.relu(x)
+
+        if hasattr(self, 'gru'):
+            if inputs.size(0) == states.size(0):
+                x = states = self.gru(x, states * masks)
+            else:
+                x = x.view(-1, states.size(0), x.size(1))
+                masks = masks.view(-1, states.size(0), 1)
+                outputs = []
+                for i in range(x.size(0)):
+                    hx = states = self.gru(x[i], states * masks[i])
+                    outputs.append(hx)
+                x = torch.cat(outputs, 0)
+        return self.critic_linear(x), x, states
+
+
+def weights_init_mlp(m):
+    classname = m.__class__.__name__
+    if classname.find('Linear') != -1:
+        m.weight.data.normal_(0, 1)
+        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
+        if m.bias is not None:
+            m.bias.data.fill_(0)
+
+
+class MLPPolicy(FFPolicy):
+    def __init__(self, num_inputs, action_space):
+        super(MLPPolicy, self).__init__()
+
+        self.action_space = action_space
+
+        self.a_fc1 = nn.Linear(num_inputs, 64)
+        self.a_fc2 = nn.Linear(64, 64)
+
+        self.v_fc1 = nn.Linear(num_inputs, 64)
+        self.v_fc2 = nn.Linear(64, 64)
+        self.v_fc3 = nn.Linear(64, 1)
+
+        if action_space.__class__.__name__ == "Discrete":
+            num_outputs = action_space.n
+            self.dist = Categorical(64, num_outputs)
+        elif action_space.__class__.__name__ == "Box":
+            num_outputs = action_space.shape[0]
+            self.dist = DiagGaussian(64, num_outputs)
+        else:
+            raise NotImplementedError
+
+        self.train()
+        self.reset_parameters()
+
+    @property
+    def state_size(self):
+        return 1
+
+    def reset_parameters(self):
+        self.apply(weights_init_mlp)
+
+        """
+        tanh_gain = nn.init.calculate_gain('tanh')
+        self.a_fc1.weight.data.mul_(tanh_gain)
+        self.a_fc2.weight.data.mul_(tanh_gain)
+        self.v_fc1.weight.data.mul_(tanh_gain)
+        self.v_fc2.weight.data.mul_(tanh_gain)
+        """
+
+        if self.dist.__class__.__name__ == "DiagGaussian":
+            self.dist.fc_mean.weight.data.mul_(0.01)
+
+    def forward(self, inputs, states, masks):
+        batch_numel = reduce(operator.mul, inputs.size()[1:], 1)
+        inputs = inputs.view(-1, batch_numel)
+
+        x = self.v_fc1(inputs)
+        x = F.tanh(x)
+
+        x = self.v_fc2(x)
+        x = F.tanh(x)
+
+        x = self.v_fc3(x)
+        value = x
+
+        x = self.a_fc1(inputs)
+        x = F.tanh(x)
+
+        x = self.a_fc2(x)
+        x = F.tanh(x)
+
+        return value, x, states

+ 3 - 0
basicrl/requirements.txt

@@ -0,0 +1,3 @@
+gym
+matplotlib
+pybullet

+ 116 - 0
basicrl/storage.py

@@ -0,0 +1,116 @@
+import torch
+from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
+
+
+class RolloutStorage(object):
+    def __init__(self, num_steps, num_processes, obs_shape, action_space, state_size):
+        self.observations = torch.zeros(num_steps + 1, num_processes, *obs_shape)
+        self.states = torch.zeros(num_steps + 1, num_processes, state_size)
+        self.rewards = torch.zeros(num_steps, num_processes, 1)
+        self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
+        self.returns = torch.zeros(num_steps + 1, num_processes, 1)
+        self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
+        if action_space.__class__.__name__ == 'Discrete':
+            action_shape = 1
+        else:
+            action_shape = action_space.shape[0]
+        self.actions = torch.zeros(num_steps, num_processes, action_shape)
+        if action_space.__class__.__name__ == 'Discrete':
+            self.actions = self.actions.long()
+        self.masks = torch.ones(num_steps + 1, num_processes, 1)
+
+    def cuda(self):
+        self.observations = self.observations.cuda()
+        self.states = self.states.cuda()
+        self.rewards = self.rewards.cuda()
+        self.value_preds = self.value_preds.cuda()
+        self.returns = self.returns.cuda()
+        self.action_log_probs = self.action_log_probs.cuda()
+        self.actions = self.actions.cuda()
+        self.masks = self.masks.cuda()
+
+    def insert(self, step, current_obs, state, action, action_log_prob, value_pred, reward, mask):
+        self.observations[step + 1].copy_(current_obs)
+        self.states[step + 1].copy_(state)
+        self.actions[step].copy_(action)
+        self.action_log_probs[step].copy_(action_log_prob)
+        self.value_preds[step].copy_(value_pred)
+        self.rewards[step].copy_(reward)
+        self.masks[step + 1].copy_(mask)
+
+    def after_update(self):
+        self.observations[0].copy_(self.observations[-1])
+        self.states[0].copy_(self.states[-1])
+        self.masks[0].copy_(self.masks[-1])
+
+    def compute_returns(self, next_value, use_gae, gamma, tau):
+        if use_gae:
+            self.value_preds[-1] = next_value
+            gae = 0
+            for step in reversed(range(self.rewards.size(0))):
+                delta = self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
+                gae = delta + gamma * tau * self.masks[step + 1] * gae
+                self.returns[step] = gae + self.value_preds[step]
+        else:
+            self.returns[-1] = next_value
+            for step in reversed(range(self.rewards.size(0))):
+                self.returns[step] = self.returns[step + 1] * \
+                    gamma * self.masks[step + 1] + self.rewards[step]
+
+
+    def feed_forward_generator(self, advantages, num_mini_batch):
+        num_steps, num_processes = self.rewards.size()[0:2]
+        batch_size = num_processes * num_steps
+        mini_batch_size = batch_size // num_mini_batch
+        sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)
+        for indices in sampler:
+            indices = torch.LongTensor(indices)
+
+            if advantages.is_cuda:
+                indices = indices.cuda()
+
+            observations_batch = self.observations[:-1].view(-1,
+                                        *self.observations.size()[2:])[indices]
+            states_batch = self.states[:-1].view(-1, self.states.size(-1))[indices]
+            actions_batch = self.actions.view(-1, self.actions.size(-1))[indices]
+            return_batch = self.returns[:-1].view(-1, 1)[indices]
+            masks_batch = self.masks[:-1].view(-1, 1)[indices]
+            old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
+            adv_targ = advantages.view(-1, 1)[indices]
+
+            yield observations_batch, states_batch, actions_batch, \
+                return_batch, masks_batch, old_action_log_probs_batch, adv_targ
+
+    def recurrent_generator(self, advantages, num_mini_batch):
+        num_processes = self.rewards.size(1)
+        num_envs_per_batch = num_processes // num_mini_batch
+        perm = torch.randperm(num_processes)
+        for start_ind in range(0, num_processes, num_envs_per_batch):
+            observations_batch = []
+            states_batch = []
+            actions_batch = []
+            return_batch = []
+            masks_batch = []
+            old_action_log_probs_batch = []
+            adv_targ = []
+
+            for offset in range(num_envs_per_batch):
+                ind = perm[start_ind + offset]
+                observations_batch.append(self.observations[:-1, ind])
+                states_batch.append(self.states[0:1, ind])
+                actions_batch.append(self.actions[:, ind])
+                return_batch.append(self.returns[:-1, ind])
+                masks_batch.append(self.masks[:-1, ind])
+                old_action_log_probs_batch.append(self.action_log_probs[:, ind])
+                adv_targ.append(advantages[:, ind])
+
+            observations_batch = torch.cat(observations_batch, 0)
+            states_batch = torch.cat(states_batch, 0)
+            actions_batch = torch.cat(actions_batch, 0)
+            return_batch = torch.cat(return_batch, 0)
+            masks_batch = torch.cat(masks_batch, 0)
+            old_action_log_probs_batch = torch.cat(old_action_log_probs_batch, 0)
+            adv_targ = torch.cat(adv_targ, 0)
+
+            yield observations_batch, states_batch, actions_batch, \
+                return_batch, masks_batch, old_action_log_probs_batch, adv_targ

+ 45 - 0
basicrl/utils.py

@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+
+
+# Necessary for my KFAC implementation.
+class AddBias(nn.Module):
+    def __init__(self, bias):
+        super(AddBias, self).__init__()
+        self._bias = nn.Parameter(bias.unsqueeze(1))
+
+    def forward(self, x):
+        if x.dim() == 2:
+            bias = self._bias.t().view(1, -1)
+        else:
+            bias = self._bias.t().view(1, -1, 1, 1)
+
+        return x + bias
+
+# A temporary solution from the master branch.
+# https://github.com/pytorch/pytorch/blob/7752fe5d4e50052b3b0bbc9109e599f8157febc0/torch/nn/init.py#L312
+# Remove after the next version of PyTorch gets release.
+def orthogonal(tensor, gain=1):
+    if tensor.ndimension() < 2:
+        raise ValueError("Only tensors with 2 or more dimensions are supported")
+
+    rows = tensor.size(0)
+    cols = tensor[0].numel()
+    flattened = torch.Tensor(rows, cols).normal_(0, 1)
+
+    if rows < cols:
+        flattened.t_()
+
+    # Compute the qr factorization
+    q, r = torch.qr(flattened)
+    # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
+    d = torch.diag(r, 0)
+    ph = d.sign()
+    q *= ph.expand_as(q)
+
+    if rows < cols:
+        q.t_()
+
+    tensor.view_as(q).copy_(q)
+    tensor.mul_(gain)
+    return tensor

+ 142 - 0
basicrl/visualize.py

@@ -0,0 +1,142 @@
+# Copied from https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/visualize_atari.py
+# and https://github.com/emansim/baselines-mansimov/blob/master/baselines/a2c/load.py
+# Thanks to the author and OpenAI team!
+
+import glob
+import json
+import os
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.signal import medfilt
+matplotlib.rcParams.update({'font.size': 8})
+
+
+def smooth_reward_curve(x, y):
+    # Halfwidth of our smoothing convolution
+    halfwidth = min(31, int(np.ceil(len(x) / 30)))
+    k = halfwidth
+    xsmoo = x[k:-k]
+    ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='valid') / \
+        np.convolve(np.ones_like(y), np.ones(2 * k + 1), mode='valid')
+    downsample = max(int(np.floor(len(xsmoo) / 1e3)), 1)
+    return xsmoo[::downsample], ysmoo[::downsample]
+
+
+def fix_point(x, y, interval):
+    np.insert(x, 0, 0)
+    np.insert(y, 0, 0)
+
+    fx, fy = [], []
+    pointer = 0
+
+    ninterval = int(max(x) / interval + 1)
+
+    for i in range(ninterval):
+        tmpx = interval * i
+
+        while pointer + 1 < len(x) and tmpx > x[pointer + 1]:
+            pointer += 1
+
+        if pointer + 1 < len(x):
+            alpha = (y[pointer + 1] - y[pointer]) / \
+                (x[pointer + 1] - x[pointer])
+            tmpy = y[pointer] + alpha * (tmpx - x[pointer])
+            fx.append(tmpx)
+            fy.append(tmpy)
+
+    return fx, fy
+
+
+def load_data(indir, smooth, bin_size):
+    datas = []
+    infiles = glob.glob(os.path.join(indir, '*.monitor.csv'))
+
+    for inf in infiles:
+        with open(inf, 'r') as f:
+            f.readline()
+            f.readline()
+            for line in f:
+                tmp = line.split(',')
+                t_time = float(tmp[2])
+                tmp = [t_time, int(tmp[1]), float(tmp[0])]
+                datas.append(tmp)
+
+    datas = sorted(datas, key=lambda d_entry: d_entry[0])
+    result = []
+    timesteps = 0
+    for i in range(len(datas)):
+        result.append([timesteps, datas[i][-1]])
+        timesteps += datas[i][1]
+
+    if len(result) < bin_size:
+        return [None, None]
+
+    x, y = np.array(result)[:, 0], np.array(result)[:, 1]
+
+    if smooth == 1:
+        x, y = smooth_reward_curve(x, y)
+
+    if smooth == 2:
+        y = medfilt(y, kernel_size=9)
+
+    x, y = fix_point(x, y, bin_size)
+    return [x, y]
+
+
+color_defaults = [
+    '#1f77b4',  # muted blue
+    '#ff7f0e',  # safety orange
+    '#2ca02c',  # cooked asparagus green
+    '#d62728',  # brick red
+    '#9467bd',  # muted purple
+    '#8c564b',  # chestnut brown
+    '#e377c2',  # raspberry yogurt pink
+    '#7f7f7f',  # middle gray
+    '#bcbd22',  # curry yellow-green
+    '#17becf'  # blue-teal
+]
+
+
+def visdom_plot(viz, win, folder, game, name, bin_size=100, smooth=1):
+    tx, ty = load_data(folder, smooth, bin_size)
+    if tx is None or ty is None:
+        return win
+
+    fig = plt.figure()
+    plt.plot(tx, ty, label="{}".format(name))
+
+    # Ugly hack to detect atari
+    if game.find('NoFrameskip') > -1:
+        plt.xticks([1e6, 2e6, 4e6, 6e6, 8e6, 10e6],
+                   ["1M", "2M", "4M", "6M", "8M", "10M"])
+        plt.xlim(0, 10e6)
+    else:
+        plt.xticks([1e5, 2e5, 4e5, 6e5, 8e5, 1e5],
+                   ["0.1M", "0.2M", "0.4M", "0.6M", "0.8M", "1M"])
+        plt.xlim(0, 1e6)
+
+    plt.xlabel('Number of Timesteps')
+    plt.ylabel('Rewards')
+
+
+    plt.title(game)
+    plt.legend(loc=4)
+    plt.show()
+    plt.draw()
+
+    image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
+    plt.close(fig)
+
+    # Show it in visdom
+    image = np.transpose(image, (2, 0, 1))
+    return viz.image(image, win=win)
+
+
+if __name__ == "__main__":
+    from visdom import Visdom
+    viz = Visdom()
+    visdom_plot(viz, None, '/tmp/gym/', 'BreakOut', 'a2c', bin_size=100, smooth=1)

+ 1 - 0
gym_minigrid/__init__.py

@@ -0,0 +1 @@
+import gym_minigrid.envs.simple_envs

+ 2 - 0
gym_minigrid/envs/__init__.py

@@ -0,0 +1,2 @@
+from gym_minigrid.envs.minigrid_env import MiniGridEnv
+from gym_minigrid.envs.simple_envs import *

+ 780 - 0
gym_minigrid/envs/minigrid_env.py

@@ -0,0 +1,780 @@
+import math
+import gym
+from gym import error, spaces, utils
+from gym.utils import seeding
+import numpy as np
+from gym_minigrid.envs.rendering import *
+
+# Size in pixels of a cell in the full-scale human view
+CELL_PIXELS = 32
+
+# Number of cells (width and height) in the agent view
+AGENT_VIEW_SIZE = 7
+
+# Size of the array given as an observation to the agent
+OBS_ARRAY_SIZE = (AGENT_VIEW_SIZE, AGENT_VIEW_SIZE, 3)
+
+COLORS = {
+    'red'   : (255, 0, 0),
+    'green' : (0, 255, 0),
+    'blue'  : (0, 0, 255),
+    'purple': (112, 39, 195),
+    'yellow': (255, 255, 0),
+    'grey'  : (100, 100, 100)
+}
+
+# Used to map colors to integers
+COLOR_TO_IDX = {
+    'red'   : 0,
+    'green' : 1,
+    'blue'  : 2,
+    'purple': 3,
+    'yellow': 4,
+    'grey'  : 5
+}
+
+IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
+
+# Map of object type to integers
+OBJECT_TO_IDX = {
+    'empty'         : 0,
+    'wall'          : 1,
+    'door'          : 2,
+    'locked_door'   : 3,
+    'ball'          : 4,
+    'key'           : 5,
+    'goal'          : 6
+}
+
+IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
+
+class WorldObj:
+    """
+    Base class for grid world objects
+    """
+
+    def __init__(self, type, color):
+        assert type in OBJECT_TO_IDX, type
+        assert color in COLOR_TO_IDX, color
+        self.type = type
+        self.color = color
+        self.contains = None
+
+    def canOverlap(self):
+        """Can the agent overlap with this?"""
+        return False
+
+    def canPickup(self):
+        """Can the agent pick this up?"""
+        return False
+
+    def canContain(self):
+        """Can this contain another object?"""
+        return False
+
+    def toggle(self, env):
+        """Method to trigger/toggle an action this object performs"""
+        return False
+
+    def render(self, r):
+        assert False
+
+    def _setColor(self, r):
+        c = COLORS[self.color]
+        r.setLineColor(c[0], c[1], c[2])
+        r.setColor(c[0], c[1], c[2])
+
+class Goal(WorldObj):
+    def __init__(self):
+        super(Goal, self).__init__('goal', 'green')
+
+    def render(self, r):
+        self._setColor(r)
+        r.drawPolygon([
+            (0          , CELL_PIXELS),
+            (CELL_PIXELS, CELL_PIXELS),
+            (CELL_PIXELS,           0),
+            (0          ,           0)
+        ])
+
+class Wall(WorldObj):
+    def __init__(self):
+        super(Wall, self).__init__('wall', 'grey')
+
+    def render(self, r):
+        self._setColor(r)
+        r.drawPolygon([
+            (0          , CELL_PIXELS),
+            (CELL_PIXELS, CELL_PIXELS),
+            (CELL_PIXELS,           0),
+            (0          ,           0)
+        ])
+
+class Door(WorldObj):
+    def __init__(self, color, isOpen=False):
+        super(Door, self).__init__('door', color)
+        self.isOpen = isOpen
+
+    def render(self, r):
+        c = COLORS[self.color]
+        r.setLineColor(c[0], c[1], c[2])
+        r.setColor(0, 0, 0)
+
+        if self.isOpen:
+            r.drawPolygon([
+                (CELL_PIXELS-2, CELL_PIXELS),
+                (CELL_PIXELS  , CELL_PIXELS),
+                (CELL_PIXELS  ,           0),
+                (CELL_PIXELS-2,           0)
+            ])
+            return
+
+        r.drawPolygon([
+            (0          , CELL_PIXELS),
+            (CELL_PIXELS, CELL_PIXELS),
+            (CELL_PIXELS,           0),
+            (0          ,           0)
+        ])
+        r.drawPolygon([
+            (2          , CELL_PIXELS-2),
+            (CELL_PIXELS-2, CELL_PIXELS-2),
+            (CELL_PIXELS-2,           2),
+            (2          ,           2)
+        ])
+        r.drawCircle(CELL_PIXELS * 0.75, CELL_PIXELS * 0.5, 2)
+
+    def toggle(self, env):
+        if not self.isOpen:
+            self.isOpen = True
+            return True
+        return False
+
+    def canOverlap(self):
+        """The agent can only walk over this cell when the door is open"""
+        return self.isOpen
+
+class LockedDoor(WorldObj):
+    def __init__(self, color, isOpen=False):
+        super(LockedDoor, self).__init__('locked_door', color)
+        self.isOpen = isOpen
+
+    def render(self, r):
+        c = COLORS[self.color]
+        r.setLineColor(c[0], c[1], c[2])
+        r.setColor(0, 0, 0)
+
+        if self.isOpen:
+            r.drawPolygon([
+                (CELL_PIXELS-2, CELL_PIXELS),
+                (CELL_PIXELS  , CELL_PIXELS),
+                (CELL_PIXELS  ,           0),
+                (CELL_PIXELS-2,           0)
+            ])
+            return
+
+        r.drawPolygon([
+            (0          , CELL_PIXELS),
+            (CELL_PIXELS, CELL_PIXELS),
+            (CELL_PIXELS,           0),
+            (0          ,           0)
+        ])
+        r.drawPolygon([
+            (2          , CELL_PIXELS-2),
+            (CELL_PIXELS-2, CELL_PIXELS-2),
+            (CELL_PIXELS-2,           2),
+            (2          ,           2)
+        ])
+        r.drawLine(
+            CELL_PIXELS * 0.75,
+            CELL_PIXELS * 0.45,
+            CELL_PIXELS * 0.75,
+            CELL_PIXELS * 0.60
+        )
+
+    def toggle(self, env):
+        # If the player has the right key to open the door
+        if isinstance(env.carrying, Key) and env.carrying.color == self.color:
+            self.isOpen = True
+            # The key has been used, remove it from the agent
+            env.carrying = None
+            return True
+        return False
+
+    def canOverlap(self):
+        """The agent can only walk over this cell when the door is open"""
+        return self.isOpen
+
+class Ball(WorldObj):
+    def __init__(self, color='blue'):
+        super(Ball, self).__init__('ball', color)
+
+    def canPickup(self):
+        return True
+
+    def render(self, r):
+        self._setColor(r)
+        r.drawCircle(CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, 10)
+
+class Key(WorldObj):
+    def __init__(self, color='blue'):
+        super(Key, self).__init__('key', color)
+
+    def canPickup(self):
+        return True
+
+    def render(self, r):
+        self._setColor(r)
+
+        # Vertical quad
+        r.drawPolygon([
+            (16, 10),
+            (20, 10),
+            (20, 28),
+            (16, 28)
+        ])
+
+        # Teeth
+        r.drawPolygon([
+            (12, 19),
+            (16, 19),
+            (16, 21),
+            (12, 21)
+        ])
+        r.drawPolygon([
+            (12, 26),
+            (16, 26),
+            (16, 28),
+            (12, 28)
+        ])
+
+        r.drawCircle(18, 9, 6)
+        r.setLineColor(0, 0, 0)
+        r.setColor(0, 0, 0)
+        r.drawCircle(18, 9, 2)
+
+class Grid:
+    """
+    Represent a grid and operations on it
+    """
+
+    def __init__(self, width, height):
+        assert width >= 4
+        assert height >= 4
+
+        self.width = width
+        self.height = height
+
+        self.grid = [None] * width * height
+
+    def copy(self):
+        from copy import deepcopy
+        return deepcopy(self)
+
+    def set(self, i, j, v):
+        assert i >= 0 and i < self.width
+        assert j >= 0 and j < self.height
+        self.grid[j * self.width + i] = v
+
+    def get(self, i, j):
+        assert i >= 0 and i < self.width
+        assert j >= 0 and j < self.height
+        return self.grid[j * self.width + i]
+
+    def rotateLeft(self):
+        """
+        Rotate the grid to the left (counter-clockwise)
+        """
+
+        grid = Grid(self.width, self.height)
+
+        for j in range(0, self.height):
+            for i in range(0, self.width):
+                v = self.get(self.width - 1 - j, i)
+                grid.set(i, j, v)
+
+        return grid
+
+    def slice(self, topX, topY, width, height):
+        """
+        Get a subset of the grid
+        """
+
+        grid = Grid(width, height)
+
+        for j in range(0, height):
+            for i in range(0, width):
+                x = topX + i
+                y = topY + j
+
+                if x >= 0 and x < self.width and \
+                   y >= 0 and y < self.height:
+                    v = self.get(x, y)
+                else:
+                    v = Wall()
+
+                grid.set(i, j, v)
+
+        return grid
+
+    def render(self, r, tileSize):
+        """
+        Render this grid at a given scale
+        :param r: target renderer object
+        :param tileSize: tile size in pixels
+        """
+
+        assert r.width == self.width * tileSize
+        assert r.height == self.height * tileSize
+
+        # Total grid size at native scale
+        widthPx = self.width * CELL_PIXELS
+        heightPx = self.height * CELL_PIXELS
+
+        # Draw background (out-of-world) tiles the same colors as walls
+        # so the agent understands these areas are not reachable
+        c = COLORS['grey']
+        r.setLineColor(c[0], c[1], c[2])
+        r.setColor(c[0], c[1], c[2])
+        r.drawPolygon([
+            (0    , heightPx),
+            (widthPx, heightPx),
+            (widthPx,      0),
+            (0    ,      0)
+        ])
+
+        r.push()
+
+        # Internally, we draw at the "large" full-grid resolution, but we
+        # use the renderer to scale back to the desired size
+        r.scale(tileSize / CELL_PIXELS, tileSize / CELL_PIXELS)
+
+        # Draw the background of the in-world cells black
+        r.fillRect(
+            0,
+            0,
+            widthPx,
+            heightPx,
+            0, 0, 0
+        )
+
+        # Draw grid lines
+        r.setLineColor(100, 100, 100)
+        for rowIdx in range(0, self.height):
+            y = CELL_PIXELS * rowIdx
+            r.drawLine(0, y, widthPx, y)
+        for colIdx in range(0, self.width):
+            x = CELL_PIXELS * colIdx
+            r.drawLine(x, 0, x, heightPx)
+
+        # Render the grid
+        for j in range(0, self.height):
+            for i in range(0, self.width):
+                cell = self.get(i, j)
+                if cell == None:
+                    continue
+                r.push()
+                r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
+                cell.render(r)
+                r.pop()
+
+        r.pop()
+
+    def encode(self):
+        """
+        Produce a compact numpy encoding of the grid
+        """
+
+        codeSize = self.width * self.height * 3
+
+        array = np.zeros(shape=(self.width, self.height, 3), dtype='uint8')
+
+        for j in range(0, self.height):
+            for i in range(0, self.width):
+
+                v = self.get(i, j)
+
+                if v == None:
+                    continue
+
+                array[i, j, 0] = OBJECT_TO_IDX[v.type]
+                array[i, j, 1] = COLOR_TO_IDX[v.color]
+
+                if hasattr(v, 'isOpen') and v.isOpen:
+                    array[i, j, 2] = 1
+
+        return array
+
+    def decode(array):
+        """
+        Decode an array grid encoding back into a grid
+        """
+
+        width = array.shape[0]
+        height = array.shape[1]
+        assert array.shape[2] == 3
+
+        grid = Grid(width, height)
+
+        for j in range(0, height):
+            for i in range(0, width):
+
+                typeIdx  = array[i, j, 0]
+                colorIdx = array[i, j, 1]
+                openIdx  = array[i, j, 2]
+
+                if typeIdx == 0:
+                    continue
+
+                objType = IDX_TO_OBJECT[typeIdx]
+                color = IDX_TO_COLOR[colorIdx]
+                isOpen = True if openIdx == 1 else 0
+
+                if objType == 'wall':
+                    v = Wall()
+                elif objType == 'ball':
+                    v = Ball(color)
+                elif objType == 'key':
+                    v = Key(color)
+                elif objType == 'door':
+                    v = Door(color, isOpen)
+                elif objType == 'locked_door':
+                    v = LockedDoor(color, isOpen)
+                elif objType == 'goal':
+                    v = Goal()
+                else:
+                    assert False, "unknown obj type in decode '%s'" % objType
+
+                grid.set(i, j, v)
+
+        return grid
+
+class MiniGridEnv(gym.Env):
+    """
+    2D grid world game environment
+    """
+
+    metadata = {
+        'render.modes': ['human', 'rgb_array', 'pixmap'],
+        'video.frames_per_second' : 10
+    }
+
+    # Possible actions
+    NUM_ACTIONS = 4
+    ACTION_LEFT = 0
+    ACTION_RIGHT = 1
+    ACTION_FORWARD = 2
+    ACTION_TOGGLE = 3
+
+    def __init__(self, gridSize=16, maxSteps=100):
+        # Renderer object used to render the whole grid (full-scale)
+        self.gridRender = None
+
+        # Renderer used to render observations (small-scale agent view)
+        self.obsRender = None
+
+        # Actions are discrete integer values
+        self.action_space = spaces.Discrete(MiniGridEnv.NUM_ACTIONS)
+
+        # The observations are RGB images
+        self.observation_space = spaces.Box(
+            low=0,
+            high=255,
+            shape=OBS_ARRAY_SIZE
+        )
+
+        self.reward_range = (-1, 1000)
+
+        # Environment configuration
+        self.gridSize = gridSize
+        self.maxSteps = maxSteps
+        self.startPos = (1, 1)
+        self.startDir = 0
+
+        # Initialize the state
+        self.seed()
+        self.reset()
+
+    def _genGrid(self, width, height):
+        """
+        Generate a new grid
+        """
+
+        # Initialize the grid
+        grid = Grid(width, height)
+
+        # Place walls around the edges
+        for i in range(0, width):
+            grid.set(i, 0, Wall())
+            grid.set(i, height - 1, Wall())
+        for j in range(0, height):
+            grid.set(0, j, Wall())
+            grid.set(height - 1, j, Wall())
+
+        # Place a goal in the bottom-left corner
+        grid.set(width - 2, height - 2, Goal())
+
+        return grid
+
+    def _reset(self):
+        # Place the agent in the starting position and direction
+        self.agentPos = self.startPos
+        self.agentDir = self.startDir
+
+        # Item picked up, being carried, initially nothing
+        self.carrying = None
+
+        # Step count since episode start
+        self.stepCount = 0
+
+        # Restore the initial grid
+        self.grid = self.seedGrid.copy()
+
+        # Return first observation
+        obs = self._genObs()
+        return obs
+
+    def _seed(self, seed=None):
+        """
+        The seed function sets the random elements of the environment,
+        and initializes the world.
+        """
+
+        # By default, make things deterministic, always
+        # produce the same environment
+        if seed == None:
+            seed = 1337
+
+        # Seed the random number generator
+        self.np_random, _ = seeding.np_random(seed)
+
+        self.grid = self._genGrid(self.gridSize, self.gridSize)
+
+        # Store a copy of the grid so we can restore it on reset
+        self.seedGrid = self.grid.copy()
+
+        return [seed]
+
+    def _randInt(self, low, high):
+        return self.np_random.randint(low, high)
+
+    def _randElem(self, iterable):
+        lst = list(iterable)
+        idx = self._randInt(0, len(lst))
+        return lst[idx]
+
+    def getStepsRemaining(self):
+        return self.maxSteps - self.stepCount
+
+    def getDirVec(self):
+        """
+        Get the direction vector for the agent, pointing in the direction
+        of forward movement.
+        """
+
+        # Pointing right
+        if self.agentDir == 0:
+            return (1, 0)
+        # Down (positive Y)
+        elif self.agentDir == 1:
+            return (0, 1)
+        # Pointing left
+        elif self.agentDir == 2:
+            return (-1, 0)
+        # Up (negative Y)
+        elif self.agentDir == 3:
+            return (0, -1)
+        else:
+            assert False
+
+    def getViewExts(self):
+        """
+        Get the extents of the square set of tiles visible to the agent
+        Note: the bottom extent indices are not included in the set
+        """
+
+        # Facing right
+        if self.agentDir == 0:
+            topX = self.agentPos[0]
+            topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
+        # Facing down
+        elif self.agentDir == 1:
+            topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
+            topY = self.agentPos[1]
+        # Facing right
+        elif self.agentDir == 2:
+            topX = self.agentPos[0] - AGENT_VIEW_SIZE + 1
+            topY = self.agentPos[1] - AGENT_VIEW_SIZE // 2
+        # Facing up
+        elif self.agentDir == 3:
+            topX = self.agentPos[0] - AGENT_VIEW_SIZE // 2
+            topY = self.agentPos[1] - AGENT_VIEW_SIZE + 1
+        else:
+            assert False
+
+        botX = topX + AGENT_VIEW_SIZE
+        botY = topY + AGENT_VIEW_SIZE
+
+        return (topX, topY, botX, botY)
+
+    def _step(self, action):
+        self.stepCount += 1
+
+        reward = 0
+        done = False
+
+        # Rotate left
+        if action == MiniGridEnv.ACTION_LEFT:
+            self.agentDir -= 1
+            if self.agentDir < 0:
+                self.agentDir += 4
+
+        # Rotate right
+        elif action == MiniGridEnv.ACTION_RIGHT:
+            self.agentDir = (self.agentDir + 1) % 4
+
+        # Move forward
+        elif action == MiniGridEnv.ACTION_FORWARD:
+            u, v = self.getDirVec()
+            newPos = (self.agentPos[0] + u, self.agentPos[1] + v)
+            targetCell = self.grid.get(newPos[0], newPos[1])
+            if targetCell == None or targetCell.canOverlap():
+                self.agentPos = newPos
+            elif targetCell.type == 'goal':
+                done = True
+                reward = 1000 - self.stepCount
+
+        # Pick up or trigger/activate an item
+        elif action == MiniGridEnv.ACTION_TOGGLE:
+            u, v = self.getDirVec()
+            cell = self.grid.get(self.agentPos[0] + u, self.agentPos[1] + v)
+            if cell and cell.canPickup() and self.carrying is None:
+                self.carrying = cell
+                self.grid.set(self.agentPos[0] + u, self.agentPos[1] + v, None)
+            elif cell:
+                cell.toggle(self)
+
+        else:
+            assert False, "unknown action"
+
+        if self.stepCount >= self.maxSteps:
+            done = True
+
+        obs = self._genObs()
+
+        return obs, reward, done, {}
+
+    def _genObs(self):
+        """
+        Generate the agent's view (partially observable, low-resolution encoding)
+        """
+
+        topX, topY, botX, botY = self.getViewExts()
+
+        grid = self.grid.slice(topX, topY, AGENT_VIEW_SIZE, AGENT_VIEW_SIZE)
+
+        for i in range(self.agentDir + 1):
+            grid = grid.rotateLeft()
+
+        obs = grid.encode()
+
+        return obs
+
+    def getObsRender(self, obs):
+        """
+        Render an agent observation for visualization
+        """
+
+        if self.obsRender == None:
+            self.obsRender = Renderer(
+                AGENT_VIEW_SIZE * CELL_PIXELS // 2,
+                AGENT_VIEW_SIZE * CELL_PIXELS // 2
+            )
+
+        r = self.obsRender
+
+        r.beginFrame()
+
+        grid = Grid.decode(obs)
+
+        # Render the whole grid
+        grid.render(r, CELL_PIXELS // 2)
+
+        # Draw the agent
+        r.push()
+        r.scale(0.5, 0.5)
+        r.translate(
+            CELL_PIXELS * (0.5 + AGENT_VIEW_SIZE // 2),
+            CELL_PIXELS * (AGENT_VIEW_SIZE - 0.5)
+        )
+        r.rotate(3 * 90)
+        r.setLineColor(255, 0, 0)
+        r.setColor(255, 0, 0)
+        r.drawPolygon([
+            (-12, 10),
+            ( 12,  0),
+            (-12, -10)
+        ])
+        r.pop()
+
+        r.endFrame()
+
+        return r.getPixmap()
+
+    def _render(self, mode='human', close=False):
+        """
+        Render the whole-grid human view
+        """
+
+        if close:
+            if self.gridRender:
+                self.gridRender.close()
+            return
+
+        if self.gridRender is None:
+            self.gridRender = Renderer(
+                self.gridSize * CELL_PIXELS,
+                self.gridSize * CELL_PIXELS,
+                True if mode == 'human' else False
+            )
+
+        r = self.gridRender
+
+        r.beginFrame()
+
+        # Render the whole grid
+        self.grid.render(r, CELL_PIXELS)
+
+        # Draw the agent
+        r.push()
+        r.translate(
+            CELL_PIXELS * (self.agentPos[0] + 0.5),
+            CELL_PIXELS * (self.agentPos[1] + 0.5)
+        )
+        r.rotate(self.agentDir * 90)
+        r.setLineColor(255, 0, 0)
+        r.setColor(255, 0, 0)
+        r.drawPolygon([
+            (-12, 10),
+            ( 12,  0),
+            (-12, -10)
+        ])
+        r.pop()
+
+        # Highlight what the agent can see
+        topX, topY, botX, botY = self.getViewExts()
+        r.fillRect(
+            topX * CELL_PIXELS,
+            topY * CELL_PIXELS,
+            AGENT_VIEW_SIZE * CELL_PIXELS,
+            AGENT_VIEW_SIZE * CELL_PIXELS,
+            200, 200, 200, 75
+        )
+
+        r.endFrame()
+
+        if mode == 'rgb_array':
+            return r.getArray()
+        elif mode == 'pixmap':
+            return r.getPixmap()
+
+        return r

+ 135 - 0
gym_minigrid/envs/rendering.py

@@ -0,0 +1,135 @@
+import numpy as np
+from PyQt5.QtCore import Qt
+from PyQt5.QtGui import QImage, QPixmap, QPainter, QColor, QPolygon
+from PyQt5.QtCore import QPoint, QSize, QRect
+from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget
+from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QLabel, QFrame
+
+class Window(QMainWindow):
+    """
+    Simple application window to render the environment into
+    """
+
+    def __init__(self):
+        super().__init__()
+
+        self.setWindowTitle('MiniGrid Gym Environment')
+
+        self.imgLabel = QLabel()
+        self.imgLabel.setFrameStyle(QFrame.Panel | QFrame.Sunken)
+
+        # Arrange widgets horizontally
+        hbox = QHBoxLayout()
+        hbox.addStretch(1)
+        hbox.addWidget(self.imgLabel)
+        hbox.addStretch(1)
+
+        # Create a main widget for the window
+        mainWidget = QWidget(self)
+        self.setCentralWidget(mainWidget)
+        mainWidget.setLayout(hbox)
+
+        # Show the application window
+        self.show()
+        self.setFocus()
+
+        self.closed = False
+
+    def closeEvent(self, event):
+        self.closed = True
+
+    def setPixmap(self, pixmap):
+        self.imgLabel.setPixmap(pixmap)
+
+class Renderer:
+    def __init__(self, width, height, ownWindow=False):
+        self.width = width
+        self.height = height
+
+        self.img = QImage(width, height, QImage.Format_RGB888)
+        self.painter = QPainter()
+
+        self.window = None
+        if ownWindow:
+            self.app = QApplication([])
+            self.window = Window()
+
+    def close(self):
+        """
+        Deallocate resources used
+        """
+        pass
+
+    def beginFrame(self):
+        self.painter.begin(self.img)
+        self.painter.setRenderHint(QPainter.Antialiasing, False)
+
+        # Clear the background
+        self.painter.setBrush(QColor(0, 0, 0))
+        self.painter.drawRect(0, 0, self.width - 1, self.height - 1)
+
+    def endFrame(self):
+        self.painter.end()
+
+        if self.window:
+            if self.window.closed:
+                self.window = None
+            else:
+                self.window.setPixmap(self.getPixmap())
+                self.app.processEvents()
+
+    def getPixmap(self):
+        return QPixmap.fromImage(self.img)
+
+    def getArray(self):
+        """
+        Get a numpy array of RGB pixel values.
+        The size argument should be (3,w,h)
+        """
+
+        width = self.width
+        height = self.height
+        shape = (width, height, 3)
+
+        numBytes = self.width * self.height * 3
+        buf = self.img.bits().asstring(numBytes)
+        output = np.frombuffer(buf, dtype='uint8')
+        output = output.reshape(shape)
+
+        return output
+
+    def push(self):
+        self.painter.save()
+
+    def pop(self):
+        self.painter.restore()
+
+    def rotate(self, degrees):
+        self.painter.rotate(degrees)
+
+    def translate(self, x, y):
+        self.painter.translate(x, y)
+
+    def scale(self, x, y):
+        self.painter.scale(x, y)
+
+    def setLineColor(self, r, g, b, a=255):
+        self.painter.setPen(QColor(r, g, b, a))
+
+    def setColor(self, r, g, b, a=255):
+        self.painter.setBrush(QColor(r, g, b, a))
+
+    def drawLine(self, x0, y0, x1, y1):
+        self.painter.drawLine(x0, y0, x1, y1)
+
+    def drawCircle(self, x, y, r):
+        center = QPoint(x, y)
+        self.painter.drawEllipse(center, r, r)
+
+    def drawPolygon(self, points):
+        """Takes a list of points (tuples) as input"""
+        points = map(lambda p: QPoint(p[0], p[1]), points)
+        self.painter.drawPolygon(QPolygon(points))
+
+    def fillRect(self, x, y, width, height, r, g, b, a=255):
+        self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a))

+ 444 - 0
gym_minigrid/envs/simple_envs.py

@@ -0,0 +1,444 @@
+from gym.envs.registration import register
+from gym_minigrid.envs.minigrid_env import *
+
+class EmptyEnv(MiniGridEnv):
+    """
+    Empty grid environment, no obstacles, sparse reward
+    """
+
+    def __init__(self, size=8):
+        super(EmptyEnv, self).__init__(gridSize=size, maxSteps=2 * size)
+
+class EmptyEnv6x6(EmptyEnv):
+    def __init__(self):
+        super(EmptyEnv6x6, self).__init__(size=6)
+
+register(
+    id='MiniGrid-Empty-8x8-v0',
+    entry_point='gym_MiniGrid.envs:EmptyEnv',
+    reward_threshold=1000.0
+)
+
+register(
+    id='-Empty-6x6-v0',
+    entry_point='gym_minigrid.envs:EmptyEnv6x6',
+    reward_threshold=1000.0
+)
+
+class DoorKeyEnv(MiniGridEnv):
+    """
+    Environment with a door and key, sparse reward
+    """
+
+    def __init__(self, size=8):
+        super(DoorKeyEnv, self).__init__(gridSize=size, maxSteps=4 * size)
+
+    def _genGrid(self, width, height):
+        grid = super(DoorKeyEnv, self)._genGrid(width, height)
+        assert width == height
+        gridSz = width
+
+        # Create a vertical splitting wall
+        splitIdx = self._randInt(2, gridSz-3)
+        for i in range(0, gridSz):
+            grid.set(splitIdx, i, Wall())
+
+        # Place a door in the wall
+        doorIdx = self._randInt(1, gridSz-2)
+        grid.set(splitIdx, doorIdx, Door('yellow'))
+
+        # Place a key on the left side
+        #keyIdx = self._randInt(1 + gridSz // 2, gridSz-2)
+        keyIdx = gridSz-2
+        grid.set(1, keyIdx, Key('yellow'))
+
+        return grid
+
+class DoorKeyEnv16x16(DoorKeyEnv):
+    def __init__(self):
+        super(DoorKeyEnv16x16, self).__init__(size=16)
+
+register(
+    id='-Door-Key-8x8-v0',
+    entry_point='gym_minigrid.envs:DoorKeyEnv',
+    reward_threshold=1000.0
+)
+
+register(
+    id='-Door-Key-16x16-v0',
+    entry_point='gym_minigrid.envs:DoorKeyEnv16x16',
+    reward_threshold=1000.0
+)
+
+class Room:
+    def __init__(self,
+        top,
+        size,
+        entryDoorPos,
+        exitDoorPos
+    ):
+        self.top = top
+        self.size = size
+        self.entryDoorPos = entryDoorPos
+        self.exitDoorPos = exitDoorPos
+
+class MultiRoomEnv(MiniGridEnv):
+    """
+    Environment with multiple rooms (subgoals)
+    """
+
+    def __init__(self,
+        minNumRooms,
+        maxNumRooms,
+        maxRoomSize=10
+    ):
+        assert minNumRooms > 0
+        assert maxNumRooms >= minNumRooms
+        assert maxRoomSize >= 4
+
+        self.minNumRooms = minNumRooms
+        self.maxNumRooms = maxNumRooms
+        self.maxRoomSize = maxRoomSize
+
+        self.rooms = []
+
+        super(MultiRoomEnv, self).__init__(
+            gridSize=25,
+            maxSteps=self.maxNumRooms * 20
+        )
+
+    def _genGrid(self, width, height):
+
+        roomList = []
+
+        # Choose a random number of rooms to generate
+        numRooms = self._randInt(self.minNumRooms, self.maxNumRooms+1)
+
+        while len(roomList) < numRooms:
+            curRoomList = []
+
+            entryDoorPos = (
+                self._randInt(0, width - 2),
+                self._randInt(0, width - 2)
+            )
+
+            # Recursively place the rooms
+            self._placeRoom(
+                numRooms,
+                roomList=curRoomList,
+                minSz=4,
+                maxSz=self.maxRoomSize,
+                entryDoorWall=2,
+                entryDoorPos=entryDoorPos
+            )
+
+            if len(curRoomList) > len(roomList):
+                roomList = curRoomList
+
+        # Store the list of rooms in this environment
+        assert len(roomList) > 0
+        self.rooms = roomList
+
+        # Randomize the starting agent position and direction
+        topX, topY = roomList[0].top
+        sizeX, sizeY = roomList[0].size
+        self.startPos = (
+            self._randInt(topX + 1, topX + sizeX - 2),
+            self._randInt(topY + 1, topY + sizeY - 2)
+        )
+        self.startDir = self._randInt(0, 4)
+
+        # Create the grid
+        grid = Grid(width, height)
+        wall = Wall()
+
+        prevDoorColor = None
+
+        # For each room
+        for idx, room in enumerate(roomList):
+
+            topX, topY = room.top
+            sizeX, sizeY = room.size
+
+            # Draw the top and bottom walls
+            for i in range(0, sizeX):
+                grid.set(topX + i, topY, wall)
+                grid.set(topX + i, topY + sizeY - 1, wall)
+
+            # Draw the left and right walls
+            for j in range(0, sizeY):
+                grid.set(topX, topY + j, wall)
+                grid.set(topX + sizeX - 1, topY + j, wall)
+
+            # If this isn't the first room, place the entry door
+            if idx > 0:
+                # Pick a door color different from the previous one
+                doorColors = set(COLORS.keys())
+                if prevDoorColor:
+                    doorColors.remove(prevDoorColor)
+                doorColor = self._randElem(doorColors)
+
+                entryDoor = Door(doorColor)
+                grid.set(*room.entryDoorPos, entryDoor)
+                prevDoorColor = doorColor
+
+                prevRoom = roomList[idx-1]
+                prevRoom.exitDoorPos = entryDoorPos
+
+        # Place the final goal
+        while True:
+            goalX = self._randInt(topX + 1, topX + sizeX - 1)
+            goalY = self._randInt(topY + 1, topY + sizeY - 1)
+
+            # Make sure the goal doesn't overlap with the agent
+            if (goalX, goalY) != self.startPos:
+                grid.set(goalX, goalY, Goal())
+                break
+
+        return grid
+
+    def _placeRoom(
+        self,
+        numLeft,
+        roomList,
+        minSz,
+        maxSz,
+        entryDoorWall,
+        entryDoorPos
+    ):
+        # Choose the room size randomly
+        sizeX = self._randInt(minSz, maxSz+1)
+        sizeY = self._randInt(minSz, maxSz+1)
+
+        # The first room will be at the door position
+        if len(roomList) == 0:
+            topX, topY = entryDoorPos
+        # Entry on the right
+        elif entryDoorWall == 0:
+            topX = entryDoorPos[0] - sizeX + 1
+            y = entryDoorPos[1]
+            topY = self._randInt(y - sizeY + 2, y)
+        # Entry wall on the south
+        elif entryDoorWall == 1:
+            x = entryDoorPos[0]
+            topX = self._randInt(x - sizeX + 2, x)
+            topY = entryDoorPos[1] - sizeY + 1
+        # Entry wall on the left
+        elif entryDoorWall == 2:
+            topX = entryDoorPos[0]
+            y = entryDoorPos[1]
+            topY = self._randInt(y - sizeY + 2, y)
+        # Entry wall on the top
+        elif entryDoorWall == 3:
+            x = entryDoorPos[0]
+            topX = self._randInt(x - sizeX + 2, x)
+            topY = entryDoorPos[1]
+        else:
+            assert False, entryDoorWall
+
+        # If the room is out of the grid, can't place a room here
+        if topX < 0 or topY < 0:
+            return False
+        if topX + sizeX > self.gridSize or topY + sizeY >= self.gridSize:
+            return False
+
+        # If the room intersects with previous rooms, can't place it here
+        for room in roomList[:-1]:
+            nonOverlap = \
+                topX + sizeX < room.top[0] or \
+                room.top[0] + room.size[0] <= topX or \
+                topY + sizeY < room.top[1] or \
+                room.top[1] + room.size[1] <= topY
+
+            if not nonOverlap:
+                return False
+
+        # Add this room to the list
+        roomList.append(Room(
+            (topX, topY),
+            (sizeX, sizeY),
+            entryDoorPos,
+            None
+        ))
+
+        # If this was the last room, stop
+        if numLeft == 1:
+            return True
+
+        # Try placing the next room
+        for i in range(0, 8):
+
+            # Pick which wall to place the out door on
+            wallSet = set((0, 1, 2, 3))
+            wallSet.remove(entryDoorWall)
+            exitDoorWall = self._randElem(wallSet)
+            nextEntryWall = (exitDoorWall + 2) % 4
+
+            # Pick the exit door position
+            # Exit on right wall
+            if exitDoorWall == 0:
+                exitDoorPos = (
+                    topX + sizeX - 1,
+                    topY + self._randInt(1, sizeY - 1)
+                )
+            # Exit on south wall
+            elif exitDoorWall == 1:
+                exitDoorPos = (
+                    topX + self._randInt(1, sizeX - 1),
+                    topY + sizeY - 1
+                )
+            # Exit on left wall
+            elif exitDoorWall == 2:
+                exitDoorPos = (
+                    topX,
+                    topY + self._randInt(1, sizeY - 1)
+                )
+            # Exit on north wall
+            elif exitDoorWall == 3:
+                exitDoorPos = (
+                    topX + self._randInt(1, sizeX - 1),
+                    topY
+                )
+            else:
+                assert False
+
+            # Recursively create the other rooms
+            success = self._placeRoom(
+                numLeft - 1,
+                roomList=roomList,
+                minSz=minSz,
+                maxSz=maxSz,
+                entryDoorWall=nextEntryWall,
+                entryDoorPos=exitDoorPos
+            )
+
+            if success:
+                break
+
+        return True
+
+class MultiRoomEnvN6(MultiRoomEnv):
+    def __init__(self):
+        super(MultiRoomEnvN6, self).__init__(
+            minNumRooms=6,
+            maxNumRooms=6
+        )
+
+register(
+    id='MiniGrid-Multi-Room-N6-v0',
+    entry_point='gym_minigrid.envs:MultiRoomEnvN6',
+    reward_threshold=1000.0
+)
+
+class FetchEnv(MiniGridEnv):
+    """
+    Environment in which the agent has to fetch a random object
+    named using English text strings
+    """
+
+    def __init__(
+        self,
+        size=8,
+        numObjs=3):
+        self.numObjs = numObjs
+        super(FetchEnv, self).__init__(gridSize=size, maxSteps=5*size)
+
+    def _genGrid(self, width, height):
+        assert width == height
+        gridSz = width
+
+        # Create a grid surrounded by walls
+        grid = Grid(width, height)
+        for i in range(0, width):
+            grid.set(i, 0, Wall())
+            grid.set(i, height-1, Wall())
+        for j in range(0, height):
+            grid.set(0, j, Wall())
+            grid.set(width-1, j, Wall())
+
+        types = ['key', 'ball']
+        colors = list(COLORS.keys())
+
+        objs = []
+
+        # For each object to be generated
+        for i in range(0, self.numObjs):
+            objType = self._randElem(types)
+            objColor = self._randElem(colors)
+
+            if objType == 'key':
+                obj = Key(objColor)
+            elif objType == 'ball':
+                obj = Ball(objColor)
+
+            while True:
+                pos = (
+                    self._randInt(1, gridSz - 1),
+                    self._randInt(1, gridSz - 1)
+                )
+
+                if pos != self.startPos:
+                    grid.set(*pos, obj)
+                    break
+
+            objs.append(obj)
+
+        # Choose a random object to be picked up
+        target = objs[self._randInt(0, len(objs))]
+        self.targetType = target.type
+        self.targetColor = target.color
+
+        descStr = '%s %s' % (self.targetColor, self.targetType)
+
+        # Generate the mission string
+        idx = self._randInt(0, 5)
+        if idx == 0:
+            self.mission = 'get a %s' % descStr
+        elif idx == 1:
+            self.mission = 'go get a %s' % descStr
+        elif idx == 2:
+            self.mission = 'fetch a %s' % descStr
+        elif idx == 3:
+            self.mission = 'go fetch a %s' % descStr
+        elif idx == 4:
+            self.mission = 'you must fetch a %s' % descStr
+        assert hasattr(self, 'mission')
+
+        return grid
+
+    def _reset(self):
+        obs = MiniGridEnv._reset(self)
+
+        obs = {
+            'image': obs,
+            'mission': self.mission,
+            'advice' : ''
+        }
+
+        return obs
+
+    def _step(self, action):
+        obs, reward, done, info = MiniGridEnv._step(self, action)
+
+        if self.carrying:
+            if self.carrying.color == self.targetColor and \
+               self.carrying.type == self.targetType:
+                reward = 1000 - self.stepCount
+                done = True
+            else:
+                reward = -1000
+                done = True
+
+        obs = {
+            'image': obs,
+            'mission': self.mission,
+            'advice': ''
+        }
+
+        return obs, reward, done, info
+
+register(
+    id='MiniGrid-Fetch-8x8-v0',
+    entry_point='gym_minigrid.envs:FetchEnv',
+    reward_threshold=900.0
+)

+ 13 - 0
setup.py

@@ -0,0 +1,13 @@
+from setuptools import setup
+
+setup(
+    name='gym_minigrid',
+    version='0.0.1',
+    keywords='memory, environment, agent, rl, openaigym, openai-gym, gym',
+    install_requires=[
+        'gym>=0.9.0',
+        'numpy>=1.10.0',
+        'pyqt5',
+        'matplotlib'
+    ]
+)

+ 37 - 0
standalone.py

@@ -0,0 +1,37 @@
+#!/usr/bin/env python3
+
+from __future__ import division, print_function
+
+import numpy
+import gym
+
+import gym_minigrid
+
+def main():
+
+    env = gym.make('MiniGrid-Multi-Room-N6-v0')
+    env.reset()
+
+    # Create a window to render into
+    renderer = env.render('human')
+
+    while True:
+
+        env.render('human')
+
+        action = 0
+
+        obs, reward, done, info = env.step(action)
+
+        print('reward=%s' % reward)
+
+        if done:
+            print('done!')
+            env.reset()
+
+        # If the window was closed
+        if not renderer.window:
+            break
+
+if __name__ == "__main__":
+    main()