Skip to content

torch.cuda.set_device(0) behaves differently from torch.cuda.set_device(1) in terms of cuda context #155668

@youkaichao

Description

@youkaichao

🐛 Describe the bug

This is a very complicated bug. I find that torch.cuda.set_device(0) behaves differently from torch.cuda.set_device(1) in terms of cuda context. Calling torch.cuda.set_device(0) will not initialize a cuda context, but torch.cuda.set_device(1) will.

import torch
torch.cuda.set_device(1) # 1 will initialize cuda context, while 0 will not
import ctypes

# Load the CUDA driver library
cuda = ctypes.CDLL("libcuda.so")  # Linux
# cuda = ctypes.CDLL("nvcuda.dll")  # Windows

# Define CUcontext as a pointer type
CUcontext = ctypes.c_void_p

# Define return and argument types for cuCtxGetCurrent
cuda.cuCtxGetCurrent.restype = ctypes.c_int  # CUresult
cuda.cuCtxGetCurrent.argtypes = [ctypes.POINTER(CUcontext)]

# Create a CUcontext variable
ctx = CUcontext()

# Call cuCtxGetCurrent
result = cuda.cuCtxGetCurrent(ctypes.byref(ctx))

# Check the result
if result != 0:
    print(f"cuCtxGetCurrent failed with error code {result}")
elif not ctx:
    print("No active CUDA context.")
else:
    print("Active CUDA context detected.")

After digging into more details, I think this is related to nvidia runtime API's undocumented behavior, and I'm cross-posting here from https://forums.developer.nvidia.com/t/whats-the-expected-behavior-of-calling-cudagetdevice-when-the-process-has-no-cuda-context/335784 .

I'm not sure if I should call it a pytorch bug.

This affects some usage of nvshmem, because it requires all processes have cuda context or all processes do not have cuda context, while the bug of torch.cuda.set_device() leaves device 0 no cuda context but other devices have cuda context.

My workaround is to call data = torch.zeros(1024, device=f"cuda:{index}") to make sure all devices have cuda context initialized.

If we want to fix it from pytorch side, we need to remove this line:

if (device == cur_device) {

and call cudaSetDevice directly so that torch.cuda.set_device can behave the same as the cuda runtime API counterpart.

Versions

pytorch 2.7.0

cc @ptrblck @msaroufim @eqy @jerryzh168

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions