autoaugment.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. """AutoAugment data augmentation policy for ImageNet.
  2. -- Begin license text.
  3. MIT License
  4. Copyright (c) 2018 Philip Popien
  5. Permission is hereby granted, free of charge, to any person obtaining a copy
  6. of this software and associated documentation files (the "Software"), to deal
  7. in the Software without restriction, including without limitation the rights
  8. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. copies of the Software, and to permit persons to whom the Software is
  10. furnished to do so, subject to the following conditions:
  11. The above copyright notice and this permission notice shall be included in all
  12. copies or substantial portions of the Software.
  13. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  14. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  15. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  16. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  17. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  18. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  19. SOFTWARE.
  20. -- End license text.
  21. Code adapted from https://github.com/DeepVoltaire/AutoAugment.
  22. This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
  23. Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
  24. policies.
  25. Reference:
  26. [1] https://arxiv.org/abs/1805.09501
  27. """
  28. import random
  29. import numpy as np
  30. from PIL import Image
  31. from PIL import ImageEnhance
  32. from PIL import ImageOps
  33. _MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
  34. class ImageNetPolicy:
  35. """Definition of an ImageNetPolicy.
  36. Implements a fixed AutoAugment data augmentation policy targeted at
  37. ImageNet training by randomly applying at runtime one of the 25 pre-defined
  38. data augmentation sub-policies provided in Reference [1].
  39. Usage example as a Pytorch Transform:
  40. >>> transform=transforms.Compose([transforms.Resize(256),
  41. >>> ImageNetPolicy(),
  42. >>> transforms.ToTensor()])
  43. """
  44. def __init__(self, fillcolor=(128, 128, 128)):
  45. """Initialize an ImageNetPolicy.
  46. Args:
  47. fillcolor (tuple): RGB color components of the color to be used for
  48. filling when needed (default: (128, 128, 128), which
  49. corresponds to gray).
  50. """
  51. # Instantiate a list of sub-policies.
  52. # Each entry of the list is a SubPolicy which consists of
  53. # two augmentation operations,
  54. # each of those parametrized as operation, probability, magnitude.
  55. # Those two operations are applied sequentially on the image upon call.
  56. self.policies = [
  57. SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
  58. SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
  59. SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
  60. SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
  61. SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
  62. SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
  63. SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
  64. SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
  65. SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
  66. SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
  67. SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
  68. SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
  69. SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
  70. SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
  71. SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
  72. SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
  73. SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
  74. SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
  75. SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
  76. SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
  77. SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
  78. SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
  79. SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
  80. SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
  81. SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
  82. ]
  83. def __call__(self, img):
  84. """Define call method for ImageNetPolicy class."""
  85. policy_idx = random.randint(0, len(self.policies) - 1)
  86. return self.policies[policy_idx](img)
  87. def __repr__(self):
  88. """Define repr method for ImageNetPolicy class."""
  89. return "ImageNetPolicy"
  90. class SubPolicy:
  91. """Definition of a SubPolicy.
  92. A SubPolicy consists of two augmentation operations,
  93. each of those parametrized as operation, probability, magnitude.
  94. The two operations are applied sequentially on the image upon call.
  95. """
  96. def __init__(
  97. self,
  98. operation1,
  99. probability1,
  100. magnitude_idx1,
  101. operation2,
  102. probability2,
  103. magnitude_idx2,
  104. fillcolor,
  105. ):
  106. """Initialize a SubPolicy.
  107. Args:
  108. operation1 (str): Key specifying the first augmentation operation.
  109. There are fourteen key values altogether (see supported_ops below
  110. listing supported operations). probability1 (float): Probability
  111. within [0., 1.] of applying the first augmentation operation.
  112. magnitude_idx1 (int): Integer specifiying the strength of the first
  113. operation as an index further used to derive the magnitude from a
  114. range of possible values.
  115. operation2 (str): Key specifying the second augmentation operation.
  116. probability2 (float): Probability within [0., 1.] of applying the
  117. second augmentation operation.
  118. magnitude_idx2 (int): Integer specifiying the strength of the
  119. second operation as an index further used to derive the magnitude
  120. from a range of possible values.
  121. fillcolor (tuple): RGB color components of the color to be used for
  122. filling.
  123. Returns:
  124. """
  125. # List of supported operations for operation1 and operation2.
  126. supported_ops = [
  127. "shearX",
  128. "shearY",
  129. "translateX",
  130. "translateY",
  131. "rotate",
  132. "color",
  133. "posterize",
  134. "solarize",
  135. "contrast",
  136. "sharpness",
  137. "brightness",
  138. "autocontrast",
  139. "equalize",
  140. "invert",
  141. ]
  142. assert (operation1 in supported_ops) and (
  143. operation2 in supported_ops
  144. ), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
  145. assert (
  146. 0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
  147. ), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
  148. assert (
  149. isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
  150. ), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
  151. assert (
  152. isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
  153. ), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
  154. # Define a dictionary where each key refers to a specific type of
  155. # augmentation and the corresponding value is a range of ten possible
  156. # magnitude values for that augmentation.
  157. num_levels = _MAX_LEVEL + 1
  158. ranges = {
  159. "shearX": np.linspace(0, 0.3, num_levels),
  160. "shearY": np.linspace(0, 0.3, num_levels),
  161. "translateX": np.linspace(0, 150 / 331, num_levels),
  162. "translateY": np.linspace(0, 150 / 331, num_levels),
  163. "rotate": np.linspace(0, 30, num_levels),
  164. "color": np.linspace(0.0, 0.9, num_levels),
  165. "posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
  166. np.int
  167. ),
  168. "solarize": np.linspace(256, 0, num_levels), # range [0, 256]
  169. "contrast": np.linspace(0.0, 0.9, num_levels),
  170. "sharpness": np.linspace(0.0, 0.9, num_levels),
  171. "brightness": np.linspace(0.0, 0.9, num_levels),
  172. "autocontrast": [0]
  173. * num_levels, # This augmentation doesn't use magnitude parameter.
  174. "equalize": [0]
  175. * num_levels, # This augmentation doesn't use magnitude parameter.
  176. "invert": [0]
  177. * num_levels, # This augmentation doesn't use magnitude parameter.
  178. }
  179. def rotate_with_fill(img, magnitude):
  180. """Define rotation transformation with fill.
  181. The input image is first rotated, then it is blended together with
  182. a gray mask of the same size. Note that fillcolor as defined
  183. elsewhere in this module doesn't apply here.
  184. Args:
  185. magnitude (float): rotation angle in degrees.
  186. Returns:
  187. rotated_filled (PIL Image): rotated image with gray filling for
  188. disoccluded areas unveiled by the rotation.
  189. """
  190. rotated = img.convert("RGBA").rotate(magnitude)
  191. rotated_filled = Image.composite(
  192. rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
  193. )
  194. return rotated_filled.convert(img.mode)
  195. # Define a dictionary of augmentation functions where each key refers
  196. # to a specific type of augmentation and the corresponding value defines
  197. # the augmentation itself using a lambda function.
  198. # pylint: disable=unnecessary-lambda
  199. func_dict = {
  200. "shearX": lambda img, magnitude: img.transform(
  201. img.size,
  202. Image.AFFINE,
  203. (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
  204. Image.BICUBIC,
  205. fillcolor=fillcolor,
  206. ),
  207. "shearY": lambda img, magnitude: img.transform(
  208. img.size,
  209. Image.AFFINE,
  210. (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
  211. Image.BICUBIC,
  212. fillcolor=fillcolor,
  213. ),
  214. "translateX": lambda img, magnitude: img.transform(
  215. img.size,
  216. Image.AFFINE,
  217. (
  218. 1,
  219. 0,
  220. magnitude * img.size[0] * random.choice([-1, 1]),
  221. 0,
  222. 1,
  223. 0,
  224. ),
  225. fillcolor=fillcolor,
  226. ),
  227. "translateY": lambda img, magnitude: img.transform(
  228. img.size,
  229. Image.AFFINE,
  230. (
  231. 1,
  232. 0,
  233. 0,
  234. 0,
  235. 1,
  236. magnitude * img.size[1] * random.choice([-1, 1]),
  237. ),
  238. fillcolor=fillcolor,
  239. ),
  240. "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
  241. "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
  242. 1 + magnitude * random.choice([-1, 1])
  243. ),
  244. "posterize": lambda img, magnitude: ImageOps.posterize(
  245. img, magnitude
  246. ),
  247. "solarize": lambda img, magnitude: ImageOps.solarize(
  248. img, magnitude
  249. ),
  250. "contrast": lambda img, magnitude: ImageEnhance.Contrast(
  251. img
  252. ).enhance(1 + magnitude * random.choice([-1, 1])),
  253. "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
  254. img
  255. ).enhance(1 + magnitude * random.choice([-1, 1])),
  256. "brightness": lambda img, magnitude: ImageEnhance.Brightness(
  257. img
  258. ).enhance(1 + magnitude * random.choice([-1, 1])),
  259. "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
  260. "equalize": lambda img, magnitude: ImageOps.equalize(img),
  261. "invert": lambda img, magnitude: ImageOps.invert(img),
  262. }
  263. # Store probability, function and magnitude of the first augmentation
  264. # for the sub-policy.
  265. self.probability1 = probability1
  266. self.operation1 = func_dict[operation1]
  267. self.magnitude1 = ranges[operation1][magnitude_idx1]
  268. # Store probability, function and magnitude of the second augmentation
  269. # for the sub-policy.
  270. self.probability2 = probability2
  271. self.operation2 = func_dict[operation2]
  272. self.magnitude2 = ranges[operation2][magnitude_idx2]
  273. def __call__(self, img):
  274. """Define call method for SubPolicy class."""
  275. # Randomly apply operation 1.
  276. if random.random() < self.probability1:
  277. img = self.operation1(img, self.magnitude1)
  278. # Randomly apply operation 2.
  279. if random.random() < self.probability2:
  280. img = self.operation2(img, self.magnitude2)
  281. return img