Skip to content

Commit

Permalink
Merge pull request #736 from quadbio/fix/spatial_mixins
Browse files Browse the repository at this point in the history
Fix/spatial mixins
  • Loading branch information
Marius1311 authored Jul 18, 2024
2 parents 03ffe72 + da54cea commit 76b259e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
44 changes: 38 additions & 6 deletions src/moscot/problems/space/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def correlate( # type: ignore[misc]
corr_method: Literal["pearson", "spearman"] = "pearson",
device: Optional[Device_t] = None,
groupby: Optional[str] = None,
batch_size: Optional[int] = None,
) -> Union[Mapping[Tuple[K, K], Mapping[Any, pd.Series]], Mapping[Tuple[K, K], pd.Series]]:
"""Correlate true and predicted gene expression.
Expand All @@ -414,10 +415,13 @@ def correlate( # type: ignore[misc]
- ``'pearson'`` - `Pearson correlation <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_.
- ``'spearman'`` - `Spearman's rank correlation
<https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_.
groupby
Optional `obs` field in `AnnData` to compute correlations over categorical groups.
device
Device where to transfer the solutions, see :meth:`~moscot.base.output.BaseSolverOutput.to`.
groupby
Optional key in :attr:`~anndata.AnnData.obs`, containing categorical annotations for grouping.
batch_size:
Number of features to process at once. If :obj:`None`, process all features at once.
Larger values will require more memory.
Returns
-------
Expand Down Expand Up @@ -462,7 +466,14 @@ def correlate( # type: ignore[misc]
gexp_sp = gexp_sp.toarray()

# predict spatial feature expression
gexp_pred_sp = val.to(device=device).pull(gexp_sc, scale_by_marginals=True)
n_splits = np.max([np.floor(gexp_sc.shape[1] / batch_size), 1]) if batch_size else 1
logger.debug(f"Processing {gexp_sc.shape[1]} features in {n_splits} batches.")
gexp_pred_sp = np.hstack(
[
val.to(device=device).pull(x, scale_by_marginals=True)
for x in np.array_split(gexp_sc, n_splits, axis=1)
],
)

# loop over groups and compute correlations
for group, group_mask in group_masks.items():
Expand All @@ -484,6 +495,7 @@ def impute( # type: ignore[misc]
self: SpatialMappingMixinProtocol[K, B],
var_names: Optional[Sequence[str]] = None,
device: Optional[Device_t] = None,
batch_size: Optional[int] = None,
) -> AnnData:
"""Impute the expression of specific genes.
Expand All @@ -493,6 +505,9 @@ def impute( # type: ignore[misc]
Genes in :attr:`~anndata.AnnData.var_names` to impute. If :obj:`None`, use all genes in :attr:`adata_sc`.
device
Device where to transfer the solutions, see :meth:`~moscot.base.output.BaseSolverOutput.to`.
batch_size:
Number of features to process at once. If :obj:`None`, process all features at once.
Larger values will require more memory.
Returns
-------
Expand All @@ -505,12 +520,29 @@ def impute( # type: ignore[misc]
if sp.issparse(gexp_sc):
gexp_sc = gexp_sc.toarray()

predictions = [val.to(device=device).pull(gexp_sc, scale_by_marginals=True) for val in self.solutions.values()]
# predict spatial feature expression
n_splits = np.max([np.floor(gexp_sc.shape[1] / batch_size), 1]) if batch_size else 1
logger.debug(f"Processing {gexp_sc.shape[1]} features in {n_splits} batches.")

predictions = np.nan_to_num(
np.vstack(
[
np.hstack(
[
val.to(device=device).pull(x, scale_by_marginals=True)
for x in np.array_split(gexp_sc, n_splits, axis=1)
]
)
for val in self.solutions.values()
]
),
nan=0.0,
copy=False,
)

adata_pred = AnnData(np.nan_to_num(np.vstack(predictions), nan=0.0, copy=False))
adata_pred = AnnData(X=predictions, obsm=self.adata_sp.obsm.copy())
adata_pred.obs_names = self.adata_sp.obs_names
adata_pred.var_names = var_names
adata_pred.obsm = self.adata_sp.obsm.copy()

return adata_pred

Expand Down
6 changes: 4 additions & 2 deletions tests/problems/space/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,20 @@ class TestSpatialMappingAnalysisMixin:
@pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}])
@pytest.mark.parametrize("var_names", ["0", [str(i) for i in range(20)]])
@pytest.mark.parametrize("groupby", [None, "covariate"])
@pytest.mark.parametrize("batch_size", [None, 7, 10, 100])
def test_analysis(
self,
adata_mapping: AnnData,
sc_attr: Dict[str, str],
var_names: Optional[List[Optional[str]]],
groupby: Optional[str],
batch_size: Optional[int],
):
adataref, adatasp = _adata_spatial_split(adata_mapping)
mp = MappingProblem(adataref, adatasp).prepare(batch_key="batch", sc_attr=sc_attr).solve()

corr = mp.correlate(var_names, groupby=groupby)
imp = mp.impute()
corr = mp.correlate(var_names, groupby=groupby, batch_size=batch_size)
imp = mp.impute(batch_size=batch_size)

if groupby:
for key in adata_mapping.obs[groupby].cat.categories:
Expand Down

0 comments on commit 76b259e

Please sign in to comment.