coco_utils.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import copy
  2. import os
  3. from PIL import Image
  4. import torch
  5. import torch.utils.data
  6. import torchvision
  7. from pycocotools import mask as coco_mask
  8. from pycocotools.coco import COCO
  9. import transforms as T
  10. class FilterAndRemapCocoCategories(object):
  11. def __init__(self, categories, remap=True):
  12. self.categories = categories
  13. self.remap = remap
  14. def __call__(self, image, target):
  15. anno = target["annotations"]
  16. anno = [obj for obj in anno if obj["category_id"] in self.categories]
  17. if not self.remap:
  18. target["annotations"] = anno
  19. return image, target
  20. anno = copy.deepcopy(anno)
  21. for obj in anno:
  22. obj["category_id"] = self.categories.index(obj["category_id"])
  23. target["annotations"] = anno
  24. return image, target
  25. def convert_coco_poly_to_mask(segmentations, height, width):
  26. masks = []
  27. for polygons in segmentations:
  28. rles = coco_mask.frPyObjects(polygons, height, width)
  29. mask = coco_mask.decode(rles)
  30. if len(mask.shape) < 3:
  31. mask = mask[..., None]
  32. mask = torch.as_tensor(mask, dtype=torch.uint8)
  33. mask = mask.any(dim=2)
  34. masks.append(mask)
  35. if masks:
  36. masks = torch.stack(masks, dim=0)
  37. else:
  38. masks = torch.zeros((0, height, width), dtype=torch.uint8)
  39. return masks
  40. class ConvertCocoPolysToMask(object):
  41. def __call__(self, image, target):
  42. w, h = image.size
  43. image_id = target["image_id"]
  44. image_id = torch.tensor([image_id])
  45. anno = target["annotations"]
  46. anno = [obj for obj in anno if obj['iscrowd'] == 0]
  47. boxes = [obj["bbox"] for obj in anno]
  48. # guard against no boxes via resizing
  49. boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
  50. boxes[:, 2:] += boxes[:, :2]
  51. boxes[:, 0::2].clamp_(min=0, max=w)
  52. boxes[:, 1::2].clamp_(min=0, max=h)
  53. classes = [obj["category_id"] for obj in anno]
  54. classes = torch.tensor(classes, dtype=torch.int64)
  55. segmentations = [obj["segmentation"] for obj in anno]
  56. masks = convert_coco_poly_to_mask(segmentations, h, w)
  57. keypoints = None
  58. if anno and "keypoints" in anno[0]:
  59. keypoints = [obj["keypoints"] for obj in anno]
  60. keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
  61. num_keypoints = keypoints.shape[0]
  62. if num_keypoints:
  63. keypoints = keypoints.view(num_keypoints, -1, 3)
  64. keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
  65. boxes = boxes[keep]
  66. classes = classes[keep]
  67. masks = masks[keep]
  68. if keypoints is not None:
  69. keypoints = keypoints[keep]
  70. target = {}
  71. target["boxes"] = boxes
  72. target["labels"] = classes
  73. target["masks"] = masks
  74. target["image_id"] = image_id
  75. if keypoints is not None:
  76. target["keypoints"] = keypoints
  77. # for conversion to coco api
  78. area = torch.tensor([obj["area"] for obj in anno])
  79. iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
  80. target["area"] = area
  81. target["iscrowd"] = iscrowd
  82. return image, target
  83. def _coco_remove_images_without_annotations(dataset, cat_list=None):
  84. def _has_only_empty_bbox(anno):
  85. return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
  86. def _count_visible_keypoints(anno):
  87. return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
  88. min_keypoints_per_image = 10
  89. def _has_valid_annotation(anno):
  90. # if it's empty, there is no annotation
  91. if len(anno) == 0:
  92. return False
  93. # if all boxes have close to zero area, there is no annotation
  94. if _has_only_empty_bbox(anno):
  95. return False
  96. # keypoints task have a slight different critera for considering
  97. # if an annotation is valid
  98. if "keypoints" not in anno[0]:
  99. return True
  100. # for keypoint detection tasks, only consider valid images those
  101. # containing at least min_keypoints_per_image
  102. if _count_visible_keypoints(anno) >= min_keypoints_per_image:
  103. return True
  104. return False
  105. assert isinstance(dataset, torchvision.datasets.CocoDetection)
  106. ids = []
  107. for ds_idx, img_id in enumerate(dataset.ids):
  108. ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
  109. anno = dataset.coco.loadAnns(ann_ids)
  110. if cat_list:
  111. anno = [obj for obj in anno if obj["category_id"] in cat_list]
  112. if _has_valid_annotation(anno):
  113. ids.append(ds_idx)
  114. dataset = torch.utils.data.Subset(dataset, ids)
  115. return dataset
  116. def convert_to_coco_api(ds):
  117. coco_ds = COCO()
  118. # annotation IDs need to start at 1, not 0, see torchvision issue #1530
  119. ann_id = 1
  120. dataset = {'images': [], 'categories': [], 'annotations': []}
  121. categories = set()
  122. for img_idx in range(len(ds)):
  123. # find better way to get target
  124. # targets = ds.get_annotations(img_idx)
  125. img, targets = ds[img_idx]
  126. image_id = targets["image_id"].item()
  127. img_dict = {}
  128. img_dict['id'] = image_id
  129. img_dict['height'] = img.shape[-2]
  130. img_dict['width'] = img.shape[-1]
  131. dataset['images'].append(img_dict)
  132. bboxes = targets["boxes"]
  133. bboxes[:, 2:] -= bboxes[:, :2]
  134. bboxes = bboxes.tolist()
  135. labels = targets['labels'].tolist()
  136. areas = targets['area'].tolist()
  137. iscrowd = targets['iscrowd'].tolist()
  138. if 'masks' in targets:
  139. masks = targets['masks']
  140. # make masks Fortran contiguous for coco_mask
  141. masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
  142. if 'keypoints' in targets:
  143. keypoints = targets['keypoints']
  144. keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
  145. num_objs = len(bboxes)
  146. for i in range(num_objs):
  147. ann = {}
  148. ann['image_id'] = image_id
  149. ann['bbox'] = bboxes[i]
  150. ann['category_id'] = labels[i]
  151. categories.add(labels[i])
  152. ann['area'] = areas[i]
  153. ann['iscrowd'] = iscrowd[i]
  154. ann['id'] = ann_id
  155. if 'masks' in targets:
  156. ann["segmentation"] = coco_mask.encode(masks[i].numpy())
  157. if 'keypoints' in targets:
  158. ann['keypoints'] = keypoints[i]
  159. ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
  160. dataset['annotations'].append(ann)
  161. ann_id += 1
  162. dataset['categories'] = [{'id': i} for i in sorted(categories)]
  163. coco_ds.dataset = dataset
  164. coco_ds.createIndex()
  165. return coco_ds
  166. def get_coco_api_from_dataset(dataset):
  167. for _ in range(10):
  168. if isinstance(dataset, torchvision.datasets.CocoDetection):
  169. break
  170. if isinstance(dataset, torch.utils.data.Subset):
  171. dataset = dataset.dataset
  172. if isinstance(dataset, torchvision.datasets.CocoDetection):
  173. return dataset.coco
  174. return convert_to_coco_api(dataset)
  175. class CocoDetection(torchvision.datasets.CocoDetection):
  176. def __init__(self, img_folder, ann_file, transforms):
  177. super(CocoDetection, self).__init__(img_folder, ann_file)
  178. self._transforms = transforms
  179. def __getitem__(self, idx):
  180. img, target = super(CocoDetection, self).__getitem__(idx)
  181. image_id = self.ids[idx]
  182. target = dict(image_id=image_id, annotations=target)
  183. if self._transforms is not None:
  184. img, target = self._transforms(img, target)
  185. return img, target
  186. def get_coco(root, image_set, transforms, mode='instances'):
  187. anno_file_template = "{}_{}2017.json"
  188. PATHS = {
  189. "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
  190. "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
  191. # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
  192. }
  193. t = [ConvertCocoPolysToMask()]
  194. if transforms is not None:
  195. t.append(transforms)
  196. transforms = T.Compose(t)
  197. img_folder, ann_file = PATHS[image_set]
  198. img_folder = os.path.join(root, img_folder)
  199. ann_file = os.path.join(root, ann_file)
  200. dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
  201. if image_set == "train":
  202. dataset = _coco_remove_images_without_annotations(dataset)
  203. # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
  204. return dataset
  205. def get_coco_kp(root, image_set, transforms):
  206. return get_coco(root, image_set, transforms, mode="person_keypoints")