Skip to content

tests/test_least_squares.py::test_least_squares[rosenbrock-minimum2-init2-args2-solver15] fails with jax==0.7.1 #160

@GaetanLepage

Description

@GaetanLepage

Context: bumping jax from 0.6.2 to 0.7.1 in nixpkgs.

Since updating jax to 0.7.1, I noticed that the following test was failing:

_________ test_least_squares[rosenbrock-minimum2-init2-args2-solver15] _________
[gw15] linux -- Python 3.12.11 /nix/store/18l19gnl8dq75v6knghbcgkrikmpc0yv-python3-3.12.11/bin/python3.12

solver = OptaxMinimiser(
  optim=GradientTransformationExtraArgs(
    init=_Closure(
      fn=<equinox.internal._closure_to_pyt...        )
        ),
      )
    )
  ),
  rtol=1e-08,
  atol=1e-08,
  norm=<function max_norm>,
  verbose=frozenset()
)
_fn = <function rosenbrock at 0x7fff5c3113a0>
minimum = Array(0., dtype=float64, weak_type=True)
init = (Array([[1.5, 1.5, 1.5, 1.5],
       [1.5, 1.5, 1.5, 1.5]], dtype=float64), {'a': Array([[[1.5, 1.5],
        [1.5, 1.5],
        [1.5, 1.5]],

       [[1.5, 1.5],
        [1.5, 1.5],
        [1.5, 1.5]]], dtype=float64)}, ())
args = Array(1., dtype=float64, weak_type=True)

    @pytest.mark.parametrize("solver", least_squares_optimisers)
    @pytest.mark.parametrize("_fn, minimum, init, args", least_squares_fn_minima_init_args)
    def test_least_squares(solver, _fn, minimum, init, args):
        atol = rtol = 1e-4
        has_aux = random.choice([True, False])
        if has_aux:
            fn = lambda x, args: (_fn(x, args), smoke_aux)
        else:
            fn = _fn

        if isinstance(solver, optx.OptaxMinimiser):
            context = jax.numpy_dtype_promotion("standard")
        else:
            context = contextlib.nullcontext()
        with context:
            optx_argmin = optx.least_squares(
                fn, solver, init, has_aux=has_aux, args=args, max_steps=10_000, throw=False
            ).value
        out = fn(optx_argmin, args)
        if has_aux:
            residual, _ = out
        else:
            residual = out
        optx_min = jtu.tree_reduce(
            lambda x, y: x + y, jtu.tree_map(lambda x: jnp.sum(x**2), residual)
        )
>       assert tree_allclose(optx_min, minimum, atol=atol, rtol=rtol)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = tree_allclose(Array(0.12993518, dtype=float64), Array(0., dtype=float64, weak_type=True), atol=0.0001, rtol=0.0001)

tests/test_least_squares.py:53: AssertionError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions