-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🐛 Describe the bug
This is a very complicated bug. I find that torch.cuda.set_device(0)
behaves differently from torch.cuda.set_device(1)
in terms of cuda context. Calling torch.cuda.set_device(0)
will not initialize a cuda context, but torch.cuda.set_device(1)
will.
import torch
torch.cuda.set_device(1) # 1 will initialize cuda context, while 0 will not
import ctypes
# Load the CUDA driver library
cuda = ctypes.CDLL("libcuda.so") # Linux
# cuda = ctypes.CDLL("nvcuda.dll") # Windows
# Define CUcontext as a pointer type
CUcontext = ctypes.c_void_p
# Define return and argument types for cuCtxGetCurrent
cuda.cuCtxGetCurrent.restype = ctypes.c_int # CUresult
cuda.cuCtxGetCurrent.argtypes = [ctypes.POINTER(CUcontext)]
# Create a CUcontext variable
ctx = CUcontext()
# Call cuCtxGetCurrent
result = cuda.cuCtxGetCurrent(ctypes.byref(ctx))
# Check the result
if result != 0:
print(f"cuCtxGetCurrent failed with error code {result}")
elif not ctx:
print("No active CUDA context.")
else:
print("Active CUDA context detected.")
After digging into more details, I think this is related to nvidia runtime API's undocumented behavior, and I'm cross-posting here from https://forums.developer.nvidia.com/t/whats-the-expected-behavior-of-calling-cudagetdevice-when-the-process-has-no-cuda-context/335784 .
I'm not sure if I should call it a pytorch bug.
This affects some usage of nvshmem, because it requires all processes have cuda context or all processes do not have cuda context, while the bug of torch.cuda.set_device()
leaves device 0 no cuda context but other devices have cuda context.
My workaround is to call data = torch.zeros(1024, device=f"cuda:{index}")
to make sure all devices have cuda context initialized.
If we want to fix it from pytorch side, we need to remove this line:
pytorch/c10/cuda/CUDAFunctions.cpp
Line 239 in a2b0b26
if (device == cur_device) { |
and call cudaSetDevice
directly so that torch.cuda.set_device
can behave the same as the cuda runtime API counterpart.
Versions
pytorch 2.7.0