Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is a non-materialised FunctionLinearOperator recompiled for every new input? #83

Open
gawelk opened this issue Feb 23, 2024 · 1 comment
Labels
question User queries

Comments

@gawelk
Copy link

gawelk commented Feb 23, 2024

In lineax.materialise() docs there is a following example:

operator = lx.FunctionLinearOperator(large_function, ...)

# Option 1
out1 = operator.mv(vector1)  # Traces and compiles `large_function`
out2 = operator.mv(vector2)  # Traces and compiles `large_function` again!
out3 = operator.mv(vector3)  # Traces and compiles `large_function` a third time!
# All that compilation might lead to long compile times.
# If `large_function` takes a long time to run, then this might also lead to long
# run times.

# Option 2
operator = lx.materialise(operator)  # Traces and compiles `large_function` and
                                       # stores the result as a matrix.
out1 = operator.mv(vector1)  # Each of these just computes a matrix-vector product
out2 = operator.mv(vector2)  # against the stored matrix.
out3 = operator.mv(vector3)  #
# Now, `large_function` is only compiled once, and only ran once.
# However, storing the matrix might take a lot of memory, and the initial
# computation may-or-may-not take a long time to run.

In option 1, why is JAX tracing and recompiling the operator for every new input? Does this mean that while using an iterative solver, the operator is recompiled in every iteration? Is it possible to jit it only once?

@patrick-kidger
Copy link
Owner

So JAX will trace and compile separately for every call site in your code. If you're familiar with tradtional compiled languages, then basically what is happening is that absolutely every function call is inlined.

The reason for this in large part is the way JAX does JIT compilation: it runs your code with "tracers" that record every operation that happens to them, but this means that it can't see Python-level constructs like functions or for loop and so on. See point 7 in this post.

In an iterative solver, we wrap the multiple invocations inside of a jax.lax.while_loop. As this is a JAX-level loop construct, it knows about the loop, and so it only needs to compile its body function once.

In cases like Option 1 above, then another option (other than lx.materialise) is to call either

stacked_inputs = jax.tree_util.tree_map(lambda *a: jnp.stack(a), vector1, vector2, vector3)
jax.vmap(operator.mv)(stacked_inputs)

if each of vector{1,2,3} are independent -- or to do something similar with lax.scan, if they depend on each other.

@patrick-kidger patrick-kidger added the question User queries label Feb 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants