Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is lineax slower than the linear solver in JAX? #52

Open
ToshiyukiBandai opened this issue Oct 20, 2023 · 2 comments
Open

Is lineax slower than the linear solver in JAX? #52

ToshiyukiBandai opened this issue Oct 20, 2023 · 2 comments
Labels
question User queries

Comments

@ToshiyukiBandai
Copy link

Hi, thank you for creating the awesome libraries in JAX. I started to use lineax recently and compared it with the linear solver in JAX. The code below resulted in 931 us for lineax and 171 us for jnp.linalg.solve. Is there anything wrong with my implementation? Or, should I just stick to jnp.linalg.solve? No way to use _gesv Fortran routine through lineax?

from jax import random
import jax.numpy as jnp
import lineax as lx

matrix_key, vector_key = random.split(random.PRNGKey(0))
matrix = random.normal(matrix_key, (10, 10))
vector = random.normal(vector_key, (10,))

operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector)

%timeit lx.linear_solve(operator, vector, solver=lx.LU())

%timeit jnp.linalg.solve(matrix, vector)
@patrick-kidger
Copy link
Owner

Looks like the overhead is from two things:

  1. Error-checking on the Lineax output. By default Lineax has an extra check that the return doesn't have NaNs etc., i.e. that the solve was successful. This can be disabled by passing linear_solve(..., throw=False).

  2. Pytree flattening/unflattening across JIT boundaries. matrix and vector are simpler PyTrees than operator and lx.LU().

With this benchmark I obtain identical performance:

import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit

matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 10))
vector = jr.normal(vector_key, (10,))

@jax.jit
def solve_lineax(matrix, vector):
    operator = lx.MatrixLinearOperator(matrix)
    sol = lx.linear_solve(operator, vector, throw=False)
    return sol.value

@jax.jit
def solve_jax(matrix, vector):
    return jnp.linalg.solve(matrix, vector)

time_lineax = lambda: jax.block_until_ready(solve_lineax(matrix, vector))
time_jax = lambda: jax.block_until_ready(solve_jax(matrix, vector))

print(min(timeit.repeat(time_jax, number=1, repeat=10)))
print(min(timeit.repeat(time_lineax, number=1, repeat=10)))

@patrick-kidger patrick-kidger added the question User queries label Oct 21, 2023
@ToshiyukiBandai
Copy link
Author

Hi Patrick,

I got the same results too. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants