Skip to content

Inference and MCMC

Ariel Shurygin edited this page Mar 12, 2024 · 2 revisions

Inference of parameters is done by the MechanisticInferer class, for a brief description of the 3 class types and the `MechanisticInferer class see the Software Design page.

Currently only MCMC with NUTS (No U-Turn Sampler) is supported by MechanisticInferer, but this may be expanded in the future.

Numpyro sampling and inferer.likelihood()

The MechanisticInferer class is designed to leverage the capabilities of NumPyro's MCMC and NUTS algorithms for Bayesian inference. Priors for parameters are defined in configuration files just as static values are, however the distributions are sampled from by calling numpyro.sample() and those sampled values are passed along to the ODEs to use. While it can be easier to think of sampled values as traditional python types, like floats, in reality values sampled by numpyro are actually jaxpr tracer dtypes. Jax supplies a NumPy-like interface for computations but is compiled in a lazy manner using Just-in-time compilation and accelerated via XLA accelerators. For an Jax 101 visit their docs here.

In order to infer parameters, you must have some observed data on which to fit, the MechanisticInferer class uses its likelihood() function, which is plugged into the sampler, to compare inferred parameters to observed data. First the function calls the MechanisticRunner to run ordinary differential equations (ODEs) with sampled parameters, then it must transform the returned Solution object into the same format as the observed data, and finally compare the two timelines.

Modifying inference

If users want to modify the inference procedure in some way, the likelihood function is the best place to start. A simple child class can inherit all necessary methods from MechanisticInferer and then override the likelihood function to compare observed output in a different way. There is often no reason to re-implement all of the initialization and parameter sampling methods.

In many scenarios, the observed data may have a different structure compared to the compartments obtained from the MechanisticRunner. To address this common situation, users have the flexibility to override the likelihood function within the MechanisticInferer class. This allows them to reshape or transform the data into the appropriate format for seamless comparison with the observed data.

It is crucial to note that, at some stage, users must call numpyro.sample(..., obs=obs_metrics) to inform the sampler about which data should be compared against the observed metrics. This step ensures that the sampler incorporates and considers the relevant data during parameter inference.

By providing this capability for customizing and adapting the likelihood function, users can effectively handle variations in data structures and ensure accurate comparisons between model outputs and observed data.