# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import platform import random import time import urllib from typing import Optional, Tuple from urllib.error import HTTPError from warnings import warn import torch from torch.utils.data import DataLoader, Dataset, random_split from pl_examples import _DATASETS_PATH from pytorch_lightning import LightningDataModule from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib class _MNIST(Dataset): """Carbon copy of ``tests.helpers.datasets.MNIST``. We cannot import the tests as they are not distributed with the package. See https://github.com/PyTorchLightning/pytorch-lightning/pull/7614#discussion_r671183652 for more context. """ RESOURCES = ( "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt", "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt", ) TRAIN_FILE_NAME = "training.pt" TEST_FILE_NAME = "test.pt" cache_folder_name = "complete" def __init__( self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs ): super().__init__() self.root = root self.train = train # training set or test set self.normalize = normalize self.prepare_data(download) data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) if self.normalize is not None and len(self.normalize) == 2: img = self.normalize_tensor(img, *self.normalize) return img, target def __len__(self) -> int: return len(self.data) @property def cached_folder_path(self) -> str: return os.path.join(self.root, "MNIST", self.cache_folder_name) def _check_exists(self, data_folder: str) -> bool: existing = True for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): existing = existing and os.path.isfile(os.path.join(data_folder, fname)) return existing def prepare_data(self, download: bool = True): if download and not self._check_exists(self.cached_folder_path): self._download(self.cached_folder_path) if not self._check_exists(self.cached_folder_path): raise RuntimeError("Dataset not found.") def _download(self, data_folder: str) -> None: os.makedirs(data_folder, exist_ok=True) for url in self.RESOURCES: logging.info(f"Downloading {url}") fpath = os.path.join(data_folder, os.path.basename(url)) urllib.request.urlretrieve(url, fpath) @staticmethod def _try_load(path_data, trials: int = 30, delta: float = 1.0): """Resolving loading from the same time from multiple concurrent processes.""" res, exception = None, None assert trials, "at least some trial has to be set" assert os.path.isfile(path_data), f"missing file: {path_data}" for _ in range(trials): try: res = torch.load(path_data) # todo: specify the possible exception except Exception as e: exception = e time.sleep(delta * random.random()) else: break if exception is not None: # raise the caught exception raise exception return res @staticmethod def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor: mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) return tensor.sub(mean).div(std) def MNIST(*args, **kwargs): torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) if torchvision_mnist_available: try: from torchvision.datasets import MNIST MNIST(_DATASETS_PATH, download=True) except HTTPError as e: print(f"Error {e} downloading `torchvision.datasets.MNIST`") torchvision_mnist_available = False if not torchvision_mnist_available: print("`torchvision.datasets.MNIST` not available. Using our hosted version") MNIST = _MNIST return MNIST(*args, **kwargs) class MNISTDataModule(LightningDataModule): """Standard MNIST, train, val, test splits and transforms. >>> MNISTDataModule() # doctest: +ELLIPSIS <...mnist_datamodule.MNISTDataModule object at ...> """ name = "mnist" def __init__( self, data_dir: str = _DATASETS_PATH, val_split: int = 5000, num_workers: int = 16, normalize: bool = False, seed: int = 42, batch_size: int = 32, *args, **kwargs, ): """ Args: data_dir: where to save/load the data val_split: how many of the training images to use for the validation split num_workers: how many workers to use for loading data normalize: If true applies image normalize seed: starting seed for RNG. batch_size: desired batch size. """ super().__init__(*args, **kwargs) if num_workers and platform.system() == "Windows": # see: https://stackoverflow.com/a/59680818 warn( f"You have requested num_workers={num_workers} on Windows," " but currently recommended is 0, so we set it for you" ) num_workers = 0 self.data_dir = data_dir self.val_split = val_split self.num_workers = num_workers self.normalize = normalize self.seed = seed self.batch_size = batch_size self.dataset_train = ... self.dataset_val = ... @property def num_classes(self): return 10 def prepare_data(self): """Saves MNIST files to `data_dir`""" MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) def setup(self, stage: Optional[str] = None): """Split the train and valid dataset.""" extra = dict(transform=self.default_transforms) if self.default_transforms else {} dataset = MNIST(self.data_dir, train=True, download=False, **extra) train_length = len(dataset) self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split]) def train_dataloader(self): """MNIST train set removes a subset to use for validation.""" loader = DataLoader( self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=True, ) return loader def val_dataloader(self): """MNIST val set uses a subset of the training set for validation.""" loader = DataLoader( self.dataset_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True, ) return loader def test_dataloader(self): """MNIST test set uses the test split.""" extra = dict(transform=self.default_transforms) if self.default_transforms else {} dataset = MNIST(self.data_dir, train=False, download=False, **extra) loader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True, ) return loader @property def default_transforms(self): if not _TORCHVISION_AVAILABLE: return None if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] ) else: mnist_transforms = transform_lib.ToTensor() return mnist_transforms