浏览代码

Add Training Tutorial and CITATION.cff (#379)

Bolun 1 年之前
父节点
当前提交
546c040a94
共有 3 个文件被更改,包括 119 次插入0 次删除
  1. 51 0
      CITATION.cff
  2. 67 0
      docs/content/training.md
  3. 1 0
      docs/index.md

+ 51 - 0
CITATION.cff

@@ -0,0 +1,51 @@
+cff-version: 1.2.0
+title: Minigrid
+message: >-
+  If you use this software, please cite it using the
+  metadata from this file.
+authors:
+- family-names: Chevalier-Boisvert
+  given-names: Maxime
+- family-names: Dai
+  given-names: Bolun
+- family-names: Towers
+  given-names: Mark
+- family-names: de Lazcano
+  given-names: Rodrigo
+- family-names: Willems
+  given-names: Lucas
+- family-names: Lahlou
+  given-names: Salem
+- family-names: Pal
+  given-names: Suman
+- family-names: Castro
+  given-names: Pablo Samuel
+- family-names: Terry
+  given-names: Jordan
+url: "https://github.com/Farama-Foundation/Minigrid"
+
+preferred-citation:
+  type: article
+  authors:
+  - family-names: Chevalier-Boisvert
+    given-names: Maxime
+  - family-names: Dai
+    given-names: Bolun
+  - family-names: Towers
+    given-names: Mark
+  - family-names: de Lazcano
+    given-names: Rodrigo
+  - family-names: Willems
+    given-names: Lucas
+  - family-names: Lahlou
+    given-names: Salem
+  - family-names: Pal
+    given-names: Suman
+  - family-names: Castro
+    given-names: Pablo Samuel
+  - family-names: Terry
+    given-names: Jordan
+  journal: CoRR
+  title: Minigrid
+  volume: abs/2306.13831
+  year: 2023

+ 67 - 0
docs/content/training.md

@@ -0,0 +1,67 @@
+---
+layout: "contents"
+title: Training Minigrid Environments
+firstpage:
+---
+
+## Training Minigrid Environments
+
+The environments in the Minigrid library can be trained easily using [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/). In this tutorial we show how a PPO agent can be trained on the `MiniGrid-Empty-16x16-v0` environment.
+
+## Create Custom Feature Extractor 
+
+Although `StableBaselines3` is fully compatible with `Gymnasium`-based environments, including Minigrid, the default CNN architecture does not directly support the Minigrid observation space. Thus, to train an agent on Minigrid environments, we therefore need to create a custom feature extractor. This can be done by creating a feature extractor class that inherits from `stable_baselines3.common.torch_layers.BaseFeaturesExtractor`
+
+```python
+class MinigridFeaturesExtractor(BaseFeaturesExtractor):
+    def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None:
+        super().__init__(observation_space, features_dim)
+        n_input_channels = observation_space.shape[0]
+        self.cnn = nn.Sequential(
+            nn.Conv2d(n_input_channels, 16, (2, 2)),
+            nn.ReLU(),
+            nn.Conv2d(16, 32, (2, 2)),
+            nn.ReLU(),
+            nn.Conv2d(32, 64, (2, 2)),
+            nn.ReLU(),
+            nn.Flatten(),
+        )
+
+        # Compute shape by doing one forward pass
+        with torch.no_grad():
+            n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]
+
+        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
+
+    def forward(self, observations: torch.Tensor) -> torch.Tensor:
+        return self.linear(self.cnn(observations))
+```
+
+This class is created based on the custom feature extractor [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-feature-extractor:~:text=Custom%20Feature%20Extractor-,%EF%83%81,-If%20you%20want), the CNN architecture is copied from Lucas Willems' [rl-starter-files](https://github.com/lcswillems/rl-starter-files/blob/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/model.py#L18).
+
+## Train a PPO Agent
+
+The using the custom feature extractor, we can train a PPO agent on the `MiniGrid-Empty-16x16-v0` environment. The following code snippet shows how this can be done.
+
+```python
+import minigrid
+from minigrid.wrappers import ImgObsWrapper
+from stable_baselines3 import PPO
+
+policy_kwargs = dict(
+    features_extractor_class=MinigridFeaturesExtractor,
+    features_extractor_kwargs=dict(features_dim=128),
+)
+
+env = gym.make("MiniGrid-Empty-16x16-v0", render_mode="rgb_array")
+env = ImgObsWrapper(env)
+
+model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
+model.learn(2e5)
+```
+
+By default the observation of Minigrid environments are dictionaries. Since the `CnnPolicy` from StableBaseline3 by default takes in image observations, we need to wrap the environment using the `ImgObsWrapper` from the Minigrid library. This wrapper converts the dictionary observation to an image observation.
+
+## Further Reading
+
+One can also pass dictionary observations to StableBaseline3 policies, for a walkthrough the process of doing so see [here](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#multiple-inputs-and-dictionary-observations). An implementation utilizing this functionality can be found [here](https://github.com/BolunDai0216/MinigridMiniworldTransfer/blob/main/minigrid_gotoobj_train.py).

+ 1 - 0
docs/index.md

@@ -57,6 +57,7 @@ To cite this project please use:
 content/basic_usage
 content/publications
 content/create_env_tutorial
+content/training
 ```
 
 ```{toctree}