-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🐛 Describe the bug
I'm trying to use CUDAPluggableAllocator
, following https://pytorch.org/docs/stable/notes/cuda.html#using-custom-memory-allocators-for-cuda . However, it has a critical limitation, that torch.cuda.memory.change_current_allocator
needs to be called before any allocation, and we cannot switch the allocator.
Following @syed-ahmed 's suggestion, I'm trying to use CUDAPluggableAllocator
with MemPool
, and it seems to work, in the sense that I can switch between allocators. However, I find that, in this way, the pool never returns memory to the underlying allocator.
Here is a simple demonstration code snippet:
import torch
import torch.utils.cpp_extension
cpp_sources = """
// save as alloc.cc
// compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
// Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC
extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
void *ptr;
cudaMalloc(&ptr, size);
std::cout<<"C side: alloc "<<ptr<< " " <<size<<std::endl;
return ptr;
}
void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) {
std::cout<<"C side: free "<<ptr<< " "<<size<<std::endl;
cudaFree(ptr);
}
// hack: add this placeholder function to let PyTorch generate module extension template
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
return x.sin() + y.sin();
}
}
"""
module = torch.utils.cpp_extension.load_inline("alloc", cpp_sources, with_cuda=True, functions=['sin_add'])
so_file = module.__file__
def f():
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
so_file, 'my_malloc', 'my_free')
with torch.cuda.use_mem_pool(torch.cuda.MemPool(new_alloc._allocator)):
for factor in (1024, 1024 ** 2):
print(f"Allocate {60 * factor} bytes of memory on the GPU from Python")
data = torch.empty((60, factor), dtype=torch.uint8, device="cuda")
print(f"Free {60 * factor} bytes of memory on the GPU from Python")
del data
print("Python side: memory is released")
print(f"Allocate {70 * factor} bytes of memory on the GPU from Python")
data = torch.empty((70, factor), dtype=torch.uint8, device="cuda")
print(f"Free {70 * factor} bytes of memory on the GPU from Python")
del data
print("Python side: memory is released")
# torch.cuda.empty_cache() here will error: RuntimeError: captures_underway.empty() INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":2967, please report a bug to PyTorch.
# torch.cuda.empty_cache() here does not take effect.
f()
import gc
gc.collect()
Running the code, we can see that C side: alloc
is called properly. However, C side: free
is never called.
In addition, if I call torch.cuda.empty_cache()
inside with torch.cuda.use_mem_pool
, it will trigger an assertion error.
Ultimately, my goal is to switch between CUDAPluggableAllocator
and the default allocator, and also empty_cache
for the CUDAPluggableAllocator
.
Versions
PyTorch 2.5.1+cu124