-
Notifications
You must be signed in to change notification settings - Fork 499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sinkhorn2
and its functorch.vmap
compatibility
#482
Comments
That is a good point but POT is implemeted in pure python with backend and geting rid tof conditional flows is going to be a pain. Note that for what you want to compute (P sinkhorn in paralell with the same cost C) one does not need to do a loop/vmap and the sinkhorns can be impelmmented with already paralell matrix products with very little change in the sinkhorn_knopp function. We do not provide it in POT (maybe we will one day but we need to find the proper API) but feel free to reach me in the POT slack if you want some pointers. |
Thanks @rflamary! I wanted to join the POT Slack, but unfortunately it seems that the workspace invite link hasn't been shared. Could you send me the POT Slack invite? Thanks. |
Here is the invite link: |
🚀 Feature
Making the
ot.sinkhorn2
function compatible withfunctorch.vmap
.Motivation
I'm using the
Python Optimal Transport
library. I want to define a loss function that iterates over every sample in my batch and calculates thesinkhorn
distance for that sample and its ground-truth value. What I was using before was a for-loop:but this is way too slow for my application. I was reading through
functorch
, and apparently I should have been able to use thevmap
functionality.But after wrapping my function in
vmap
, I get this weird error:Pitch
Apparently, the data-dependent
if-statement
needs to be replaced with other alternatives. Any help is appreciated.The text was updated successfully, but these errors were encountered: