utils.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. from collections import defaultdict, deque
  2. import datetime
  3. import pickle
  4. import time
  5. import torch
  6. import torch.distributed as dist
  7. import errno
  8. import os
  9. class SmoothedValue(object):
  10. """Track a series of values and provide access to smoothed values over a
  11. window or the global series average.
  12. """
  13. def __init__(self, window_size=20, fmt=None):
  14. if fmt is None:
  15. fmt = "{median:.4f} ({global_avg:.4f})"
  16. self.deque = deque(maxlen=window_size)
  17. self.total = 0.0
  18. self.count = 0
  19. self.fmt = fmt
  20. def update(self, value, n=1):
  21. self.deque.append(value)
  22. self.count += n
  23. self.total += value * n
  24. def synchronize_between_processes(self):
  25. """
  26. Warning: does not synchronize the deque!
  27. """
  28. if not is_dist_avail_and_initialized():
  29. return
  30. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  31. dist.barrier()
  32. dist.all_reduce(t)
  33. t = t.tolist()
  34. self.count = int(t[0])
  35. self.total = t[1]
  36. @property
  37. def median(self):
  38. d = torch.tensor(list(self.deque))
  39. return d.median().item()
  40. @property
  41. def avg(self):
  42. d = torch.tensor(list(self.deque), dtype=torch.float32)
  43. return d.mean().item()
  44. @property
  45. def global_avg(self):
  46. return self.total / self.count
  47. @property
  48. def max(self):
  49. return max(self.deque)
  50. @property
  51. def value(self):
  52. return self.deque[-1]
  53. def __str__(self):
  54. return self.fmt.format(
  55. median=self.median,
  56. avg=self.avg,
  57. global_avg=self.global_avg,
  58. max=self.max,
  59. value=self.value)
  60. def all_gather(data):
  61. """
  62. Run all_gather on arbitrary picklable data (not necessarily tensors)
  63. Args:
  64. data: any picklable object
  65. Returns:
  66. list[data]: list of data gathered from each rank
  67. """
  68. world_size = get_world_size()
  69. if world_size == 1:
  70. return [data]
  71. # serialized to a Tensor
  72. buffer = pickle.dumps(data)
  73. storage = torch.ByteStorage.from_buffer(buffer)
  74. tensor = torch.ByteTensor(storage).to("cuda")
  75. # obtain Tensor size of each rank
  76. local_size = torch.tensor([tensor.numel()], device="cuda")
  77. size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
  78. dist.all_gather(size_list, local_size)
  79. size_list = [int(size.item()) for size in size_list]
  80. max_size = max(size_list)
  81. # receiving Tensor from all ranks
  82. # we pad the tensor because torch all_gather does not support
  83. # gathering tensors of different shapes
  84. tensor_list = []
  85. for _ in size_list:
  86. tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
  87. if local_size != max_size:
  88. padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
  89. tensor = torch.cat((tensor, padding), dim=0)
  90. dist.all_gather(tensor_list, tensor)
  91. data_list = []
  92. for size, tensor in zip(size_list, tensor_list):
  93. buffer = tensor.cpu().numpy().tobytes()[:size]
  94. data_list.append(pickle.loads(buffer))
  95. return data_list
  96. def reduce_dict(input_dict, average=True):
  97. """
  98. Args:
  99. input_dict (dict): all the values will be reduced
  100. average (bool): whether to do average or sum
  101. Reduce the values in the dictionary from all processes so that all processes
  102. have the averaged results. Returns a dict with the same fields as
  103. input_dict, after reduction.
  104. """
  105. world_size = get_world_size()
  106. if world_size < 2:
  107. return input_dict
  108. with torch.no_grad():
  109. names = []
  110. values = []
  111. # sort the keys so that they are consistent across processes
  112. for k in sorted(input_dict.keys()):
  113. names.append(k)
  114. values.append(input_dict[k])
  115. values = torch.stack(values, dim=0)
  116. dist.all_reduce(values)
  117. if average:
  118. values /= world_size
  119. reduced_dict = {k: v for k, v in zip(names, values)}
  120. return reduced_dict
  121. class MetricLogger(object):
  122. def __init__(self, delimiter="\t"):
  123. self.meters = defaultdict(SmoothedValue)
  124. self.delimiter = delimiter
  125. def update(self, **kwargs):
  126. for k, v in kwargs.items():
  127. if isinstance(v, torch.Tensor):
  128. v = v.item()
  129. assert isinstance(v, (float, int))
  130. self.meters[k].update(v)
  131. def __getattr__(self, attr):
  132. if attr in self.meters:
  133. return self.meters[attr]
  134. if attr in self.__dict__:
  135. return self.__dict__[attr]
  136. raise AttributeError("'{}' object has no attribute '{}'".format(
  137. type(self).__name__, attr))
  138. def __str__(self):
  139. loss_str = []
  140. for name, meter in self.meters.items():
  141. loss_str.append(
  142. "{}: {}".format(name, str(meter))
  143. )
  144. return self.delimiter.join(loss_str)
  145. def synchronize_between_processes(self):
  146. for meter in self.meters.values():
  147. meter.synchronize_between_processes()
  148. def add_meter(self, name, meter):
  149. self.meters[name] = meter
  150. def log_every(self, iterable, print_freq, header=None):
  151. i = 0
  152. if not header:
  153. header = ''
  154. start_time = time.time()
  155. end = time.time()
  156. iter_time = SmoothedValue(fmt='{avg:.4f}')
  157. data_time = SmoothedValue(fmt='{avg:.4f}')
  158. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  159. if torch.cuda.is_available():
  160. log_msg = self.delimiter.join([
  161. header,
  162. '[{0' + space_fmt + '}/{1}]',
  163. 'eta: {eta}',
  164. '{meters}',
  165. 'time: {time}',
  166. 'data: {data}',
  167. 'max mem: {memory:.0f}'
  168. ])
  169. else:
  170. log_msg = self.delimiter.join([
  171. header,
  172. '[{0' + space_fmt + '}/{1}]',
  173. 'eta: {eta}',
  174. '{meters}',
  175. 'time: {time}',
  176. 'data: {data}'
  177. ])
  178. MB = 1024.0 * 1024.0
  179. for obj in iterable:
  180. data_time.update(time.time() - end)
  181. yield obj
  182. iter_time.update(time.time() - end)
  183. if i % print_freq == 0 or i == len(iterable) - 1:
  184. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  185. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  186. if torch.cuda.is_available():
  187. print(log_msg.format(
  188. i, len(iterable), eta=eta_string,
  189. meters=str(self),
  190. time=str(iter_time), data=str(data_time),
  191. memory=torch.cuda.max_memory_allocated() / MB))
  192. else:
  193. print(log_msg.format(
  194. i, len(iterable), eta=eta_string,
  195. meters=str(self),
  196. time=str(iter_time), data=str(data_time)))
  197. i += 1
  198. end = time.time()
  199. total_time = time.time() - start_time
  200. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  201. print('{} Total time: {} ({:.4f} s / it)'.format(
  202. header, total_time_str, total_time / len(iterable)))
  203. def collate_fn(batch):
  204. return tuple(zip(*batch))
  205. def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
  206. def f(x):
  207. if x >= warmup_iters:
  208. return 1
  209. alpha = float(x) / warmup_iters
  210. return warmup_factor * (1 - alpha) + alpha
  211. return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
  212. def mkdir(path):
  213. try:
  214. os.makedirs(path)
  215. except OSError as e:
  216. if e.errno != errno.EEXIST:
  217. raise
  218. def setup_for_distributed(is_master):
  219. """
  220. This function disables printing when not in master process
  221. """
  222. import builtins as __builtin__
  223. builtin_print = __builtin__.print
  224. def print(*args, **kwargs):
  225. force = kwargs.pop('force', False)
  226. if is_master or force:
  227. builtin_print(*args, **kwargs)
  228. __builtin__.print = print
  229. def is_dist_avail_and_initialized():
  230. if not dist.is_available():
  231. return False
  232. if not dist.is_initialized():
  233. return False
  234. return True
  235. def get_world_size():
  236. if not is_dist_avail_and_initialized():
  237. return 1
  238. return dist.get_world_size()
  239. def get_rank():
  240. if not is_dist_avail_and_initialized():
  241. return 0
  242. return dist.get_rank()
  243. def is_main_process():
  244. return get_rank() == 0
  245. def save_on_master(*args, **kwargs):
  246. if is_main_process():
  247. torch.save(*args, **kwargs)
  248. def init_distributed_mode(args):
  249. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  250. args.rank = int(os.environ["RANK"])
  251. args.world_size = int(os.environ['WORLD_SIZE'])
  252. args.gpu = int(os.environ['LOCAL_RANK'])
  253. elif 'SLURM_PROCID' in os.environ:
  254. args.rank = int(os.environ['SLURM_PROCID'])
  255. args.gpu = args.rank % torch.cuda.device_count()
  256. else:
  257. print('Not using distributed mode')
  258. args.distributed = False
  259. return
  260. args.distributed = True
  261. torch.cuda.set_device(args.gpu)
  262. args.dist_backend = 'nccl'
  263. print('| distributed init (rank {}): {}'.format(
  264. args.rank, args.dist_url), flush=True)
  265. torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
  266. world_size=args.world_size, rank=args.rank)
  267. torch.distributed.barrier()
  268. setup_for_distributed(args.rank == 0)