Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/handle solve args #748

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
9 changes: 8 additions & 1 deletion src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _prepare(
cost_matrix_rank: Optional[int] = None,
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
# problem
alpha: float = 0.5,
alpha: Optional[float] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just make a comment behind this that default is 0.5

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually wondering whether we should have a default alpha whenever we don't want it to be 1.0

I.e. always set it explicitly in the classes which use GW, wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set it like this so that it doesn't change the current behaviour.

I.e. always set it explicitly in the classes which use GW, wdyt?

You mean to make non-optional? I personally prefer non-optional parameters, especially if the class is an internal solver. I also think we should rename GWSolver to FGWSolver (because it technically can solve fgw and gw) and just set alpha=1 when in a GWProblem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense

**kwargs: Any,
) -> quadratic_problem.QuadraticProblem:
self._a = a
Expand All @@ -456,6 +456,13 @@ def _prepare(
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
if alpha is None:
alpha = 1.0 if xy is None else 0.5 # set defaults according to the data provided
if alpha <= 0.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also raise error if alpha>1.0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually

def alpha_to_fused_penalty(alpha: float) -> float:
gives the error for invalid range. I think it's better if one source only throws error. And thinking of it I think even the check in 461 is redundant so I can remove it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok makes sense!

raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None):
raise ValueError(f"Expected `xy` to be `None` if `alpha` is not 1.0, found xy={xy}, alpha={alpha}.")

Comment on lines +459 to +465
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to above, let's throw an error if alpha is None and xy is not None?

if alpha == 1.0 or xy is None: # GW
# arbitrary fused penalty; must be positive
geom_xy, fused_penalty = None, 1.0
Expand Down
18 changes: 1 addition & 17 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,24 +430,8 @@ def solve(
solver_class = backends.get_solver(
self.problem_kind, solver_name=solver_name, backend=backend, return_class=True
)
init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs)
# if linear problem, then alpha is 0.0 by default
# if quadratic problem, then alpha is 1.0 by default
alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0)
if alpha < 0.0 or alpha > 1.0:
raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.")
if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)):
raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.")
if self.problem_kind == "quadratic":
if self.x is None or self.y is None:
raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.")
if alpha != 1.0 and self.xy is None: # means FGW case
raise ValueError(
"`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class."
)
if alpha == 1.0 and self.xy is not None:
raise ValueError("Unable to solve a quadratic problem with `alpha = 1` and `xy` supplied.")

init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs)
self._solver = solver_class(**init_kwargs)

# note that the solver call consists of solver._prepare and solver._solve
Expand Down
56 changes: 56 additions & 0 deletions tests/problems/base/test_general_problem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Literal, Optional, Tuple

import pytest
Expand Down Expand Up @@ -29,6 +30,29 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData):

assert isinstance(prob.solution, BaseDiscreteSolverOutput)

@pytest.mark.parametrize(
("kind", "rank"),
[
("linear", -1),
("linear", 5),
("quadratic", -1),
("quadratic", 5),
],
)
def test_unrecognized_args(
self, adata_x: AnnData, adata_y: AnnData, kind: Literal["linear", "quadratic"], rank: int
):
prob = OTProblem(adata_x, adata_y)
data = {
"xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"},
}
if "quadratic" in kind:
data["x"] = {"attr": "X"}
data["y"] = {"attr": "X"}

with pytest.raises(TypeError):
prob.prepare(**data).solve(epsilon=5e-1, rank=rank, dummy=42)

@pytest.mark.fast
def test_output(self, adata_x: AnnData, x: Geom_t):
problem = OTProblem(adata_x)
Expand Down Expand Up @@ -346,3 +370,35 @@ def test_set_graph_xy_test_t(self, adata_x: AnnData, adata_y: AnnData, t: float)
assert pushed_0.shape == pushed_1.shape
assert np.all(np.abs(pushed_0 - pushed_1).sum() > np.abs(pushed_2 - pushed_1).sum())
assert np.all(np.abs(pushed_0 - pushed_2).sum() > np.abs(pushed_1 - pushed_2).sum())

@pytest.mark.parametrize(
("attrs", "alpha", "raise_msg"),
[
({"xy"}, 0.5, "type-error"),
({"xy", "x", "y"}, 0, re.escape("Expected `alpha` to be in interval `(0, 1]`, found")),
({"xy", "x", "y"}, 1.1, re.escape("Expected `alpha` to be in interval `(0, 1]`, found")),
({"xy", "x", "y"}, 0.5, None),
({"x", "y"}, 1.0, None),
({"x", "y"}, 0.5, re.escape("Expected `xy` to be `None` if `alpha` is not 1.0, found")),
],
)
def test_xy_alpha_raises(self, adata_x: AnnData, adata_y: AnnData, attrs, alpha, raise_msg):
prob = OTProblem(adata_x, adata_y)
data = {
"xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"} if "xy" in attrs else {},
"x": {"attr": "X"} if "x" in attrs else {},
"y": {"attr": "X"} if "y" in attrs else {},
}
prob = prob.prepare(
**data,
)
if raise_msg is not None:
if raise_msg == "type-error":
with pytest.raises(TypeError):
prob.solve(epsilon=5e-1, alpha=alpha)
else:
with pytest.raises(ValueError, match=raise_msg):
prob.solve(epsilon=5e-1, alpha=alpha)
else:
prob.solve(epsilon=5e-1, alpha=alpha)
assert isinstance(prob.solution, BaseDiscreteSolverOutput)
46 changes: 27 additions & 19 deletions tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,16 @@ def test_prepare_star(self, adata_space_rotate: AnnData, reference: str):
assert ref == reference
assert isinstance(ap[prob_key], ap._base_problem_type)

@pytest.mark.skip(reason="See https:/theislab/moscot/issues/678")
@pytest.mark.parametrize(
("epsilon", "alpha", "rank", "initializer"),
[(1, 0.9, -1, None), (1, 0.5, 10, "random"), (1, 0.5, 10, "rank2"), (0.1, 0.1, -1, None)],
("epsilon", "alpha", "rank", "initializer", "should_raise"),
[
(1, 0.9, -1, None, False),
(1, 0.5, 10, "random", False),
(1, 0.5, 10, "rank2", False),
(0.1, 0.1, -1, None, False),
(0.1, -0.1, -1, None, True), # Invalid alpha
(0.1, 1.1, -1, None, True), # Invalid alpha
],
)
def test_solve_balanced(
self,
Expand All @@ -87,6 +93,7 @@ def test_solve_balanced(
alpha: float,
rank: int,
initializer: Optional[Literal["random", "rank2"]],
should_raise: bool,
):
kwargs = {}
if rank > -1:
Expand All @@ -95,22 +102,23 @@ def test_solve_balanced(
# kwargs["kwargs_init"] = {"key": 0}
# kwargs["key"] = 0
return # TODO(@MUCDK) fix after refactoring
ap = (
AlignmentProblem(adata=adata_space_rotate)
.prepare(batch_key="batch")
.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)
)
for prob_key in ap:
assert ap[prob_key].solution.rank == rank
if initializer != "random": # TODO: is this valid?
assert ap[prob_key].solution.converged

# TODO(michalk8): use np.testing
assert np.allclose(*(sol.cost for sol in ap.solutions.values()))
assert np.all([sol.converged for sol in ap.solutions.values()])
np.testing.assert_array_equal(
[np.all(np.isfinite(sol.transport_matrix)) for sol in ap.solutions.values()], True
)
ap = AlignmentProblem(adata=adata_space_rotate).prepare(batch_key="batch")
if should_raise:
with pytest.raises(ValueError, match=r"Expected `alpha`"):
ap.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)
else:
ap = ap.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)
for prob_key in ap:
assert ap[prob_key].solution.rank == rank
if initializer != "random": # TODO: is this valid?
assert ap[prob_key].solution.converged

# TODO(michalk8): use np.testing
assert np.allclose(*(sol.cost for sol in ap.solutions.values()))
assert np.all([sol.converged for sol in ap.solutions.values()])
np.testing.assert_array_equal(
[np.all(np.isfinite(sol.transport_matrix)) for sol in ap.solutions.values()], True
)

def test_solve_unbalanced(self, adata_space_rotate: AnnData):
tau_a, tau_b = [0.8, 1]
Expand Down
11 changes: 6 additions & 5 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from pathlib import Path
from typing import Any, List, Literal, Mapping, Optional, Union

Expand Down Expand Up @@ -301,20 +302,20 @@ def test_problem_type(
assert isinstance(sol._output, solution_kind)

@pytest.mark.parametrize(
("sc_attr", "alpha"),
("sc_attr", "alpha", "raise_msg"),
[
(None, 0.5),
({"attr": "X"}, 0),
(None, 0.5, re.escape("Expected `alpha` to be 0 for a `linear problem`.")),
({"attr": "X"}, 0, re.escape("Expected `alpha` to be in interval `(0, 1]`, found `0`.")),
],
)
def test_problem_type_corner_cases(
self, adata_mapping: AnnData, sc_attr: Optional[Mapping[str, str]], alpha: Optional[float]
self, adata_mapping: AnnData, sc_attr: Optional[Mapping[str, str]], alpha: Optional[float], raise_msg: str
):
# initialize and prepare the MappingProblem
adataref, adatasp = _adata_spatial_split(adata_mapping)
mp = MappingProblem(adataref, adatasp)
mp = mp.prepare(batch_key="batch", sc_attr=sc_attr)

# we test two incompatible combinations of `sc_attr` and `alpha`
with pytest.raises(ValueError, match=r"^Expected `alpha`"):
with pytest.raises(ValueError, match=raise_msg):
mp.solve(alpha=alpha)
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_solve_balanced(self, adata_spatio_temporal: AnnData):
assert isinstance(subsol, BaseDiscreteSolverOutput)
assert key in expected_keys

@pytest.mark.skip(reason="unbalanced does not work yet")
@pytest.mark.skip(reason="unbalanced does not work yet: https:/ott-jax/ott/issues/519")
def test_solve_unbalanced(self, adata_spatio_temporal: AnnData):
taus = [9e-1, 1e-2]
problem1 = SpatioTemporalProblem(adata=adata_spatio_temporal)
Expand Down
Loading