Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Including user-defined Jacobian #17

Open
Justin-Tan opened this issue Oct 19, 2023 · 1 comment
Open

Including user-defined Jacobian #17

Justin-Tan opened this issue Oct 19, 2023 · 1 comment
Labels
question User queries

Comments

@Justin-Tan
Copy link

Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports vmap for some time. Currently I am using an external call to scipy.optimize.root together with the multiprocessing library, which is quite slow.

The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?

For reference, this is my MWE in case I am not doing things efficiently:

from jax import jit, jacfwd, vmap, random
import optimistix as optx

def fn(y, b):
    return (y-b)**2

M = 1024
key = random.PRNGKey(42)
key, key_ = random.split(key, 2)

y = random.normal(key, (M,))
b = random.normal(key_, (M,))
sol = optx.root_find(vmap(fn), solver, y, b)
@patrick-kidger
Copy link
Owner

patrick-kidger commented Oct 20, 2023

Okay, many things to respond to here!

Speed

With respect to the speed, for your JAX code are you:

  • JIT'ing everything;
  • excluding compile time;
  • including block_until_ready?

In practice this means writing things out something like:

@jax.jit
def run(y, b):
    sol = optax.root_find(vmap(fn), solver, y, b)
    return sol.value

run(y, b)  # compile
times = timeit.repeat(lambda: jax.block_until_ready(run(y, b)), number=1, repeat=10)
print(min(times))

Recalculating Jacobians

You commented on calculating the Jacobian afresh every iteration. If using the typical Newton algorithm then this is expected (desired) behaviour. But if you're saying that you'd prefer to use a quasi-Newton algorithm like the chord method (that computes the Jacobian once at the initial point and then re-uses it), then there is optx.Chord as well.

Analytical Jacobians

You commented on supplying an analytical Jacobian. This isn't necessary, as the analytical Jacobian is actually already derived from fn automatically using autodifferentiation. Unless the autodiff does something surprisingly inefficient, then providing one manually wouldn't meaningfully improve things there.

Custom Jacobians

If despite everything you really do want to provide a custom Jacobian, then this can be done using jax.custom_jvp. By wrapping your fn in a jax.custom_vjp, then you can override how JAX calculates autoderivatives of your code. (And this will then be picked up by the autodiff used by Optimistix to calculate the Jacobian.)

Does the above help?

@patrick-kidger patrick-kidger added the question User queries label Oct 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants