|
@@ -0,0 +1,253 @@
|
|
|
+# Copyright The PyTorch Lightning team.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+import logging
|
|
|
+import os
|
|
|
+import platform
|
|
|
+import random
|
|
|
+import time
|
|
|
+import urllib
|
|
|
+from typing import Optional, Tuple
|
|
|
+from urllib.error import HTTPError
|
|
|
+from warnings import warn
|
|
|
+
|
|
|
+import torch
|
|
|
+from torch.utils.data import DataLoader, Dataset, random_split
|
|
|
+
|
|
|
+from pl_examples import _DATASETS_PATH
|
|
|
+from pytorch_lightning import LightningDataModule
|
|
|
+from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
|
|
+
|
|
|
+if _TORCHVISION_AVAILABLE:
|
|
|
+ from torchvision import transforms as transform_lib
|
|
|
+
|
|
|
+
|
|
|
+class _MNIST(Dataset):
|
|
|
+ """Carbon copy of ``tests.helpers.datasets.MNIST``.
|
|
|
+ We cannot import the tests as they are not distributed with the package.
|
|
|
+ See https://github.com/PyTorchLightning/pytorch-lightning/pull/7614#discussion_r671183652 for more context.
|
|
|
+ """
|
|
|
+
|
|
|
+ RESOURCES = (
|
|
|
+ "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
|
|
|
+ "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
|
|
|
+ )
|
|
|
+
|
|
|
+ TRAIN_FILE_NAME = "training.pt"
|
|
|
+ TEST_FILE_NAME = "test.pt"
|
|
|
+ cache_folder_name = "complete"
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.root = root
|
|
|
+ self.train = train # training set or test set
|
|
|
+ self.normalize = normalize
|
|
|
+
|
|
|
+ self.prepare_data(download)
|
|
|
+
|
|
|
+ data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
|
|
|
+ self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
|
|
|
+
|
|
|
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
|
|
+ img = self.data[idx].float().unsqueeze(0)
|
|
|
+ target = int(self.targets[idx])
|
|
|
+
|
|
|
+ if self.normalize is not None and len(self.normalize) == 2:
|
|
|
+ img = self.normalize_tensor(img, *self.normalize)
|
|
|
+
|
|
|
+ return img, target
|
|
|
+
|
|
|
+ def __len__(self) -> int:
|
|
|
+ return len(self.data)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def cached_folder_path(self) -> str:
|
|
|
+ return os.path.join(self.root, "MNIST", self.cache_folder_name)
|
|
|
+
|
|
|
+ def _check_exists(self, data_folder: str) -> bool:
|
|
|
+ existing = True
|
|
|
+ for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
|
|
|
+ existing = existing and os.path.isfile(os.path.join(data_folder, fname))
|
|
|
+ return existing
|
|
|
+
|
|
|
+ def prepare_data(self, download: bool = True):
|
|
|
+ if download and not self._check_exists(self.cached_folder_path):
|
|
|
+ self._download(self.cached_folder_path)
|
|
|
+ if not self._check_exists(self.cached_folder_path):
|
|
|
+ raise RuntimeError("Dataset not found.")
|
|
|
+
|
|
|
+ def _download(self, data_folder: str) -> None:
|
|
|
+ os.makedirs(data_folder, exist_ok=True)
|
|
|
+ for url in self.RESOURCES:
|
|
|
+ logging.info(f"Downloading {url}")
|
|
|
+ fpath = os.path.join(data_folder, os.path.basename(url))
|
|
|
+ urllib.request.urlretrieve(url, fpath)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _try_load(path_data, trials: int = 30, delta: float = 1.0):
|
|
|
+ """Resolving loading from the same time from multiple concurrent processes."""
|
|
|
+ res, exception = None, None
|
|
|
+ assert trials, "at least some trial has to be set"
|
|
|
+ assert os.path.isfile(path_data), f"missing file: {path_data}"
|
|
|
+ for _ in range(trials):
|
|
|
+ try:
|
|
|
+ res = torch.load(path_data)
|
|
|
+ # todo: specify the possible exception
|
|
|
+ except Exception as e:
|
|
|
+ exception = e
|
|
|
+ time.sleep(delta * random.random())
|
|
|
+ else:
|
|
|
+ break
|
|
|
+ if exception is not None:
|
|
|
+ # raise the caught exception
|
|
|
+ raise exception
|
|
|
+ return res
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
|
|
|
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
|
|
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
|
|
+ return tensor.sub(mean).div(std)
|
|
|
+
|
|
|
+
|
|
|
+def MNIST(*args, **kwargs):
|
|
|
+ torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
|
|
|
+ if torchvision_mnist_available:
|
|
|
+ try:
|
|
|
+ from torchvision.datasets import MNIST
|
|
|
+
|
|
|
+ MNIST(_DATASETS_PATH, download=True)
|
|
|
+ except HTTPError as e:
|
|
|
+ print(f"Error {e} downloading `torchvision.datasets.MNIST`")
|
|
|
+ torchvision_mnist_available = False
|
|
|
+ if not torchvision_mnist_available:
|
|
|
+ print("`torchvision.datasets.MNIST` not available. Using our hosted version")
|
|
|
+ MNIST = _MNIST
|
|
|
+ return MNIST(*args, **kwargs)
|
|
|
+
|
|
|
+
|
|
|
+class MNISTDataModule(LightningDataModule):
|
|
|
+ """Standard MNIST, train, val, test splits and transforms.
|
|
|
+ >>> MNISTDataModule() # doctest: +ELLIPSIS
|
|
|
+ <...mnist_datamodule.MNISTDataModule object at ...>
|
|
|
+ """
|
|
|
+
|
|
|
+ name = "mnist"
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ data_dir: str = _DATASETS_PATH,
|
|
|
+ val_split: int = 5000,
|
|
|
+ num_workers: int = 16,
|
|
|
+ normalize: bool = False,
|
|
|
+ seed: int = 42,
|
|
|
+ batch_size: int = 32,
|
|
|
+ *args,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ data_dir: where to save/load the data
|
|
|
+ val_split: how many of the training images to use for the validation split
|
|
|
+ num_workers: how many workers to use for loading data
|
|
|
+ normalize: If true applies image normalize
|
|
|
+ seed: starting seed for RNG.
|
|
|
+ batch_size: desired batch size.
|
|
|
+ """
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+ if num_workers and platform.system() == "Windows":
|
|
|
+ # see: https://stackoverflow.com/a/59680818
|
|
|
+ warn(
|
|
|
+ f"You have requested num_workers={num_workers} on Windows,"
|
|
|
+ " but currently recommended is 0, so we set it for you"
|
|
|
+ )
|
|
|
+ num_workers = 0
|
|
|
+
|
|
|
+ self.data_dir = data_dir
|
|
|
+ self.val_split = val_split
|
|
|
+ self.num_workers = num_workers
|
|
|
+ self.normalize = normalize
|
|
|
+ self.seed = seed
|
|
|
+ self.batch_size = batch_size
|
|
|
+ self.dataset_train = ...
|
|
|
+ self.dataset_val = ...
|
|
|
+
|
|
|
+ @property
|
|
|
+ def num_classes(self):
|
|
|
+ return 10
|
|
|
+
|
|
|
+ def prepare_data(self):
|
|
|
+ """Saves MNIST files to `data_dir`"""
|
|
|
+ MNIST(self.data_dir, train=True, download=True)
|
|
|
+ MNIST(self.data_dir, train=False, download=True)
|
|
|
+
|
|
|
+ def setup(self, stage: Optional[str] = None):
|
|
|
+ """Split the train and valid dataset."""
|
|
|
+ extra = dict(transform=self.default_transforms) if self.default_transforms else {}
|
|
|
+ dataset = MNIST(self.data_dir, train=True, download=False, **extra)
|
|
|
+ train_length = len(dataset)
|
|
|
+ self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
|
|
|
+
|
|
|
+ def train_dataloader(self):
|
|
|
+ """MNIST train set removes a subset to use for validation."""
|
|
|
+ loader = DataLoader(
|
|
|
+ self.dataset_train,
|
|
|
+ batch_size=self.batch_size,
|
|
|
+ shuffle=True,
|
|
|
+ num_workers=self.num_workers,
|
|
|
+ drop_last=True,
|
|
|
+ pin_memory=True,
|
|
|
+ )
|
|
|
+ return loader
|
|
|
+
|
|
|
+ def val_dataloader(self):
|
|
|
+ """MNIST val set uses a subset of the training set for validation."""
|
|
|
+ loader = DataLoader(
|
|
|
+ self.dataset_val,
|
|
|
+ batch_size=self.batch_size,
|
|
|
+ shuffle=False,
|
|
|
+ num_workers=self.num_workers,
|
|
|
+ drop_last=True,
|
|
|
+ pin_memory=True,
|
|
|
+ )
|
|
|
+ return loader
|
|
|
+
|
|
|
+ def test_dataloader(self):
|
|
|
+ """MNIST test set uses the test split."""
|
|
|
+ extra = dict(transform=self.default_transforms) if self.default_transforms else {}
|
|
|
+ dataset = MNIST(self.data_dir, train=False, download=False, **extra)
|
|
|
+ loader = DataLoader(
|
|
|
+ dataset,
|
|
|
+ batch_size=self.batch_size,
|
|
|
+ shuffle=False,
|
|
|
+ num_workers=self.num_workers,
|
|
|
+ drop_last=True,
|
|
|
+ pin_memory=True,
|
|
|
+ )
|
|
|
+ return loader
|
|
|
+
|
|
|
+ @property
|
|
|
+ def default_transforms(self):
|
|
|
+ if not _TORCHVISION_AVAILABLE:
|
|
|
+ return None
|
|
|
+ if self.normalize:
|
|
|
+ mnist_transforms = transform_lib.Compose(
|
|
|
+ [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ mnist_transforms = transform_lib.ToTensor()
|
|
|
+
|
|
|
+ return mnist_transforms
|
|
|
+
|