group_by_aspect_ratio.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import bisect
  2. from collections import defaultdict
  3. import copy
  4. from itertools import repeat, chain
  5. import math
  6. import numpy as np
  7. import torch
  8. import torch.utils.data
  9. from torch.utils.data.sampler import BatchSampler, Sampler
  10. from torch.utils.model_zoo import tqdm
  11. import torchvision
  12. from PIL import Image
  13. def _repeat_to_at_least(iterable, n):
  14. repeat_times = math.ceil(n / len(iterable))
  15. repeated = chain.from_iterable(repeat(iterable, repeat_times))
  16. return list(repeated)
  17. class GroupedBatchSampler(BatchSampler):
  18. """
  19. Wraps another sampler to yield a mini-batch of indices.
  20. It enforces that the batch only contain elements from the same group.
  21. It also tries to provide mini-batches which follows an ordering which is
  22. as close as possible to the ordering from the original sampler.
  23. Arguments:
  24. sampler (Sampler): Base sampler.
  25. group_ids (list[int]): If the sampler produces indices in range [0, N),
  26. `group_ids` must be a list of `N` ints which contains the group id of each sample.
  27. The group ids must be a continuous set of integers starting from
  28. 0, i.e. they must be in the range [0, num_groups).
  29. batch_size (int): Size of mini-batch.
  30. """
  31. def __init__(self, sampler, group_ids, batch_size):
  32. if not isinstance(sampler, Sampler):
  33. raise ValueError(
  34. "sampler should be an instance of "
  35. "torch.utils.data.Sampler, but got sampler={}".format(sampler)
  36. )
  37. self.sampler = sampler
  38. self.group_ids = group_ids
  39. self.batch_size = batch_size
  40. def __iter__(self):
  41. buffer_per_group = defaultdict(list)
  42. samples_per_group = defaultdict(list)
  43. num_batches = 0
  44. for idx in self.sampler:
  45. group_id = self.group_ids[idx]
  46. buffer_per_group[group_id].append(idx)
  47. samples_per_group[group_id].append(idx)
  48. if len(buffer_per_group[group_id]) == self.batch_size:
  49. yield buffer_per_group[group_id]
  50. num_batches += 1
  51. del buffer_per_group[group_id]
  52. assert len(buffer_per_group[group_id]) < self.batch_size
  53. # now we have run out of elements that satisfy
  54. # the group criteria, let's return the remaining
  55. # elements so that the size of the sampler is
  56. # deterministic
  57. expected_num_batches = len(self)
  58. num_remaining = expected_num_batches - num_batches
  59. if num_remaining > 0:
  60. # for the remaining batches, take first the buffers with largest number
  61. # of elements
  62. for group_id, _ in sorted(buffer_per_group.items(),
  63. key=lambda x: len(x[1]), reverse=True):
  64. remaining = self.batch_size - len(buffer_per_group[group_id])
  65. samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
  66. buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
  67. assert len(buffer_per_group[group_id]) == self.batch_size
  68. yield buffer_per_group[group_id]
  69. num_remaining -= 1
  70. if num_remaining == 0:
  71. break
  72. assert num_remaining == 0
  73. def __len__(self):
  74. return len(self.sampler) // self.batch_size
  75. def _compute_aspect_ratios_slow(dataset, indices=None):
  76. print("Your dataset doesn't support the fast path for "
  77. "computing the aspect ratios, so will iterate over "
  78. "the full dataset and load every image instead. "
  79. "This might take some time...")
  80. if indices is None:
  81. indices = range(len(dataset))
  82. class SubsetSampler(Sampler):
  83. def __init__(self, indices):
  84. self.indices = indices
  85. def __iter__(self):
  86. return iter(self.indices)
  87. def __len__(self):
  88. return len(self.indices)
  89. sampler = SubsetSampler(indices)
  90. data_loader = torch.utils.data.DataLoader(
  91. dataset, batch_size=1, sampler=sampler,
  92. num_workers=14, # you might want to increase it for faster processing
  93. collate_fn=lambda x: x[0])
  94. aspect_ratios = []
  95. with tqdm(total=len(dataset)) as pbar:
  96. for _i, (img, _) in enumerate(data_loader):
  97. pbar.update(1)
  98. height, width = img.shape[-2:]
  99. aspect_ratio = float(width) / float(height)
  100. aspect_ratios.append(aspect_ratio)
  101. return aspect_ratios
  102. def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
  103. if indices is None:
  104. indices = range(len(dataset))
  105. aspect_ratios = []
  106. for i in indices:
  107. height, width = dataset.get_height_and_width(i)
  108. aspect_ratio = float(width) / float(height)
  109. aspect_ratios.append(aspect_ratio)
  110. return aspect_ratios
  111. def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
  112. if indices is None:
  113. indices = range(len(dataset))
  114. aspect_ratios = []
  115. for i in indices:
  116. img_info = dataset.coco.imgs[dataset.ids[i]]
  117. aspect_ratio = float(img_info["width"]) / float(img_info["height"])
  118. aspect_ratios.append(aspect_ratio)
  119. return aspect_ratios
  120. def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
  121. if indices is None:
  122. indices = range(len(dataset))
  123. aspect_ratios = []
  124. for i in indices:
  125. # this doesn't load the data into memory, because PIL loads it lazily
  126. width, height = Image.open(dataset.images[i]).size
  127. aspect_ratio = float(width) / float(height)
  128. aspect_ratios.append(aspect_ratio)
  129. return aspect_ratios
  130. def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
  131. if indices is None:
  132. indices = range(len(dataset))
  133. ds_indices = [dataset.indices[i] for i in indices]
  134. return compute_aspect_ratios(dataset.dataset, ds_indices)
  135. def compute_aspect_ratios(dataset, indices=None):
  136. if hasattr(dataset, "get_height_and_width"):
  137. return _compute_aspect_ratios_custom_dataset(dataset, indices)
  138. if isinstance(dataset, torchvision.datasets.CocoDetection):
  139. return _compute_aspect_ratios_coco_dataset(dataset, indices)
  140. if isinstance(dataset, torchvision.datasets.VOCDetection):
  141. return _compute_aspect_ratios_voc_dataset(dataset, indices)
  142. if isinstance(dataset, torch.utils.data.Subset):
  143. return _compute_aspect_ratios_subset_dataset(dataset, indices)
  144. # slow path
  145. return _compute_aspect_ratios_slow(dataset, indices)
  146. def _quantize(x, bins):
  147. bins = copy.deepcopy(bins)
  148. bins = sorted(bins)
  149. quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
  150. return quantized
  151. def create_aspect_ratio_groups(dataset, k=0):
  152. aspect_ratios = compute_aspect_ratios(dataset)
  153. bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
  154. groups = _quantize(aspect_ratios, bins)
  155. # count number of elements per group
  156. counts = np.unique(groups, return_counts=True)[1]
  157. fbins = [0] + bins + [np.inf]
  158. print("Using {} as bins for aspect ratio quantization".format(fbins))
  159. print("Count of instances per bin: {}".format(counts))
  160. return groups