123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- # Copyright The PyTorch Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- import os
- import platform
- import random
- import time
- import urllib
- from typing import Optional, Tuple
- from urllib.error import HTTPError
- from warnings import warn
- import torch
- from torch.utils.data import DataLoader, Dataset, random_split
- from pl_examples import _DATASETS_PATH
- from pytorch_lightning import LightningDataModule
- from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
- if _TORCHVISION_AVAILABLE:
- from torchvision import transforms as transform_lib
- class _MNIST(Dataset):
- """Carbon copy of ``tests.helpers.datasets.MNIST``.
- We cannot import the tests as they are not distributed with the package.
- See https://github.com/PyTorchLightning/pytorch-lightning/pull/7614#discussion_r671183652 for more context.
- """
- RESOURCES = (
- "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
- "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
- )
- TRAIN_FILE_NAME = "training.pt"
- TEST_FILE_NAME = "test.pt"
- cache_folder_name = "complete"
- def __init__(
- self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs
- ):
- super().__init__()
- self.root = root
- self.train = train # training set or test set
- self.normalize = normalize
- self.prepare_data(download)
- data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
- self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
- img = self.data[idx].float().unsqueeze(0)
- target = int(self.targets[idx])
- if self.normalize is not None and len(self.normalize) == 2:
- img = self.normalize_tensor(img, *self.normalize)
- return img, target
- def __len__(self) -> int:
- return len(self.data)
- @property
- def cached_folder_path(self) -> str:
- return os.path.join(self.root, "MNIST", self.cache_folder_name)
- def _check_exists(self, data_folder: str) -> bool:
- existing = True
- for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
- existing = existing and os.path.isfile(os.path.join(data_folder, fname))
- return existing
- def prepare_data(self, download: bool = True):
- if download and not self._check_exists(self.cached_folder_path):
- self._download(self.cached_folder_path)
- if not self._check_exists(self.cached_folder_path):
- raise RuntimeError("Dataset not found.")
- def _download(self, data_folder: str) -> None:
- os.makedirs(data_folder, exist_ok=True)
- for url in self.RESOURCES:
- logging.info(f"Downloading {url}")
- fpath = os.path.join(data_folder, os.path.basename(url))
- urllib.request.urlretrieve(url, fpath)
- @staticmethod
- def _try_load(path_data, trials: int = 30, delta: float = 1.0):
- """Resolving loading from the same time from multiple concurrent processes."""
- res, exception = None, None
- assert trials, "at least some trial has to be set"
- assert os.path.isfile(path_data), f"missing file: {path_data}"
- for _ in range(trials):
- try:
- res = torch.load(path_data)
- # todo: specify the possible exception
- except Exception as e:
- exception = e
- time.sleep(delta * random.random())
- else:
- break
- if exception is not None:
- # raise the caught exception
- raise exception
- return res
- @staticmethod
- def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
- mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
- std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
- return tensor.sub(mean).div(std)
- def MNIST(*args, **kwargs):
- torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
- if torchvision_mnist_available:
- try:
- from torchvision.datasets import MNIST
- MNIST(_DATASETS_PATH, download=True)
- except HTTPError as e:
- print(f"Error {e} downloading `torchvision.datasets.MNIST`")
- torchvision_mnist_available = False
- if not torchvision_mnist_available:
- print("`torchvision.datasets.MNIST` not available. Using our hosted version")
- MNIST = _MNIST
- return MNIST(*args, **kwargs)
- class MNISTDataModule(LightningDataModule):
- """Standard MNIST, train, val, test splits and transforms.
- >>> MNISTDataModule() # doctest: +ELLIPSIS
- <...mnist_datamodule.MNISTDataModule object at ...>
- """
- name = "mnist"
- def __init__(
- self,
- data_dir: str = _DATASETS_PATH,
- val_split: int = 5000,
- num_workers: int = 16,
- normalize: bool = False,
- seed: int = 42,
- batch_size: int = 32,
- *args,
- **kwargs,
- ):
- """
- Args:
- data_dir: where to save/load the data
- val_split: how many of the training images to use for the validation split
- num_workers: how many workers to use for loading data
- normalize: If true applies image normalize
- seed: starting seed for RNG.
- batch_size: desired batch size.
- """
- super().__init__(*args, **kwargs)
- if num_workers and platform.system() == "Windows":
- # see: https://stackoverflow.com/a/59680818
- warn(
- f"You have requested num_workers={num_workers} on Windows,"
- " but currently recommended is 0, so we set it for you"
- )
- num_workers = 0
- self.data_dir = data_dir
- self.val_split = val_split
- self.num_workers = num_workers
- self.normalize = normalize
- self.seed = seed
- self.batch_size = batch_size
- self.dataset_train = ...
- self.dataset_val = ...
- @property
- def num_classes(self):
- return 10
- def prepare_data(self):
- """Saves MNIST files to `data_dir`"""
- MNIST(self.data_dir, train=True, download=True)
- MNIST(self.data_dir, train=False, download=True)
- def setup(self, stage: Optional[str] = None):
- """Split the train and valid dataset."""
- extra = dict(transform=self.default_transforms) if self.default_transforms else {}
- dataset = MNIST(self.data_dir, train=True, download=False, **extra)
- train_length = len(dataset)
- self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
- def train_dataloader(self):
- """MNIST train set removes a subset to use for validation."""
- loader = DataLoader(
- self.dataset_train,
- batch_size=self.batch_size,
- shuffle=True,
- num_workers=self.num_workers,
- drop_last=True,
- pin_memory=True,
- )
- return loader
- def val_dataloader(self):
- """MNIST val set uses a subset of the training set for validation."""
- loader = DataLoader(
- self.dataset_val,
- batch_size=self.batch_size,
- shuffle=False,
- num_workers=self.num_workers,
- drop_last=True,
- pin_memory=True,
- )
- return loader
- def test_dataloader(self):
- """MNIST test set uses the test split."""
- extra = dict(transform=self.default_transforms) if self.default_transforms else {}
- dataset = MNIST(self.data_dir, train=False, download=False, **extra)
- loader = DataLoader(
- dataset,
- batch_size=self.batch_size,
- shuffle=False,
- num_workers=self.num_workers,
- drop_last=True,
- pin_memory=True,
- )
- return loader
- @property
- def default_transforms(self):
- if not _TORCHVISION_AVAILABLE:
- return None
- if self.normalize:
- mnist_transforms = transform_lib.Compose(
- [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
- )
- else:
- mnist_transforms = transform_lib.ToTensor()
- return mnist_transforms
|