Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/handle solve args #748

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

selmanozleyen
Copy link
Collaborator

@selmanozleyen selmanozleyen commented Sep 22, 2024

hi @MUCDK ,

So good news is we currently do a good job on partitioning the kwargs for solve. In solve we give any kwarg we don't know to either SinkhornSolver or GWSolver constructors. SinkhornSolver uses Sinkhorn or LRSinkhorn from ottjax, these classes don't have kwargs in their constructors so when using SinkhornSolver as a backend we are good. GWSolver uses GromovWasserstein or LRGromovWasserstein from ottjax. The parent class of these class WassersteinSolver don't throw an error on unrecognized args. The tests will pass after the ottjax PR merges.

Here is the PR in ott-jax: ott-jax/ott#579

Other things done:

  • Added tests if we throw appropriate errors on completely unrecognized arguments
  • I refactored where we handle checks for alpha, now it's completely independent from CompoundProblem or any other more abstract class. It's handled in GWSolver as it should.
  • I added extra tests on the errors we raise for alpha or the data given.

Additionally closes:

@selmanozleyen selmanozleyen marked this pull request as draft September 22, 2024 21:35
@@ -435,7 +435,7 @@ def _prepare(
cost_matrix_rank: Optional[int] = None,
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
# problem
alpha: float = 0.5,
alpha: Optional[float] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just make a comment behind this that default is 0.5

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually wondering whether we should have a default alpha whenever we don't want it to be 1.0

I.e. always set it explicitly in the classes which use GW, wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set it like this so that it doesn't change the current behaviour.

I.e. always set it explicitly in the classes which use GW, wdyt?

You mean to make non-optional? I personally prefer non-optional parameters, especially if the class is an internal solver. I also think we should rename GWSolver to FGWSolver (because it technically can solve fgw and gw) and just set alpha=1 when in a GWProblem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense

Comment on lines +459 to +465
if alpha is None:
alpha = 1.0 if xy is None else 0.5 # set defaults according to the data provided
if alpha <= 0.0:
raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None):
raise ValueError(f"Expected `xy` to be `None` if `alpha` is not 1.0, found xy={xy}, alpha={alpha}.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to above, let's throw an error if alpha is None and xy is not None?

@@ -456,6 +456,13 @@ def _prepare(
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
if alpha is None:
alpha = 1.0 if xy is None else 0.5 # set defaults according to the data provided
if alpha <= 0.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also raise error if alpha>1.0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually

def alpha_to_fused_penalty(alpha: float) -> float:
gives the error for invalid range. I think it's better if one source only throws error. And thinking of it I think even the check in 461 is redundant so I can remove it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok makes sense!

Copy link
Collaborator

@MUCDK MUCDK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, overall really nice. Let's see what Mike and Marco say on the ott-jax side!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants