memory.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. # A dictionary of all the memory buffers allocated.
  17. _MEM_BUFFS = dict()
  18. def allocate_mem_buff(name, numel, dtype, track_usage):
  19. """Allocate a memory buffer."""
  20. assert name not in _MEM_BUFFS, \
  21. 'memory buffer {} already allocated.'.format(name)
  22. _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
  23. return _MEM_BUFFS[name]
  24. def get_mem_buff(name):
  25. """Get the memory buffer."""
  26. return _MEM_BUFFS[name]
  27. class MemoryBuffer:
  28. """Contiguous memory buffer.
  29. Allocate a contiguous memory of type `dtype` and size `numel`. It is
  30. used to reduce memory fragmentation.
  31. Usage: After the allocation, the `_start` index is set tot the first
  32. index of the memory. A memory chunk starting from `_start` index
  33. can be `allocated` for an input tensor, with the elements of the
  34. tensor being coppied. The buffer can be reused by resetting the
  35. `_start` index.
  36. """
  37. def __init__(self, name, numel, dtype, track_usage):
  38. if torch.distributed.get_rank() == 0:
  39. element_size = torch.tensor([], dtype=dtype).element_size()
  40. print('> building the {} memory buffer with {} num elements '
  41. 'and {} dtype ({:.1f} MB)...'.format(
  42. name, numel, dtype, numel*element_size/1024/1024),
  43. flush=True)
  44. self.name = name
  45. self.numel = numel
  46. self.dtype = dtype
  47. self.data = torch.empty(self.numel,
  48. dtype=self.dtype,
  49. device=torch.cuda.current_device(),
  50. requires_grad=False)
  51. # Index tracking the start of the free memory.
  52. self._start = 0
  53. # Values used for tracking usage.
  54. self.track_usage = track_usage
  55. if self.track_usage:
  56. self.in_use_value = 0.0
  57. self.total_value = 0.0
  58. def reset(self):
  59. """Reset the buffer start index to the beginning of the buffer."""
  60. self._start = 0
  61. def is_in_use(self):
  62. """Whether the current buffer hold on to any memory."""
  63. return self._start > 0
  64. def numel_in_use(self):
  65. """Return number of elements in use."""
  66. return self._start
  67. def add(self, tensor):
  68. """Allocate a chunk of memory from the buffer to tensor and copy
  69. the values."""
  70. assert tensor.dtype == self.dtype, \
  71. 'Input tensor type {} different from buffer type {}'.format(
  72. tensor.dtype, self.dtype)
  73. # Number of elements of the input tensor.
  74. tensor_numel = torch.numel(tensor)
  75. new_start = self._start + tensor_numel
  76. assert new_start <= self.numel, \
  77. 'Not enough memory left in the buffer ({} > {})'.format(
  78. tensor_numel, self.numel - self._start)
  79. # New tensor is a view into the memory.
  80. new_tensor = self.data[self._start:new_start]
  81. self._start = new_start
  82. new_tensor = new_tensor.view(tensor.shape)
  83. new_tensor.copy_(tensor)
  84. # Return a pointer to the new tensor.
  85. return new_tensor
  86. def get_data(self):
  87. """Return the data currently in use."""
  88. if self.track_usage:
  89. self.in_use_value += float(self._start)
  90. self.total_value += float(self.numel)
  91. return self.data[:self._start]
  92. def print_average_usage(self):
  93. """Print memory usage average over time. We would like this value
  94. to be as high as possible."""
  95. assert self.track_usage, 'You need to enable track usage.'
  96. if torch.distributed.get_rank() == 0:
  97. print(' > usage of {} memory buffer: {:.2f} %'.format(
  98. self.name, self.in_use_value * 100.0 / self.total_value),
  99. flush=True)
  100. class RingMemBuffer:
  101. """A ring of memory buffers."""
  102. def __init__(self, name, num_buffers, numel, dtype, track_usage):
  103. self.num_buffers = num_buffers
  104. self.buffers = [
  105. allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage)
  106. for i in range(num_buffers)]
  107. self._index = -1
  108. def get_next_buffer(self):
  109. self._index += 1
  110. self._index = self._index % self.num_buffers
  111. buff = self.buffers[self._index]
  112. assert not buff.is_in_use(), 'buffer is already in use.'
  113. return buff