Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log_prob of blocked programs #81

Open
jhn-nt opened this issue Aug 29, 2024 · 0 comments
Open

log_prob of blocked programs #81

jhn-nt opened this issue Aug 29, 2024 · 0 comments

Comments

@jhn-nt
Copy link

jhn-nt commented Aug 29, 2024

Hello,

First, thanks for developing such an amazing package.
I am newbie to oryx and was playing around with its functionalities,
perhaps naively I had been attempting to evalute log_probs of blocker porblems, as below:

from jax.random import split
from oryx.core import ppl
import tensorflow_probability.substrates.jax.distributions as tfd

def latent_normal(key):
    z_key,x_key= split(key)
    z=ppl.random_variable(tfd.Normal(0,1),name="z")(z_key)
    return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)


blocked=ppl.block(latent_normal,names=["z"])
ppl.joint_log_prob(blocked)({"x":10})

However, it returns:
{
"name": "ValueError",
"message": "Cannot compute log_prob of function.",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 12
8 return ppl.random_variable(tfd.Normal(z,1e-1),name="x")(x_key)
11 blocked=ppl.block(latent_normal,names=["z"])
---> 12 ppl.joint_log_prob(blocked)({"x":10})

File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:71, in log_prob..wrapped(sample, *args, **kwargs)
67 flat_incells = [
68 InverseAndILDJ.unknown(trace_util.get_shaped_aval(dummy_seed))
69 ] + [InverseAndILDJ.new(val) for val in flat_inargs]
70 flat_outcells = [InverseAndILDJ.new(a) for a in flat_outargs]
---> 71 return log_prob_jaxpr(jaxpr.jaxpr, constcells, flat_incells, flat_outcells)

File /opt/conda/lib/python3.11/site-packages/oryx/core/interpreters/log_prob.py:128, in log_prob_jaxpr(jaxpr, constcells, flat_incells, flat_outcells)
118 _, final_log_prob = propagate.propagate(
119 InverseAndILDJ,
120 log_prob_rules,
(...)
125 reducer=reducer,
126 initial_state=0.)
127 if final_log_prob is failed_log_prob:
--> 128 raise ValueError('Cannot compute log_prob of function.')
129 return final_log_prob

ValueError: Cannot compute log_prob of function."
}

Am I missing something?

Thanks again

Very Best
Giovanni

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant