mnist_datamodule.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright The PyTorch Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import os
  16. import platform
  17. import random
  18. import time
  19. import urllib
  20. from typing import Optional, Tuple
  21. from urllib.error import HTTPError
  22. from warnings import warn
  23. import torch
  24. from torch.utils.data import DataLoader, Dataset, random_split
  25. from pl_examples import _DATASETS_PATH
  26. from pytorch_lightning import LightningDataModule
  27. from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
  28. if _TORCHVISION_AVAILABLE:
  29. from torchvision import transforms as transform_lib
  30. class _MNIST(Dataset):
  31. """Carbon copy of ``tests.helpers.datasets.MNIST``.
  32. We cannot import the tests as they are not distributed with the package.
  33. See https://github.com/PyTorchLightning/pytorch-lightning/pull/7614#discussion_r671183652 for more context.
  34. """
  35. RESOURCES = (
  36. "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
  37. "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
  38. )
  39. TRAIN_FILE_NAME = "training.pt"
  40. TEST_FILE_NAME = "test.pt"
  41. cache_folder_name = "complete"
  42. def __init__(
  43. self, root: str, train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, **kwargs
  44. ):
  45. super().__init__()
  46. self.root = root
  47. self.train = train # training set or test set
  48. self.normalize = normalize
  49. self.prepare_data(download)
  50. data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
  51. self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))
  52. def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
  53. img = self.data[idx].float().unsqueeze(0)
  54. target = int(self.targets[idx])
  55. if self.normalize is not None and len(self.normalize) == 2:
  56. img = self.normalize_tensor(img, *self.normalize)
  57. return img, target
  58. def __len__(self) -> int:
  59. return len(self.data)
  60. @property
  61. def cached_folder_path(self) -> str:
  62. return os.path.join(self.root, "MNIST", self.cache_folder_name)
  63. def _check_exists(self, data_folder: str) -> bool:
  64. existing = True
  65. for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
  66. existing = existing and os.path.isfile(os.path.join(data_folder, fname))
  67. return existing
  68. def prepare_data(self, download: bool = True):
  69. if download and not self._check_exists(self.cached_folder_path):
  70. self._download(self.cached_folder_path)
  71. if not self._check_exists(self.cached_folder_path):
  72. raise RuntimeError("Dataset not found.")
  73. def _download(self, data_folder: str) -> None:
  74. os.makedirs(data_folder, exist_ok=True)
  75. for url in self.RESOURCES:
  76. logging.info(f"Downloading {url}")
  77. fpath = os.path.join(data_folder, os.path.basename(url))
  78. urllib.request.urlretrieve(url, fpath)
  79. @staticmethod
  80. def _try_load(path_data, trials: int = 30, delta: float = 1.0):
  81. """Resolving loading from the same time from multiple concurrent processes."""
  82. res, exception = None, None
  83. assert trials, "at least some trial has to be set"
  84. assert os.path.isfile(path_data), f"missing file: {path_data}"
  85. for _ in range(trials):
  86. try:
  87. res = torch.load(path_data)
  88. # todo: specify the possible exception
  89. except Exception as e:
  90. exception = e
  91. time.sleep(delta * random.random())
  92. else:
  93. break
  94. if exception is not None:
  95. # raise the caught exception
  96. raise exception
  97. return res
  98. @staticmethod
  99. def normalize_tensor(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
  100. mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
  101. std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
  102. return tensor.sub(mean).div(std)
  103. def MNIST(*args, **kwargs):
  104. torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
  105. if torchvision_mnist_available:
  106. try:
  107. from torchvision.datasets import MNIST
  108. MNIST(_DATASETS_PATH, download=True)
  109. except HTTPError as e:
  110. print(f"Error {e} downloading `torchvision.datasets.MNIST`")
  111. torchvision_mnist_available = False
  112. if not torchvision_mnist_available:
  113. print("`torchvision.datasets.MNIST` not available. Using our hosted version")
  114. MNIST = _MNIST
  115. return MNIST(*args, **kwargs)
  116. class MNISTDataModule(LightningDataModule):
  117. """Standard MNIST, train, val, test splits and transforms.
  118. >>> MNISTDataModule() # doctest: +ELLIPSIS
  119. <...mnist_datamodule.MNISTDataModule object at ...>
  120. """
  121. name = "mnist"
  122. def __init__(
  123. self,
  124. data_dir: str = _DATASETS_PATH,
  125. val_split: int = 5000,
  126. num_workers: int = 16,
  127. normalize: bool = False,
  128. seed: int = 42,
  129. batch_size: int = 32,
  130. *args,
  131. **kwargs,
  132. ):
  133. """
  134. Args:
  135. data_dir: where to save/load the data
  136. val_split: how many of the training images to use for the validation split
  137. num_workers: how many workers to use for loading data
  138. normalize: If true applies image normalize
  139. seed: starting seed for RNG.
  140. batch_size: desired batch size.
  141. """
  142. super().__init__(*args, **kwargs)
  143. if num_workers and platform.system() == "Windows":
  144. # see: https://stackoverflow.com/a/59680818
  145. warn(
  146. f"You have requested num_workers={num_workers} on Windows,"
  147. " but currently recommended is 0, so we set it for you"
  148. )
  149. num_workers = 0
  150. self.data_dir = data_dir
  151. self.val_split = val_split
  152. self.num_workers = num_workers
  153. self.normalize = normalize
  154. self.seed = seed
  155. self.batch_size = batch_size
  156. self.dataset_train = ...
  157. self.dataset_val = ...
  158. @property
  159. def num_classes(self):
  160. return 10
  161. def prepare_data(self):
  162. """Saves MNIST files to `data_dir`"""
  163. MNIST(self.data_dir, train=True, download=True)
  164. MNIST(self.data_dir, train=False, download=True)
  165. def setup(self, stage: Optional[str] = None):
  166. """Split the train and valid dataset."""
  167. extra = dict(transform=self.default_transforms) if self.default_transforms else {}
  168. dataset = MNIST(self.data_dir, train=True, download=False, **extra)
  169. train_length = len(dataset)
  170. self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
  171. def train_dataloader(self):
  172. """MNIST train set removes a subset to use for validation."""
  173. loader = DataLoader(
  174. self.dataset_train,
  175. batch_size=self.batch_size,
  176. shuffle=True,
  177. num_workers=self.num_workers,
  178. drop_last=True,
  179. pin_memory=True,
  180. )
  181. return loader
  182. def val_dataloader(self):
  183. """MNIST val set uses a subset of the training set for validation."""
  184. loader = DataLoader(
  185. self.dataset_val,
  186. batch_size=self.batch_size,
  187. shuffle=False,
  188. num_workers=self.num_workers,
  189. drop_last=True,
  190. pin_memory=True,
  191. )
  192. return loader
  193. def test_dataloader(self):
  194. """MNIST test set uses the test split."""
  195. extra = dict(transform=self.default_transforms) if self.default_transforms else {}
  196. dataset = MNIST(self.data_dir, train=False, download=False, **extra)
  197. loader = DataLoader(
  198. dataset,
  199. batch_size=self.batch_size,
  200. shuffle=False,
  201. num_workers=self.num_workers,
  202. drop_last=True,
  203. pin_memory=True,
  204. )
  205. return loader
  206. @property
  207. def default_transforms(self):
  208. if not _TORCHVISION_AVAILABLE:
  209. return None
  210. if self.normalize:
  211. mnist_transforms = transform_lib.Compose(
  212. [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
  213. )
  214. else:
  215. mnist_transforms = transform_lib.ToTensor()
  216. return mnist_transforms