Skip to content
21 changes: 12 additions & 9 deletions python/ray/_private/accelerators/amd_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,24 @@ def get_resource_name() -> str:

@staticmethod
def get_visible_accelerator_ids_env_var() -> str:
return HIP_VISIBLE_DEVICES_ENV_VAR

@staticmethod
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
if "ROCR_VISIBLE_DEVICES" in os.environ:
raise RuntimeError(
f"Please use {HIP_VISIBLE_DEVICES_ENV_VAR} instead of ROCR_VISIBLE_DEVICES"
)

hip_val = os.environ.get(HIP_VISIBLE_DEVICES_ENV_VAR, None)
if cuda_val := os.environ.get(CUDA_VISIBLE_DEVICES_ENV_VAR, None):
assert (
hip_val == cuda_val
), f"Inconsistant values found. Please use either {HIP_VISIBLE_DEVICES_ENV_VAR} or {CUDA_VISIBLE_DEVICES_ENV_VAR}."
env_var = HIP_VISIBLE_DEVICES_ENV_VAR
if cuda_val := os.environ.get(CUDA_VISIBLE_DEVICES_ENV_VAR, None) is not None:
if hip_val := os.environ.get(HIP_VISIBLE_DEVICES_ENV_VAR, None) is None:
env_var = CUDA_VISIBLE_DEVICES_ENV_VAR
elif hip_val != cuda_val:
raise ValueError(
f"Inconsistant values found. Please use either {HIP_VISIBLE_DEVICES_ENV_VAR} or {CUDA_VISIBLE_DEVICES_ENV_VAR}."
)

return env_var

@staticmethod
def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
amd_visible_devices = os.environ.get(
AMDGPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
)
Expand Down
33 changes: 22 additions & 11 deletions python/ray/tests/accelerators/test_amd_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
from ray._private.accelerators import get_accelerator_manager_for_resource


@pytest.mark.parametrize(
"visible_devices_env_var", ("HIP_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES")
)
@patch(
"ray._private.accelerators.AMDGPUAcceleratorManager.get_current_node_num_accelerators", # noqa: E501
return_value=4,
)
def test_visible_amd_gpu_ids(mock_get_num_accelerators, monkeypatch, shutdown_only):
monkeypatch.setenv("HIP_VISIBLE_DEVICES", "0,1,2")
def test_visible_amd_gpu_ids(
mock_get_num_accelerators, visible_devices_env_var, monkeypatch, shutdown_only
):
monkeypatch.setenv(visible_devices_env_var, "0,1,2")
# Delete the cache so it can be re-populated the next time
# we call get_accelerator_manager_for_resource
del get_accelerator_manager_for_resource._resource_name_to_accelerator_manager
Expand Down Expand Up @@ -45,43 +50,49 @@ def test_visible_amd_gpu_type_bad_device_id(mock_get_num_accelerators, shutdown_
assert AMDGPUAcceleratorManager.get_current_node_accelerator_type() is None


def test_get_current_process_visible_accelerator_ids(monkeypatch):
monkeypatch.setenv("HIP_VISIBLE_DEVICES", "0,1,2")
@pytest.mark.parametrize(
"visible_devices_env_var", ("HIP_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES")
)
def test_get_current_process_visible_accelerator_ids(
visible_devices_env_var, monkeypatch
):
monkeypatch.setenv(visible_devices_env_var, "0,1,2")
assert AMDGPUAcceleratorManager.get_current_process_visible_accelerator_ids() == [
"0",
"1",
"2",
]

monkeypatch.setenv("HIP_VISIBLE_DEVICES", "0,2,7")
monkeypatch.setenv(visible_devices_env_var, "0,2,7")
assert AMDGPUAcceleratorManager.get_current_process_visible_accelerator_ids() == [
"0",
"2",
"7",
]

monkeypatch.setenv("HIP_VISIBLE_DEVICES", "")
monkeypatch.setenv(visible_devices_env_var, "")
assert AMDGPUAcceleratorManager.get_current_process_visible_accelerator_ids() == []

del os.environ["HIP_VISIBLE_DEVICES"]
del os.environ[visible_devices_env_var]
assert (
AMDGPUAcceleratorManager.get_current_process_visible_accelerator_ids() is None
)


def test_set_current_process_visible_accelerator_ids():
AMDGPUAcceleratorManager.set_current_process_visible_accelerator_ids(["0"])
assert os.environ["HIP_VISIBLE_DEVICES"] == "0"
env_var = AMDGPUAcceleratorManager.get_visible_accelerator_ids_env_var()
assert os.environ[env_var] == "0"

AMDGPUAcceleratorManager.set_current_process_visible_accelerator_ids(["0", "1"])
assert os.environ["HIP_VISIBLE_DEVICES"] == "0,1"
assert os.environ[env_var] == "0,1"

AMDGPUAcceleratorManager.set_current_process_visible_accelerator_ids(
["0", "1", "7"]
)
assert os.environ["HIP_VISIBLE_DEVICES"] == "0,1,7"
assert os.environ[env_var] == "0,1,7"

del os.environ["HIP_VISIBLE_DEVICES"]
del os.environ[env_var]


if __name__ == "__main__":
Expand Down