Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression colab example - JAX implementation #8

Open
conorhassan opened this issue Apr 21, 2024 · 2 comments
Open

Regression colab example - JAX implementation #8

conorhassan opened this issue Apr 21, 2024 · 2 comments

Comments

@conorhassan
Copy link
Contributor

conorhassan commented Apr 21, 2024

Hi, great paper!

I implemented the regression colab example (or at least the first VBLLMLP example) in JAX. I wrote the equiv. of the distributions.py by subclassing numpyro.distributions and implemented the Regression and VBLLMLP classes in flax. The model is training but the uncertainty bands are a bit of a mess.

Are there any plans to implement in JAX? Would be keen to maybe help out a little if there was. Would be keen to find the errors in my colab somehow too...

Here is the colab: https://colab.research.google.com/drive/1Rh895u0jP9xEpK7eMOz9JHUX_2CluyLO?usp=sharing

Thanks,
Conor

@2592761383
Copy link

2592761383 commented Apr 25, 2024

image
Absolutely it is a wonderful work.

But I have the same problem when performing the classcification task. the UQ is a bit of a mess.
Which parameter do you think is the most important?

Best,

@2592761383
Copy link

image Absolutely it is a wonderful work.

But I have the same problem when performing the classcification task. the UQ is a bit of a mess. Which parameter do you think is the most important?

Best,

This is my setting, where 7232 is the num of total samples
self.output = vbll.DiscClassification(64, 2, 1.0 / 7232, parameterization='diagonal', prior_scale= 1.0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants