-
Notifications
You must be signed in to change notification settings - Fork 499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to pointwize and sampled GW variants #470
Labels
Comments
Hello @patrick-nicodemus this makes sens you could give loss either as string for pre computed loss or a function for more geenral ones. feel free to propose a PR and try to respect the API for GW. |
rflamary
changed the title
Improvements to GW variants
Improvements to pointwize and sampled GW variants
May 5, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currently several algorithms to compute or estimate Gromov-Wasserstein distance are provided, so the user has lots of freedom to experiment with algorithms which are appropriate to their particular distribution size, accuracy requirements, loss function, etc.
However, the pointwise_gromov_wasserstein and sampled_gromov_wasserstein functions are substantially slower than gromov_wasserstein for analogous cases. Our lab is working with distributions of size N=100 and on a 20 core machine, the gromov_wasserstein function takes about 17-20 milliseconds. For pointwise_gromov_wasserstein with 5 iterations, log=False, max_iter=5, it takes between 40 and 80 milliseconds.
Granted, the original paper on Sampled Gromov Wasserstein points out that its advantage is strongest for distributions with N >> 100, and strongest when we are not talking about the square loss. However I do not think this explains the performance difference. I suspect a large share of the performance difference is due to the slowness of the user-supplied loss function being interpreted in a list comprehension each stage of the loop.
I propose that the interface for pointwise_gromov_wasserstein, sampled_gromov_wasserstein and GW_distance_estimation expose a way that users can select from a fixed list of loss operations, including square loss and absolute value loss, and internally these will be implemented in a vectorized way using a performant backend.
The text was updated successfully, but these errors were encountered: