Skip to content

Save the kernel parameters after warmup #98

@rlouf

Description

@rlouf

mcx/mcx/inference/hmc.py

Lines 198 to 205 in 2a2b948

parameters = HMCParameters(
jnp.ones(initial_state.position.shape[0], dtype=jnp.int32)
* num_integration_steps,
step_size,
inverse_mass_matrix,
)
return last_chain_state, parameters, warmup_chain

We currently pass the parameters directly to the runtime; While the values of the parameters are passed in the Trace object, it would be convenient to update the Kernel's parameter values.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions