Skip to content

Orthogonal Initializer raises gpusolverDnCreate(&handle) failed: cuSolver internal error #23616

@brianorbrain

Description

@brianorbrain

Description

I am having issues initializing a Flax.linen neural network when running with GPU support. I have narrowed it down to the flax.linen.initializers.orthogonal. Running the below code will result in a:
RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

However running the code in another venv with only CPU support it runs just fine. And secondly running it without the orthogonal kernel initializer it runs just fine.
The jax is installed using pip install -U "jax[cuda12]"

I have attached a minimal example that will raise the issue.


import os

os.environ['JAX_TRACEBACK_FILTERING'] = 'off'

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal




class SingleLayer(nn.Module):
    @nn.compact
    def __call__(self, x):
        layer = nn.Dense(64, kernel_init=orthogonal())(x)
        return layer


network = SingleLayer()
init_x = jnp.zeros(128)
network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
print(network_params)


/home/brain/Tensor/JaxRL/.venv/bin/python /home/brain/Tensor/JaxRL/flax_lax.py 
Traceback (most recent call last):
  File "/home/brain/Tensor/JaxRL/flax_lax.py", line 22, in <module>
    network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
  File "/home/brain/Tensor/JaxRL/flax_lax.py", line 16, in __call__
    layer = nn.Dense(64, kernel_init=orthogonal())(x)
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/linear.py", line 251, in __call__
    kernel = self.param(
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/nn/initializers.py", line 611, in init
    Q, R = jnp.linalg.qr(A)
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/numpy/linalg.py", line 1300, in qr
    q, r = lax_linalg.qr(a, full_matrices=full_matrices)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/brain/Tensor/JaxRL/flax_lax.py", line 22, in <module>
    network_params = network.init(rngs=jax.random.PRNGKey(0), x=init_x)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 2442, in init
    _, v_out = self.init_with_output(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 2294, in init_with_output
    return init_with_output(
           ^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 1144, in wrapper
    return apply(fn, mutable=mutable, flags=init_flags)(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 1108, in wrapper
    y = fn(root, *args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 3081, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1211, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/flax_lax.py", line 16, in __call__
    layer = nn.Dense(64, kernel_init=orthogonal())(x)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1211, in _call_wrapped_method
    y = run_fun(self, *args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/linear.py", line 251, in __call__
    kernel = self.param(
             ^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/linen/module.py", line 1867, in param
    v = self.scope.param(name, init_fn, *init_args, unbox=unbox, **init_kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/flax/core/scope.py", line 997, in param
    value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/nn/initializers.py", line 611, in init
    Q, R = jnp.linalg.qr(A)
           ^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 2782, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 443, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 949, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
    return xc._xla.pjit(
           ^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1643, in _pjit_call_impl_python
    compiled = _resolve_and_lower(
               ^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1610, in _resolve_and_lower
    lowered = _pjit_lower(
              ^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1748, in _pjit_lower
    return _pjit_lower_cached(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1769, in _pjit_lower_cached
    return pxla.lower_sharding_computation(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2230, in lower_sharding_computation
    nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1950, in _cached_lowering_to_hlo
    lowering_result = mlir.lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1132, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1590, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1805, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1921, in lower_per_platform
    output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2036, in f_lowered
    out, tokens = jaxpr_subcomp(
                  ^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1805, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1921, in lower_per_platform
    output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib/python3.12/site-packages/jax/_src/lax/linalg.py", line 1757, in _geqrf_cpu_gpu_lowering
    a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brain/Tensor/JaxRL/.venv/lib64/python3.12/site-packages/jaxlib/gpu_solver.py", line 164, in _geqrf_hlo
    lwork, opaque = gpu_solver.build_geqrf_descriptor(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

Thanks,
Brian

System info (python version, jaxlib version, accelerator, etc.)

jax.print_environment_info()
jax:    0.4.32
jaxlib: 0.4.32
numpy:  2.1.1
python: 3.12.5 (main, Aug 23 2024, 00:00:00) [GCC 14.2.1 20240801 (Red Hat 14.2.1-1)]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='programming-desktop', release='6.10.8-200.fc40.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Sep  4 21:41:11 UTC 2024', machine='x86_64')
$ nvidia-smi
Thu Sep 12 22:41:39 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
| 59%   57C    P0            180W /  390W |    1759MiB /  24576MiB |     34%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2921      G   /usr/libexec/Xorg                             456MiB |
|    0   N/A  N/A      3172    C+G   ...libexec/gnome-remote-desktop-daemon        258MiB |
|    0   N/A  N/A      3255      G   /usr/bin/gnome-shell                           76MiB |
|    0   N/A  N/A      4696      G   /usr/bin/nautilus                              24MiB |
|    0   N/A  N/A      4978      G   /usr/lib64/firefox/firefox                    190MiB |
|    0   N/A  N/A     64410      C   ...ensor/Sin_PPO_test/.venv/bin/python        366MiB |
|    0   N/A  N/A     98376      G   ...erProcess --variations-seed-version         16MiB |
|    0   N/A  N/A    111528      C   ...brain/Tensor/JaxRL/.venv/bin/python        256MiB |
+-----------------------------------------------------------------------------------------+
nvidia-cublas-cu12==12.6.1.4
nvidia-cuda-cupti-cu12==12.6.68
nvidia-cuda-nvcc-cu12==12.6.68
nvidia-cuda-runtime-cu12==12.6.68
nvidia-cudnn-cu12==9.4.0.58
nvidia-cufft-cu12==11.2.6.59
nvidia-cusolver-cu12==11.6.4.69
nvidia-cusparse-cu12==12.5.3.3
nvidia-nccl-cu12==2.23.4
nvidia-nvjitlink-cu12==12.6.68

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions