-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why is a non-materialised FunctionLinearOperator recompiled for every new input? #83
Comments
So JAX will trace and compile separately for every call site in your code. If you're familiar with tradtional compiled languages, then basically what is happening is that absolutely every function call is inlined. The reason for this in large part is the way JAX does JIT compilation: it runs your code with "tracers" that record every operation that happens to them, but this means that it can't see Python-level constructs like functions or for loop and so on. See point 7 in this post. In an iterative solver, we wrap the multiple invocations inside of a In cases like stacked_inputs = jax.tree_util.tree_map(lambda *a: jnp.stack(a), vector1, vector2, vector3)
jax.vmap(operator.mv)(stacked_inputs) if each of |
In lineax.materialise() docs there is a following example:
In option 1, why is JAX tracing and recompiling the operator for every new input? Does this mean that while using an iterative solver, the operator is recompiled in every iteration? Is it possible to jit it only once?
The text was updated successfully, but these errors were encountered: