Skip to content

Commit

Permalink
jax vmap searchsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
clbonet committed Oct 16, 2024
1 parent 4862b2c commit 2331d19
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ def searchsorted(self, a, v, side='left'):
if a.ndim == 1:
return jnp.searchsorted(a, v, side)
else:
return jax.vmap(jnp.searchsorted, in_axes=[0, 1, None])(a, v, side)
return jax.vmap(lambda b, u: jnp.searchsorted(b, u, side))(a, v)

def flip(self, a, axis=None):
return jnp.flip(a, axis)
Expand Down

0 comments on commit 2331d19

Please sign in to comment.