-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
I am having issues initializing a Flax.linen neural network when running with GPU support. I have narrowed it down to the flax.linen.initializers.orthogonal. Running the below code will result in a:
RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
However running the code in another venv with only CPU support it runs just fine. And secondly running it without the orthogonal kernel initializer it runs just fine.
The jax is installed using pip install -U "jax[cuda12]"
I have attached a minimal example that will raise the issue.
import os
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
class SingleLayer(nn.Module):
@nn.compact
def __call__(self, x):
layer = nn.Dense(64, kernel_init=orthogonal())(x)
return layer
network = SingleLayer()
init_x = jnp.zeros(128)
network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
print(network_params)
/home/brain/Tensor/JaxRL/.venv/bin/python /home/brain/Tensor/JaxRL/flax_lax.py
Traceback (most recent call last):
File "/home/brain/Tensor/JaxRL/flax_lax.py", line 22, in <module>
network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
File "/home/brain/Tensor/JaxRL/flax_lax.py", line 16, in __call__
layer = nn.Dense(64, kernel_init=orthogonal())(x)
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/linear.py", line 251, in __call__
kernel = self.param(
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/nn/initializers.py", line 611, in init
Q, R = jnp.linalg.qr(A)
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/numpy/linalg.py", line 1300, in qr
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/brain/Tensor/JaxRL/flax_lax.py", line 22, in <module>
network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 2442, in init
_, v_out = self.init_with_output(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 2294, in init_with_output
return init_with_output(
^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 1144, in wrapper
return apply(fn, mutable=mutable, flags=init_flags)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 1108, in wrapper
y = fn(root, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 3081, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1211, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/flax_lax.py", line 16, in __call__
layer = nn.Dense(64, kernel_init=orthogonal())(x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1211, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/linear.py", line 251, in __call__
kernel = self.param(
^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1867, in param
v = self.scope.param(name, init_fn, *init_args, unbox=unbox, **init_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 997, in param
value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/nn/initializers.py", line 611, in init
Q, R = jnp.linalg.qr(A)
^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 332, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 2782, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 949, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
return xc._xla.pjit(
^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1643, in _pjit_call_impl_python
compiled = _resolve_and_lower(
^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1610, in _resolve_and_lower
lowered = _pjit_lower(
^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1748, in _pjit_lower
return _pjit_lower_cached(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1769, in _pjit_lower_cached
return pxla.lower_sharding_computation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2230, in lower_sharding_computation
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1950, in _cached_lowering_to_hlo
lowering_result = mlir.lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1132, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1590, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1805, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1921, in lower_per_platform
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2036, in f_lowered
out, tokens = jaxpr_subcomp(
^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1805, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1921, in lower_per_platform
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/lax/linalg.py", line 1757, in _geqrf_cpu_gpu_lowering
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brain/Tensor/JaxRL/.venv/lib64/python3.12/site-packages/jaxlib/gpu_solver.py", line 164, in _geqrf_hlo
lwork, opaque = gpu_solver.build_geqrf_descriptor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
Thanks,
Brian
System info (python version, jaxlib version, accelerator, etc.)
jax.print_environment_info()
jax: 0.4.32
jaxlib: 0.4.32
numpy: 2.1.1
python: 3.12.5 (main, Aug 23 2024, 00:00:00) [GCC 14.2.1 20240801 (Red Hat 14.2.1-1)]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='programming-desktop', release='6.10.8-200.fc40.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Sep 4 21:41:11 UTC 2024', machine='x86_64')
$ nvidia-smi
Thu Sep 12 22:41:39 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:01:00.0 On | N/A |
| 59% 57C P0 180W / 390W | 1759MiB / 24576MiB | 34% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2921 G /usr/libexec/Xorg 456MiB |
| 0 N/A N/A 3172 C+G ...libexec/gnome-remote-desktop-daemon 258MiB |
| 0 N/A N/A 3255 G /usr/bin/gnome-shell 76MiB |
| 0 N/A N/A 4696 G /usr/bin/nautilus 24MiB |
| 0 N/A N/A 4978 G /usr/lib64/firefox/firefox 190MiB |
| 0 N/A N/A 64410 C ...ensor/Sin_PPO_test/.venv/bin/python 366MiB |
| 0 N/A N/A 98376 G ...erProcess --variations-seed-version 16MiB |
| 0 N/A N/A 111528 C ...brain/Tensor/JaxRL/.venv/bin/python 256MiB |
+-----------------------------------------------------------------------------------------+
nvidia-cublas-cu12==12.6.1.4
nvidia-cuda-cupti-cu12==12.6.68
nvidia-cuda-nvcc-cu12==12.6.68
nvidia-cuda-runtime-cu12==12.6.68
nvidia-cudnn-cu12==9.4.0.58
nvidia-cufft-cu12==11.2.6.59
nvidia-cusolver-cu12==11.6.4.69
nvidia-cusparse-cu12==12.5.3.3
nvidia-nccl-cu12==2.23.4
nvidia-nvjitlink-cu12==12.6.68
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working