-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🚀 The feature, motivation and pitch
There is a growing need to mix-and-match cuda allocators in PyTorch for new NVIDIA architectures. For instance:
- Reductions on NVIDIA Switches (NVLS Reductions), require buffers to have specific alignment (currently facilitated by ncclMemAlloc API). Buffers then need to be “registered” in the process group to finish the setup for the reduction.
- Extended GPU Memory (EGM) based all-gathers also require buffers to have specific alignment and optionally a NUMA location (can be specified by creating memory with cuMemCreate and CU_MEM_LOCATION_TYPE_HOST_NUMA).
Currently, a user cannot mark regions of pytorch code to use a different allocator, while using the default allocator for unmarked regions. The only two ways available (BACKEND
environment variable, or CUDAPluggableAllocator
) overrides the CUDACachingAllocator object globally, giving up the benefits of the CUDACachingAllocator.
We propose to expose Private Pools (in a more first-class manner) to the user land, along with the ability for the user to provide their own {allocator, deleter} functions to specify how the blocks in the pool should be allocated. In addition, users should be able to mark when to begin allocation to pool, when to end allocation to a pool, and finally when to destroy allocated pools.
Proposed Approach
- A user can already create a private pool in PyTorch.
torch.cuda.graph_pool_handle()
provides a unique MemPool IDtorch._C._cuda_beginAllocateCurrentStreamToPool(...)
can use this ID and create a PrivatePool object in the CUDACachingAllocator managed by thegraph_pools
container.torch._C._cuda_endAllocateCurrentStreamToPool(...)
can be used to mark when to stop allocating to a pool.torch._C._cuda_releasePool
can be used to mark when the pool’s memory can be deleted or is safe to be reused by other consumers.
- We would like to encapsulate
torch.cuda.graph_pool_handle()
into a class:using CaptureId_t = unsigned long long; using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>; using AllocFuncType = void*(size_t); using DeleteFuncType = void(void*); struct C10_CUDA_API MemPool { CUDAMemPool(std::function<AllocFuncType> alloc_fn = {}, std::function<DeleteFuncType> delete_fn = {}, bool is_user_created=true); // new members std::mutex mutex_; std::function<AllocFuncType> alloc_fn_; std::function<DeleteFuncType> delete_fn_; // mempool_id_ holding uid and uuid logic from graph_pool_handle() static MempoolId_t mempool_id_; }
- An “allocator” function pointer and “delete” function pointer is added (similar to CUDAPluggableAllocator). The atomic objects (
uid
anduuid
) are changed to use a mutex so that we can have similar thread semantics for settingalloc_fn_
anddelete_fn_
.
graph_pool_handle()
can then be modified to use the following to preserve functionality:auto new_pool = c10::cuda::MemPool(); return new_pool.mempool_id_;
- Additionally, places using
c10::cuda::MempoolId_t{0, 0}
can be changed toc10::cuda::CUDAMemPool(false)
.MempoolId_t{0, 0}
is used to track whether the pool is user created or not in CUDAGraphscapture_begin
implementation. Here, we can just change the logic to explicitly have a bool in the constructor of the MemPool, and never have a mempool_id that is {0,0}. - Now that we have a MemPool object that can take an allocator and delete function pointers from the users, we can pass this info to the CUDACachingAllocator. We would like to add some function pointer members to BlockPool and PrivatePool, such that pools can use the function pointers when not null:
BlockPool(bool small, PrivatePool* private_pool = nullptr, AllocFnPtr allocator = nullptr, DeleteFnPtr deleter = nullptr) PrivatePool(AllocFnPtr allocator = nullptr, DeleteFnPtr deleter = nullptr)
- We would create another container (
user_pools
) similar tograph_pools
in CUDACachingAllocator, to distinguish from CUDAGraph related pool usage. - We would create another variable similar to
captures_underway
, such that we can pick betweengraph_pools
or theuser_pools
inget_pool
. beginAllocateToPool
in the CUDACachingAllocator can then create the PrivatePools with the allocator and deleter functions.alloc_block
can be modified to use the function pointer:auto allocator = p.pool->allocator; if (allocator) { p.err = allocator(&ptr, size); } else { p.err = cudaMallocMaybeCapturing(&ptr, size); }
release_block
can be modified similarly to use the deletor.- Tensors can then utilize the user specified allocator and safely be deleted with the correct deletor.
- In a similar manner, we can modify
endAllocateToPool
andreleasePools
. - We can then provide a context manager similar to
_use_cuda_memory_pool_manager
or use the APIs directly in Python:import torch import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group import os from cuda import cuda import ctypes def nccl_mem_alloc(size): nccl = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libnccl.so") assert nccl.ncclMemAlloc is not None ptr = ctypes.c_void_p() err = nccl.ncclMemAlloc(ctypes.byref(ptr), ctypes.c_size_t(size)) if err != 0: raise RuntimeError(f"Failed to allocate memory with ncclMemAlloc with error code: {err}") return ptr def nccl_mem_free(ptr): nccl = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libnccl.so") assert nccl.ncclMemFree is not None err = nccl.ncclMemFree(ctypes.c_void_p(ptr)) if err != 0: raise RuntimeError(f"Failed to free memory with ncclMemFree with error code: {err}") pool = torch.cuda.MemPool(nccl_mem_alloc, nccl_mem_free) device = torch.device("cuda:0") stream = torch.cuda.Stream() dist.init_process_group(backend='nccl') default_pg = _get_default_group() backend = default_pg._get_backend(device) with torch.cuda.mempool(pool, device, stream): special_tensor = torch.randn(2**32, device="cuda") # Use in distributed for NVLS reduction (pseudocode) backend.register_user_buffers(pool) # collective uses NVLS reduction dist.all_reduce(special_tensor)
Alternatives
-
CUDAPluggableAllocator
- CUDAPluggableAllocator can be used to override CUDACachingAllocator and the use cases mentioned above can be successfully implemented. However, it takes over the allocator for the entire lifetime of the program and once it gets enabled there is no way to get back to CUDACachingAllocator. Even if we can go back to using CUDACachingAllocator, it won’t be safe as tensors wouldn’t know the correct deletor function to use.
-
Subclassing a Tensor
- We could use a similar approach to FBGEMM_GPU where we modify the Storage of the tensor and plug in the user specified allocator and deletor there. However, this approach doesn’t compose and generalize well, and can only be used for very specific use cases like a custom UVM based embedding kernel in TorchRec.
Additional context
cc: @ptrblck @eqy @Aidyn-A @zdevito @kwen2501 @minsii
Relevant PRs: