-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Including user-defined Jacobian #17
Comments
Okay, many things to respond to here! Speed With respect to the speed, for your JAX code are you:
In practice this means writing things out something like: @jax.jit
def run(y, b):
sol = optax.root_find(vmap(fn), solver, y, b)
return sol.value
run(y, b) # compile
times = timeit.repeat(lambda: jax.block_until_ready(run(y, b)), number=1, repeat=10)
print(min(times)) Recalculating Jacobians You commented on calculating the Jacobian afresh every iteration. If using the typical Newton algorithm then this is expected (desired) behaviour. But if you're saying that you'd prefer to use a quasi-Newton algorithm like the chord method (that computes the Jacobian once at the initial point and then re-uses it), then there is Analytical Jacobians You commented on supplying an analytical Jacobian. This isn't necessary, as the analytical Jacobian is actually already derived from Custom Jacobians If despite everything you really do want to provide a custom Jacobian, then this can be done using Does the above help? |
Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports
vmap
for some time. Currently I am using an external call toscipy.optimize.root
together with the multiprocessing library, which is quite slow.The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?
For reference, this is my MWE in case I am not doing things efficiently:
The text was updated successfully, but these errors were encountered: