Skip to content

(With PR) torch.cuda.MemPool() internal assertion failure when changing devices #149802

@fzyzcjy

Description

@fzyzcjy

Potential cause analysis

Quickly glanced at the code, quick thoughts:

  • When creating MemPool on device 0, it creates a MemPool on device 0, let's say it has mempool_id 111
  • When first call to use_mem_pool, it tells C++ to find mempool with id 111 on device 1 (!), but that does not exist, so C++ side creates a brand new pool with id 111 on device 1
  • When first use_mem_pool leaves, refcount of the mempool 111 on device 1 decreases by one, and it becomes zero
  • When second call to use_mem_pool, it finds the pool 111 on device 1, but then realize it has refcount being zero, thus error

A quick fix may be adding assertions when using MemPool - if users use a wrong device, we just throw to forbid the action. Another more elaborated fix may be supporting pools on different devices.

🐛 Describe the bug

code

import torch

torch.cuda.set_device(0)
pool = torch.cuda.MemPool()
torch.cuda.set_device(1)

with torch.cuda.use_mem_pool(pool):
    a = torch.tensor([10, 20], device='cuda')

with torch.cuda.use_mem_pool(pool):
    b = torch.tensor([30, 40], device='cuda')

print(f'{a=} {b=}')

error

RuntimeError: it->second->use_count > 0 INTERNAL ASSERT FAILED at "/pytorch/c10/cuda/CUDACachingAllocator.cpp":2225, please report a bug to PyTorch. 

full error log

[W322 10:20:27.881420799 Module.cpp:182] symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1...

Traceback (most recent call last):
  File "/host_home/primary_synced/tom_sglang_server/misc/adhoc_ac3369_mem_pool.py", line 10, in <module>
    with torch.cuda.use_mem_pool(pool):
  File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py", line 1086, in use_mem_pool
    _cuda_beginAllocateToPool(device_index, pool.id)
RuntimeError: it->second->use_count > 0 INTERNAL ASSERT FAILED at "/pytorch/c10/cuda/CUDACachingAllocator.cpp":2225, please report a bug to PyTorch. 
Exception raised from ensure_exists_and_incref_pool at /pytorch/c10/cuda/CUDACachingAllocator.cpp:2225 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::string> const> (), c10::SetStackTraceFetcher(std::function<std::string ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::string) from ??:0
#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) from ??:0
#7 c10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::beginAllocateToPool(signed char, std::pair<unsigned long long, unsigned long long>, std::function<bool (CUstream_st*)>) from :0
#8 pybind11::cpp_function::initialize<registerCudaPluggableAllocator(_object*)::{lambda(signed char, std::pair<unsigned long long, unsigned long long>)#21}, void, signed char, std::pair<unsigned long long, unsigned long long>, pybind11::name, pybind11::scope, pybind11::sibling>(registerCudaPluggableAllocator(_object*)::{lambda(signed char, std::pair<unsigned long long, unsigned long long>)#21}&&, void (*)(signed char, std::pair<unsigned long long, unsigned long long>), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from Module.cpp:0
#9 pybind11::cpp_function::dispatcher(_object*, _object*, _object*) from :0
#10 PyObject_CallFunctionObjArgs from ??:0
#11 _PyObject_MakeTpCall from ??:0
#12 _PyEval_EvalFrameDefault from ??:0
#13 _PyUnicode_ToDecimalDigit from ??:0
#14 PyCell_New from ??:0
#15 _PyEval_EvalFrameDefault from ??:0
#16 PyMethod_New from ??:0
#17 _PyEval_EvalFrameDefault from ??:0
#18 PyEval_EvalCode from ??:0
#19 PyEval_EvalCode from ??:0
#20 PyUnicode_Tailmatch from ??:0
#21 PyInit__collections from ??:0
#22 PyUnicode_Tailmatch from ??:0
#23 _PyRun_SimpleFileObject from ??:0
#24 _PyRun_AnyFileObject from ??:0
#25 Py_RunMain from ??:0
#26 Py_BytesMain from ??:0
#27 __libc_start_call_main from ./csu/../sysdeps/nptl/libc_start_call_main.h:58
#28 __libc_start_main_impl from ./csu/../csu/libc-start.c:392
#29 _start from ??:0

Versions

torch 2.6.0

cc @ptrblck @msaroufim @eqy

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions