-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/handle solve args #748
base: main
Are you sure you want to change the base?
Changes from all commits
310b273
6c8d42d
0b06cc4
feca7ec
c7cceb1
9edaafb
376eccf
01c89a2
49c8bbc
f303a14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -435,7 +435,7 @@ def _prepare( | |||
cost_matrix_rank: Optional[int] = None, | ||||
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None, | ||||
# problem | ||||
alpha: float = 0.5, | ||||
alpha: Optional[float] = None, | ||||
**kwargs: Any, | ||||
) -> quadratic_problem.QuadraticProblem: | ||||
self._a = a | ||||
|
@@ -456,6 +456,13 @@ def _prepare( | |||
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank | ||||
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs) | ||||
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs) | ||||
if alpha is None: | ||||
alpha = 1.0 if xy is None else 0.5 # set defaults according to the data provided | ||||
if alpha <= 0.0: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also raise error if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually moscot/src/moscot/backends/ott/_utils.py Line 91 in f303a14
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok makes sense! |
||||
raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.") | ||||
if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None): | ||||
raise ValueError(f"Expected `xy` to be `None` if `alpha` is not 1.0, found xy={xy}, alpha={alpha}.") | ||||
|
||||
Comment on lines
+459
to
+465
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. related to above, let's throw an error if |
||||
if alpha == 1.0 or xy is None: # GW | ||||
# arbitrary fused penalty; must be positive | ||||
geom_xy, fused_penalty = None, 1.0 | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just make a comment behind this that default is 0.5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually wondering whether we should have a default
alpha
whenever we don't want it to be 1.0I.e. always set it explicitly in the classes which use GW, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set it like this so that it doesn't change the current behaviour.
You mean to make non-optional? I personally prefer non-optional parameters, especially if the class is an internal solver. I also think we should rename
GWSolver
toFGWSolver
(because it technically can solve fgw and gw) and just setalpha=1
when in aGWProblem
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that makes sense