| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 | # coding=utf-8# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.import torch# A dictionary of all the memory buffers allocated._MEM_BUFFS = dict()def allocate_mem_buff(name, numel, dtype, track_usage):    """Allocate a memory buffer."""    assert name not in _MEM_BUFFS, \        'memory buffer {} already allocated.'.format(name)    _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)    return _MEM_BUFFS[name]def get_mem_buff(name):    """Get the memory buffer."""    return _MEM_BUFFS[name]class MemoryBuffer:    """Contiguous memory buffer.    Allocate a contiguous memory of type `dtype` and size `numel`. It is    used to reduce memory fragmentation.    Usage: After the allocation, the `_start` index is set tot the first           index of the memory. A memory chunk starting from `_start` index           can be `allocated` for an input tensor, with the elements of the           tensor being coppied. The buffer can be reused by resetting the           `_start` index.    """    def __init__(self, name, numel, dtype, track_usage):        if torch.distributed.get_rank() == 0:            element_size = torch.tensor([], dtype=dtype).element_size()            print('> building the {} memory buffer with {} num elements '                  'and {} dtype ({:.1f} MB)...'.format(                      name, numel, dtype, numel*element_size/1024/1024),                  flush=True)        self.name = name        self.numel = numel        self.dtype = dtype        self.data = torch.empty(self.numel,                                dtype=self.dtype,                                device=torch.cuda.current_device(),                                requires_grad=False)        # Index tracking the start of the free memory.        self._start = 0        # Values used for tracking usage.        self.track_usage = track_usage        if self.track_usage:            self.in_use_value = 0.0            self.total_value = 0.0    def reset(self):        """Reset the buffer start index to the beginning of the buffer."""        self._start = 0    def is_in_use(self):        """Whether the current buffer hold on to any memory."""        return self._start > 0    def numel_in_use(self):        """Return number of elements in use."""        return self._start    def add(self, tensor):        """Allocate a chunk of memory from the buffer to tensor and copy        the values."""        assert tensor.dtype == self.dtype, \            'Input tensor type {} different from buffer type {}'.format(                tensor.dtype, self.dtype)        # Number of elements of the input tensor.        tensor_numel = torch.numel(tensor)        new_start = self._start + tensor_numel        assert new_start <= self.numel, \            'Not enough memory left in the buffer ({} > {})'.format(                tensor_numel, self.numel - self._start)        # New tensor is a view into the memory.        new_tensor = self.data[self._start:new_start]        self._start = new_start        new_tensor = new_tensor.view(tensor.shape)        new_tensor.copy_(tensor)        # Return a pointer to the new tensor.        return new_tensor    def get_data(self):        """Return the data currently in use."""        if self.track_usage:            self.in_use_value += float(self._start)            self.total_value += float(self.numel)        return self.data[:self._start]    def print_average_usage(self):        """Print memory usage average over time. We would like this value        to be as high as possible."""        assert self.track_usage, 'You need to enable track usage.'        if torch.distributed.get_rank() == 0:            print(' > usage of {} memory buffer: {:.2f} %'.format(                self.name, self.in_use_value * 100.0 / self.total_value),                  flush=True)class RingMemBuffer:    """A ring of memory buffers."""    def __init__(self, name, num_buffers, numel, dtype, track_usage):        self.num_buffers = num_buffers        self.buffers = [            allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage)            for i in range(num_buffers)]        self._index = -1    def get_next_buffer(self):        self._index += 1        self._index = self._index % self.num_buffers        buff = self.buffers[self._index]        assert not buff.is_in_use(), 'buffer is already in use.'        return buff
 |