-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is lineax slower than the linear solver in JAX? #52
Comments
Looks like the overhead is from two things:
With this benchmark I obtain identical performance: import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 10))
vector = jr.normal(vector_key, (10,))
@jax.jit
def solve_lineax(matrix, vector):
operator = lx.MatrixLinearOperator(matrix)
sol = lx.linear_solve(operator, vector, throw=False)
return sol.value
@jax.jit
def solve_jax(matrix, vector):
return jnp.linalg.solve(matrix, vector)
time_lineax = lambda: jax.block_until_ready(solve_lineax(matrix, vector))
time_jax = lambda: jax.block_until_ready(solve_jax(matrix, vector))
print(min(timeit.repeat(time_jax, number=1, repeat=10)))
print(min(timeit.repeat(time_lineax, number=1, repeat=10))) |
Hi Patrick, I got the same results too. Thank you! |
Hi, thank you for creating the awesome libraries in JAX. I started to use lineax recently and compared it with the linear solver in JAX. The code below resulted in 931 us for lineax and 171 us for jnp.linalg.solve. Is there anything wrong with my implementation? Or, should I just stick to jnp.linalg.solve? No way to use _gesv Fortran routine through lineax?
The text was updated successfully, but these errors were encountered: