-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/handle solve args #748
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
This reverts commit 0b06cc4.
@@ -435,7 +435,7 @@ def _prepare( | |||
cost_matrix_rank: Optional[int] = None, | |||
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None, | |||
# problem | |||
alpha: float = 0.5, | |||
alpha: Optional[float] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just make a comment behind this that default is 0.5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually wondering whether we should have a default alpha
whenever we don't want it to be 1.0
I.e. always set it explicitly in the classes which use GW, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set it like this so that it doesn't change the current behaviour.
I.e. always set it explicitly in the classes which use GW, wdyt?
You mean to make non-optional? I personally prefer non-optional parameters, especially if the class is an internal solver. I also think we should rename GWSolver
to FGWSolver
(because it technically can solve fgw and gw) and just set alpha=1
when in a GWProblem
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that makes sense
if alpha is None: | ||
alpha = 1.0 if xy is None else 0.5 # set defaults according to the data provided | ||
if alpha <= 0.0: | ||
raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.") | ||
if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None): | ||
raise ValueError(f"Expected `xy` to be `None` if `alpha` is not 1.0, found xy={xy}, alpha={alpha}.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related to above, let's throw an error if alpha is None
and xy is not None
?
@@ -456,6 +456,13 @@ def _prepare( | |||
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank | |||
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs) | |||
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs) | |||
if alpha is None: | |||
alpha = 1.0 if xy is None else 0.5 # set defaults according to the data provided | |||
if alpha <= 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also raise error if alpha>1.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually
moscot/src/moscot/backends/ott/_utils.py
Line 91 in f303a14
def alpha_to_fused_penalty(alpha: float) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, overall really nice. Let's see what Mike and Marco say on the ott-jax side!
hi @MUCDK ,
So good news is we currently do a good job on partitioning the
kwargs
for solve. In solve we give anykwarg
we don't know to eitherSinkhornSolver
orGWSolver
constructors.SinkhornSolver
usesSinkhorn
orLRSinkhorn
fromottjax
, these classes don't havekwargs
in their constructors so when usingSinkhornSolver
as a backend we are good.GWSolver
usesGromovWasserstein
orLRGromovWasserstein
fromottjax
. The parent class of these classWassersteinSolver
don't throw an error on unrecognized args. The tests will pass after the ottjax PR merges.Here is the PR in
ott-jax
: ott-jax/ott#579Other things done:
CompoundProblem
or any other more abstract class. It's handled inGWSolver
as it should.Additionally closes:
solve
methods #720