Browse Source

add torch example

Andre Vauvelle 2 years ago
parent
commit
9d35f2deaa

+ 0 - 0
data/.gitkeep


+ 9 - 0
environment.yml

@@ -0,0 +1,9 @@
+name: hpc-example
+dependencies:
+        - pytorch
+        - pytorch-lightning
+        - torchvision
+        - torchmetrics
+        - setuptools==59.5.0
+        - lightning-bolts
+        - jsonargparse

+ 0 - 0
lightning_logs/.gitkeep


BIN
pl_examples/.__init__.py.swp


+ 43 - 0
pl_examples/__init__.py

@@ -0,0 +1,43 @@
+import os
+from pathlib import Path
+
+
+
+_DATASETS_PATH = os.path.join(Path(__file__).resolve().parents[0], "data")
+
+
+LIGHTNING_LOGO = """
+                    ####
+                ###########
+             ####################
+         ############################
+    #####################################
+##############################################
+#########################  ###################
+#######################    ###################
+####################      ####################
+##################       #####################
+################        ######################
+#####################        #################
+######################     ###################
+#####################    #####################
+####################   #######################
+###################  #########################
+##############################################
+    #####################################
+         ############################
+             ####################
+                  ##########
+                     ####
+"""
+
+
+def nice_print(msg, last=False):
+    print()
+    print("\033[0;35m" + msg + "\033[0m")
+    if last:
+        print()
+
+
+def cli_lightning_logo():
+    nice_print(LIGHTNING_LOGO)

BIN
pl_examples/__pycache__/__init__.cpython-39.pyc


BIN
pl_examples/__pycache__/image_classifier_1_pytorch.cpython-39.pyc


BIN
pl_examples/__pycache__/mnist_datamodule.cpython-39.pyc


BIN
pl_examples/data/MNIST/raw/t10k-images-idx3-ubyte


BIN
pl_examples/data/MNIST/raw/t10k-images-idx3-ubyte.gz


BIN
pl_examples/data/MNIST/raw/t10k-labels-idx1-ubyte


BIN
pl_examples/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


BIN
pl_examples/data/MNIST/raw/train-images-idx3-ubyte


BIN
pl_examples/data/MNIST/raw/train-images-idx3-ubyte.gz


BIN
pl_examples/data/MNIST/raw/train-labels-idx1-ubyte


BIN
pl_examples/data/MNIST/raw/train-labels-idx1-ubyte.gz


+ 154 - 0
pl_examples/image_classifier_1_pytorch.py

@@ -0,0 +1,154 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision.transforms as T
+from torch.optim.lr_scheduler import StepLR
+
+from pl_examples.mnist_datamodule import MNIST
+
+# Credit to the PyTorch Team
+# Taken from https://github.com/pytorch/examples/blob/master/mnist/main.py and slightly adapted.
+
+
+class Net(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 32, 3, 1)
+        self.conv2 = nn.Conv2d(32, 64, 3, 1)
+        self.dropout1 = nn.Dropout(0.25)
+        self.dropout2 = nn.Dropout(0.5)
+        self.fc1 = nn.Linear(9216, 128)
+        self.fc2 = nn.Linear(128, 10)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = F.relu(x)
+        x = self.conv2(x)
+        x = F.relu(x)
+        x = F.max_pool2d(x, 2)
+        x = self.dropout1(x)
+        x = torch.flatten(x, 1)
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.dropout2(x)
+        x = self.fc2(x)
+        output = F.log_softmax(x, dim=1)
+        return output
+
+
+def run(hparams):
+
+    torch.manual_seed(hparams.seed)
+
+    use_cuda = torch.cuda.is_available()
+    device = torch.device("cuda" if use_cuda else "cpu")
+
+    transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
+    train_dataset = MNIST("./data", train=True, download=True, transform=transform)
+    test_dataset = MNIST("./data", train=False, transform=transform)
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=hparams.batch_size,
+    )
+    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size)
+
+    model = Net().to(device)
+    optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr)
+
+    scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma)
+
+    # EPOCH LOOP
+    for epoch in range(1, hparams.epochs + 1):
+
+        # TRAINING LOOP
+        model.train()
+        for batch_idx, (data, target) in enumerate(train_loader):
+            data, target = data.to(device), target.to(device)
+            optimizer.zero_grad()
+            output = model(data)
+            loss = F.nll_loss(output, target)
+            loss.backward()
+            optimizer.step()
+            if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0):
+                print(
+                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
+                        epoch,
+                        batch_idx * len(data),
+                        len(train_loader.dataset),
+                        100.0 * batch_idx / len(train_loader),
+                        loss.item(),
+                    )
+                )
+                if hparams.dry_run:
+                    break
+        scheduler.step()
+
+        # TESTING LOOP
+        model.eval()
+        test_loss = 0
+        correct = 0
+        with torch.no_grad():
+            for data, target in test_loader:
+                data, target = data.to(device), target.to(device)
+                output = model(data)
+                test_loss += F.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
+                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
+                correct += pred.eq(target.view_as(pred)).sum().item()
+                if hparams.dry_run:
+                    break
+
+        test_loss /= len(test_loader.dataset)
+
+        print(
+            "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
+                test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
+            )
+        )
+
+        if hparams.dry_run:
+            break
+
+    if hparams.save_model:
+        torch.save(model.state_dict(), "mnist_cnn.pt")
+
+
+def main():
+    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+    parser.add_argument(
+        "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
+    )
+    parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)")
+    parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)")
+    parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)")
+    parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass")
+    parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
+    parser.add_argument(
+        "--log-interval",
+        type=int,
+        default=10,
+        metavar="N",
+        help="how many batches to wait before logging training status",
+    )
+    parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model")
+    hparams = parser.parse_args()
+    run(hparams)
+
+
+if __name__ == "__main__":
+    main()

+ 84 - 0
pl_examples/image_classifier_4_lightning_module.py

@@ -0,0 +1,84 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Simple MNIST image classifier example with LightningModule.
+To run: python image_classifier_4_lightning_module.py --trainer.max_epochs=50
+"""
+import torch
+import torchvision.transforms as T
+from torch.nn import functional as F
+from torchmetrics import Accuracy
+
+from pl_examples import cli_lightning_logo
+from pl_examples.mnist_datamodule import MNIST
+from pl_examples.image_classifier_1_pytorch import Net
+from pytorch_lightning import LightningModule
+from pytorch_lightning.utilities.cli import LightningCLI
+
+
+class ImageClassifier(LightningModule):
+    def __init__(self, model=None, lr=1.0, gamma=0.7, batch_size=32):
+        super().__init__()
+        self.save_hyperparameters(ignore="model")
+        self.model = model or Net()
+        self.test_acc = Accuracy()
+
+    def forward(self, x):
+        return self.model(x)
+
+    def training_step(self, batch, batch_idx):
+        x, y = batch
+        logits = self.forward(x)
+        loss = F.nll_loss(logits, y.long())
+        return loss
+
+    def test_step(self, batch, batch_idx):
+        x, y = batch
+        logits = self.forward(x)
+        loss = F.nll_loss(logits, y.long())
+        self.test_acc(logits, y)
+        self.log("test_acc", self.test_acc)
+        self.log("test_loss", loss)
+
+    def configure_optimizers(self):
+        optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr)
+        return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)]
+
+    # Methods for the `LightningDataModule` conversion
+
+    @property
+    def transform(self):
+        return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
+
+    def prepare_data(self) -> None:
+        MNIST("./data", download=True)
+
+    def train_dataloader(self):
+        train_dataset = MNIST("./data", train=True, download=False, transform=self.transform)
+        return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size)
+
+    def test_dataloader(self):
+        test_dataset = MNIST("./data", train=False, download=False, transform=self.transform)
+        return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size)
+
+
+def cli_main():
+    # The LightningCLI removes all the boilerplate associated with arguments parsing. This is purely optional.
+    cli = LightningCLI(ImageClassifier, seed_everything_default=42, save_config_overwrite=True, run=False)
+    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
+    cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
+
+
+if __name__ == "__main__":
+    cli_lightning_logo()
+    cli_main()

+ 253 - 0
pl_examples/mnist_datamodule.py

@@ -0,0 +1,253 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+import platform
+import random
+import time
+import urllib
+from typing import Optional, Tuple
+from urllib.error import HTTPError
+from warnings import warn
+
+import torch
+from torch.utils.data import DataLoader, Dataset, random_split
+
+from pl_examples import _DATASETS_PATH
+from pytorch_lightning import LightningDataModule
+from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision import transforms as transform_lib
+
+
+class _MNIST(Dataset):
+    """Carbon copy of ``tests.helpers.datasets.MNIST``.
+    We cannot import the tests as they are not distributed with the package.
+    See https://github.com/PyTorchLightning/pytorch-lightning/pull/7614#discussion_r671183652 for more context.
+    """
+
+    RESOURCES = (
+        "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
+        "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
+    )
+
+    TRAIN_FILE_NAME = "training.pt"
+    TEST_FILE_NAME = "test.pt"
+    cache_folder_name = "complete"
+
+    def __init__(
+        self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs
+    ):
+        super().__init__()
+        self.root = root
+        self.train = train  # training set or test set
+        self.normalize = normalize
+
+        self.prepare_data(download)
+
+        data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
+        self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
+
+    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
+        img = self.data[idx].float().unsqueeze(0)
+        target = int(self.targets[idx])
+
+        if self.normalize is not None and len(self.normalize) == 2:
+            img = self.normalize_tensor(img, *self.normalize)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    @property
+    def cached_folder_path(self) -> str:
+        return os.path.join(self.root, "MNIST", self.cache_folder_name)
+
+    def _check_exists(self, data_folder: str) -> bool:
+        existing = True
+        for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
+            existing = existing and os.path.isfile(os.path.join(data_folder, fname))
+        return existing
+
+    def prepare_data(self, download: bool = True):
+        if download and not self._check_exists(self.cached_folder_path):
+            self._download(self.cached_folder_path)
+        if not self._check_exists(self.cached_folder_path):
+            raise RuntimeError("Dataset not found.")
+
+    def _download(self, data_folder: str) -> None:
+        os.makedirs(data_folder, exist_ok=True)
+        for url in self.RESOURCES:
+            logging.info(f"Downloading {url}")
+            fpath = os.path.join(data_folder, os.path.basename(url))
+            urllib.request.urlretrieve(url, fpath)
+
+    @staticmethod
+    def _try_load(path_data, trials: int = 30, delta: float = 1.0):
+        """Resolving loading from the same time from multiple concurrent processes."""
+        res, exception = None, None
+        assert trials, "at least some trial has to be set"
+        assert os.path.isfile(path_data), f"missing file: {path_data}"
+        for _ in range(trials):
+            try:
+                res = torch.load(path_data)
+            # todo: specify the possible exception
+            except Exception as e:
+                exception = e
+                time.sleep(delta * random.random())
+            else:
+                break
+        if exception is not None:
+            # raise the caught exception
+            raise exception
+        return res
+
+    @staticmethod
+    def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
+        mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
+        std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
+        return tensor.sub(mean).div(std)
+
+
+def MNIST(*args, **kwargs):
+    torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
+    if torchvision_mnist_available:
+        try:
+            from torchvision.datasets import MNIST
+
+            MNIST(_DATASETS_PATH, download=True)
+        except HTTPError as e:
+            print(f"Error {e} downloading `torchvision.datasets.MNIST`")
+            torchvision_mnist_available = False
+    if not torchvision_mnist_available:
+        print("`torchvision.datasets.MNIST` not available. Using our hosted version")
+        MNIST = _MNIST
+    return MNIST(*args, **kwargs)
+
+
+class MNISTDataModule(LightningDataModule):
+    """Standard MNIST, train, val, test splits and transforms.
+    >>> MNISTDataModule()  # doctest: +ELLIPSIS
+    <...mnist_datamodule.MNISTDataModule object at ...>
+    """
+
+    name = "mnist"
+
+    def __init__(
+        self,
+        data_dir: str = _DATASETS_PATH,
+        val_split: int = 5000,
+        num_workers: int = 16,
+        normalize: bool = False,
+        seed: int = 42,
+        batch_size: int = 32,
+        *args,
+        **kwargs,
+    ):
+        """
+        Args:
+            data_dir: where to save/load the data
+            val_split: how many of the training images to use for the validation split
+            num_workers: how many workers to use for loading data
+            normalize: If true applies image normalize
+            seed: starting seed for RNG.
+            batch_size: desired batch size.
+        """
+        super().__init__(*args, **kwargs)
+        if num_workers and platform.system() == "Windows":
+            # see: https://stackoverflow.com/a/59680818
+            warn(
+                f"You have requested num_workers={num_workers} on Windows,"
+                " but currently recommended is 0, so we set it for you"
+            )
+            num_workers = 0
+
+        self.data_dir = data_dir
+        self.val_split = val_split
+        self.num_workers = num_workers
+        self.normalize = normalize
+        self.seed = seed
+        self.batch_size = batch_size
+        self.dataset_train = ...
+        self.dataset_val = ...
+
+    @property
+    def num_classes(self):
+        return 10
+
+    def prepare_data(self):
+        """Saves MNIST files to `data_dir`"""
+        MNIST(self.data_dir, train=True, download=True)
+        MNIST(self.data_dir, train=False, download=True)
+
+    def setup(self, stage: Optional[str] = None):
+        """Split the train and valid dataset."""
+        extra = dict(transform=self.default_transforms) if self.default_transforms else {}
+        dataset = MNIST(self.data_dir, train=True, download=False, **extra)
+        train_length = len(dataset)
+        self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
+
+    def train_dataloader(self):
+        """MNIST train set removes a subset to use for validation."""
+        loader = DataLoader(
+            self.dataset_train,
+            batch_size=self.batch_size,
+            shuffle=True,
+            num_workers=self.num_workers,
+            drop_last=True,
+            pin_memory=True,
+        )
+        return loader
+
+    def val_dataloader(self):
+        """MNIST val set uses a subset of the training set for validation."""
+        loader = DataLoader(
+            self.dataset_val,
+            batch_size=self.batch_size,
+            shuffle=False,
+            num_workers=self.num_workers,
+            drop_last=True,
+            pin_memory=True,
+        )
+        return loader
+
+    def test_dataloader(self):
+        """MNIST test set uses the test split."""
+        extra = dict(transform=self.default_transforms) if self.default_transforms else {}
+        dataset = MNIST(self.data_dir, train=False, download=False, **extra)
+        loader = DataLoader(
+            dataset,
+            batch_size=self.batch_size,
+            shuffle=False,
+            num_workers=self.num_workers,
+            drop_last=True,
+            pin_memory=True,
+        )
+        return loader
+
+    @property
+    def default_transforms(self):
+        if not _TORCHVISION_AVAILABLE:
+            return None
+        if self.normalize:
+            mnist_transforms = transform_lib.Compose(
+                [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
+            )
+        else:
+            mnist_transforms = transform_lib.ToTensor()
+
+        return mnist_transforms
+