# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch # A dictionary of all the memory buffers allocated. _MEM_BUFFS = dict() def allocate_mem_buff(name, numel, dtype, track_usage): """Allocate a memory buffer.""" assert name not in _MEM_BUFFS, \ 'memory buffer {} already allocated.'.format(name) _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) return _MEM_BUFFS[name] def get_mem_buff(name): """Get the memory buffer.""" return _MEM_BUFFS[name] class MemoryBuffer: """Contiguous memory buffer. Allocate a contiguous memory of type `dtype` and size `numel`. It is used to reduce memory fragmentation. Usage: After the allocation, the `_start` index is set tot the first index of the memory. A memory chunk starting from `_start` index can be `allocated` for an input tensor, with the elements of the tensor being coppied. The buffer can be reused by resetting the `_start` index. """ def __init__(self, name, numel, dtype, track_usage): if torch.distributed.get_rank() == 0: element_size = torch.tensor([], dtype=dtype).element_size() print('> building the {} memory buffer with {} num elements ' 'and {} dtype ({:.1f} MB)...'.format( name, numel, dtype, numel*element_size/1024/1024), flush=True) self.name = name self.numel = numel self.dtype = dtype self.data = torch.empty(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False) # Index tracking the start of the free memory. self._start = 0 # Values used for tracking usage. self.track_usage = track_usage if self.track_usage: self.in_use_value = 0.0 self.total_value = 0.0 def reset(self): """Reset the buffer start index to the beginning of the buffer.""" self._start = 0 def is_in_use(self): """Whether the current buffer hold on to any memory.""" return self._start > 0 def numel_in_use(self): """Return number of elements in use.""" return self._start def add(self, tensor): """Allocate a chunk of memory from the buffer to tensor and copy the values.""" assert tensor.dtype == self.dtype, \ 'Input tensor type {} different from buffer type {}'.format( tensor.dtype, self.dtype) # Number of elements of the input tensor. tensor_numel = torch.numel(tensor) new_start = self._start + tensor_numel assert new_start <= self.numel, \ 'Not enough memory left in the buffer ({} > {})'.format( tensor_numel, self.numel - self._start) # New tensor is a view into the memory. new_tensor = self.data[self._start:new_start] self._start = new_start new_tensor = new_tensor.view(tensor.shape) new_tensor.copy_(tensor) # Return a pointer to the new tensor. return new_tensor def get_data(self): """Return the data currently in use.""" if self.track_usage: self.in_use_value += float(self._start) self.total_value += float(self.numel) return self.data[:self._start] def print_average_usage(self): """Print memory usage average over time. We would like this value to be as high as possible.""" assert self.track_usage, 'You need to enable track usage.' if torch.distributed.get_rank() == 0: print(' > usage of {} memory buffer: {:.2f} %'.format( self.name, self.in_use_value * 100.0 / self.total_value), flush=True) class RingMemBuffer: """A ring of memory buffers.""" def __init__(self, name, num_buffers, numel, dtype, track_usage): self.num_buffers = num_buffers self.buffers = [ allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage) for i in range(num_buffers)] self._index = -1 def get_next_buffer(self): self._index += 1 self._index = self._index % self.num_buffers buff = self.buffers[self._index] assert not buff.is_in_use(), 'buffer is already in use.' return buff