Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial for neural adaptive smc #11

Open
fehiepsi opened this issue Aug 1, 2023 · 3 comments
Open

Tutorial for neural adaptive smc #11

fehiepsi opened this issue Aug 1, 2023 · 3 comments

Comments

@fehiepsi
Copy link
Collaborator

fehiepsi commented Aug 1, 2023

Neural Adaptive SMC, Gu etc. is a nice framework that allows us to train proposals for non-linear state space models. We can use forward KL in a nested variational inference scheme because both derivations provide similar grad estimations.

For state space models, we typically don't have reverse kernel because the state dimension grows over time. This example will greatly illustrate how to deal with growing-dimensional variables in JAX. The trick will be to prepare a full dimensional variable and perform index update in each smc step.

@deoxyribose
Copy link

Hi @fehiepsi,

I'd like to take a stab at this, but could use a little help getting started. Given the model from 5.1 in the paper

def ssm(xs = None, T_max = 1000):
    z_0 = numpyro.sample("z_0", dist.Normal(0, 5))
    z_t_m1 = z_0
    for t in range(1, T_max):
        z_t_loc = z_t_m1 / 2 + 25 * z_t_m1 / (1 + z_t_m1 ** 2) + 8 * jnp.cos(1.2 * t)
        z_t = numpyro.sample(f"z_{t}", dist.Normal(z_t_loc, jnp.sqrt(10)))
        x_t = numpyro.sample(f"x_{t}", dist.Normal(z_t ** 2 / 20, 1), obs=xs[t - 1] if xs is not None else None)
        z_t_m1 = z_t
    return x_t

I figure what needs to be implemented is an LSTM-based mixture density network which parametrizes q(z_t | z_1:t-1, x_1:t) (or q(v_t | z_1:t-1, x_1:t, f(z_t-1, t)), since that works better according to the paper). Then make a list of proposals, one for each z_t, each of which is sampled and used to update the full dimensional variable using zs.at[t].set(sample) ? Would the targets simply be the model above, conditioned on x_1:t ?

I will try to code something up, but some guidance would be very helpful!

@fehiepsi
Copy link
Collaborator Author

fehiepsi commented Oct 6, 2024

Great to hear that you are interested in this issue, @deoxyribose! The main theme of using coix is to define subprograms, then combine them together. Each subprogram is modelled by using a PPL, e.g. numpyro.

Your model is already in the form of a "combined" one. You can factor it out by creating subprograms: init_proposal, proposal_t, target_t. Here target_t is your body function of your for loop. proposal_t is your lstm-based model. target_t defines the joint distribution of p(z_t,x_t|...) while proposal_t is q(z_t|...). Let's walk through this step first. Please let me know if you have any question.

The next step is to combine those programs together. You can use the algorithm in coix.algo.nasmc or even better, combine them in your own way. But let's discuss this later.

@deoxyribose
Copy link

deoxyribose commented Oct 8, 2024

Thanks @fehiepsi! I've tried to do what you suggest here: https:/deoxyribose/nasmc/blob/master/nasmc.ipynb, but I don't have a good handle on how it's supposed to look like yet. I'm not sure whether init_proposal and z_0 sampling should be separate from ssm_proposal and ssm_target respectively. In any case, I don't know how to progress from this current error message, but I figure I probably have some misconceptions apparent from the code which you could clear up :)

Edit 10/10/24: At present, I can run training if jit compilation is turned off, but judging by the metrics, it's not very stable and eventually crashes. I think I need to break the problem down to smaller tests than just running training, but I'm not sure what that could be.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants