Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt all internal models to new setup #1301

Merged
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scvi/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class _CONSTANTS_NT(NamedTuple):
PROTEIN_EXP_KEY: str = "proteins"
CAT_COVS_KEY: str = "extra_categorical_covs"
CONT_COVS_KEY: str = "extra_continuous_covs"
INDICES_KEY: str = "ind_x"


_CONSTANTS = _CONSTANTS_NT()
92 changes: 71 additions & 21 deletions scvi/data/anndata/_compat.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
from copy import deepcopy

import numpy as np
from anndata import AnnData
from sklearn.utils import deprecated

from scvi import _CONSTANTS

from . import _constants
from ._manager import AnnDataManager
from .fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ProteinObsmField,
)

LEGACY_REGISTRY_KEY_MAP = {
"X": _CONSTANTS.X_KEY,
"batch_indices": _CONSTANTS.BATCH_KEY,
"labels": _CONSTANTS.LABELS_KEY,
"cat_covs": _CONSTANTS.CAT_COVS_KEY,
"cont_covs": _CONSTANTS.CONT_COVS_KEY,
"protein_expression": _CONSTANTS.PROTEIN_EXP_KEY,
"ind_x": _CONSTANTS.INDICES_KEY,
}


def registry_from_setup_dict(setup_dict: dict) -> dict:
"""
Expand All @@ -37,15 +52,19 @@ def registry_from_setup_dict(setup_dict: dict) -> dict:
registry_key,
adata_mapping,
) in data_registry.items(): # Note: this does not work for empty fields.
if registry_key not in LEGACY_REGISTRY_KEY_MAP:
continue
new_registry_key = LEGACY_REGISTRY_KEY_MAP[registry_key]

attr_name = adata_mapping[_constants._DR_ATTR_NAME]
attr_key = adata_mapping[_constants._DR_ATTR_KEY]

field_registries[registry_key] = {
field_registries[new_registry_key] = {
_constants._DATA_REGISTRY_KEY: adata_mapping,
_constants._STATE_REGISTRY_KEY: dict(),
_constants._SUMMARY_STATS_KEY: dict(),
}
field_registry = field_registries[registry_key]
field_registry = field_registries[new_registry_key]
field_state_registry = field_registry[_constants._STATE_REGISTRY_KEY]
field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY]

Expand All @@ -58,21 +77,32 @@ def registry_from_setup_dict(setup_dict: dict) -> dict:
field_state_registry[
CategoricalObsField.CATEGORICAL_MAPPING_KEY
] = categorical_mapping["mapping"]
if attr_key == "_scvi_batch":
field_summary_stats[f"n_{registry_key}"] = summary_stats["n_batch"]
elif attr_key == "_scvi_labels":
field_summary_stats[f"n_{registry_key}"] = summary_stats["n_labels"]
if new_registry_key == _CONSTANTS.BATCH_KEY:
field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_batch"]
elif new_registry_key == _CONSTANTS.LABELS_KEY:
field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_labels"]
elif attr_name == _constants._ADATA_ATTRS.OBSM:
if attr_key == "_scvi_extra_continuous":
if new_registry_key == _CONSTANTS.CONT_COVS_KEY:
columns = setup_dict["extra_continuous_keys"].copy()
field_state_registry[NumericalJointObsField.COLUMNS_KEY] = columns
field_summary_stats[f"n_{registry_key}"] = columns.shape[0]
elif attr_key == "_scvi_extra_categoricals":
field_summary_stats[f"n_{new_registry_key}"] = columns.shape[0]
elif new_registry_key == _CONSTANTS.CAT_COVS_KEY:
extra_categoricals_mapping = deepcopy(setup_dict["extra_categoricals"])
field_state_registry.update(deepcopy(setup_dict["extra_categoricals"]))
field_summary_stats[f"n_{registry_key}"] = len(
field_summary_stats[f"n_{new_registry_key}"] = len(
extra_categoricals_mapping["keys"]
)
elif new_registry_key == _CONSTANTS.PROTEIN_EXP_KEY:
field_state_registry[ProteinObsmField.COLUMN_NAMES_KEY] = setup_dict[
"protein_names"
].copy()
if "totalvi_batch_mask" in setup_dict:
field_state_registry[
ProteinObsmField.PROTEIN_BATCH_MASK
] = setup_dict["totalvi_batch_mask"].copy()
field_summary_stats[f"n_{new_registry_key}"] = len(
setup_dict["protein_names"]
)
return registry


Expand Down Expand Up @@ -106,37 +136,57 @@ def manager_from_setup_dict(
data_registry = setup_dict[_constants._DATA_REGISTRY_KEY]
categorical_mappings = setup_dict["categorical_mappings"]
for registry_key, adata_mapping in data_registry.items():
if registry_key not in LEGACY_REGISTRY_KEY_MAP:
continue
new_registry_key = LEGACY_REGISTRY_KEY_MAP[registry_key]

field = None
attr_name = adata_mapping[_constants._DR_ATTR_NAME]
attr_key = adata_mapping[_constants._DR_ATTR_KEY]
if attr_name == _constants._ADATA_ATTRS.X:
field = LayerField(registry_key, None)
field = LayerField(_CONSTANTS.X_KEY, None)
setup_kwargs["layer"] = None
elif attr_name == _constants._ADATA_ATTRS.LAYERS:
field = LayerField(registry_key, attr_key)
field = LayerField(_CONSTANTS.X_KEY, attr_key)
setup_kwargs["layer"] = attr_key
elif attr_name == _constants._ADATA_ATTRS.OBS:
original_key = categorical_mappings[attr_key]["original_key"]
field = CategoricalObsField(registry_key, original_key)
setup_kwargs[f"{registry_key}_key"] = original_key
if new_registry_key in {_CONSTANTS.BATCH_KEY, _CONSTANTS.LABELS_KEY}:
original_key = categorical_mappings[attr_key]["original_key"]
field = CategoricalObsField(new_registry_key, original_key)
setup_kwargs[f"{new_registry_key}_key"] = original_key
elif new_registry_key == _CONSTANTS.INDICES_KEY:
adata.obs[attr_key] = np.arange(adata.n_obs).astype("int64")
field = NumericalObsField(new_registry_key, attr_key)
elif attr_name == _constants._ADATA_ATTRS.OBSM:
if attr_key == "_scvi_extra_continuous":
if new_registry_key == _CONSTANTS.CONT_COVS_KEY:
obs_keys = setup_dict["extra_continuous_keys"]
field = NumericalJointObsField(registry_key, obs_keys)
field = NumericalJointObsField(_CONSTANTS.CONT_COVS_KEY, obs_keys)
setup_kwargs["continuous_covariate_keys"] = obs_keys
elif attr_key == "_scvi_extra_categoricals":
elif new_registry_key == _CONSTANTS.CAT_COVS_KEY:
obs_keys = setup_dict["extra_categoricals"]["keys"]
field = CategoricalJointObsField(registry_key, obs_keys)
field = CategoricalJointObsField(_CONSTANTS.CAT_COVS_KEY, obs_keys)
setup_kwargs["categorical_covariate_keys"] = obs_keys
elif new_registry_key == _CONSTANTS.PROTEIN_EXP_KEY:
protein_names = setup_dict["protein_names"]
adata.uns["_protein_names"] = protein_names
field = ProteinObsmField(
_CONSTANTS.PROTEIN_EXP_KEY,
attr_key,
"_scvi_batch",
colnames_uns_key="_protein_names",
)
setup_kwargs["protein_expression_obsm_key"] = attr_key
setup_kwargs["protein_names_uns_key"] = "_protein_names"
else:
raise NotImplementedError(
f"Unrecognized .obsm attribute {attr_key}. Backwards compatibility unavailable."
f"Unrecognized .obsm attribute {attr_key} registered as {new_registry_key}. Backwards compatibility unavailable."
)
else:
raise NotImplementedError(
f"Backwards compatibility for attribute {attr_name} is not implemented yet."
f"Backwards compatibility for attribute {attr_name} is not implemented."
)
fields.append(field)

setup_method_args = {
_constants._MODEL_NAME_KEY: cls.__name__,
_constants._SETUP_KWARGS_KEY: setup_kwargs,
Expand Down
44 changes: 0 additions & 44 deletions scvi/data/anndata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,50 +44,6 @@ def get_anndata_attribute(
return field


def get_from_registry(
adata: anndata.AnnData, key: str
) -> Union[np.ndarray, pd.DataFrame]:
"""
Returns the object in AnnData associated with the key in ``.uns['_scvi']['data_registry']``.

Parameters
----------
adata
anndata object already setup with setup_anndata
key
key of object to get from ``adata.uns['_scvi]['data_registry']``

Returns
-------
The requested data

Examples
--------
>>> import scvi
>>> adata = scvi.data.cortex()
>>> adata.uns['_scvi']['data_registry']
{'X': ['_X', None],
'batch': ['obs', 'batch'],
'labels': ['obs', 'labels']}
>>> batch = get_from_registry(adata, "batch")
>>> batch
array([[0],
[0],
[0],
...,
[0],
[0],
[0]])
"""
data_loc = adata.uns[_constants._SETUP_DICT_KEY][_constants._DATA_REGISTRY_KEY][key]
attr_name, attr_key = (
data_loc[_constants._DR_ATTR_NAME],
data_loc[_constants._DR_ATTR_KEY],
)

return get_anndata_attribute(adata, attr_name, attr_key)


@deprecated(
extra="Please use the model-specific setup_anndata methods instead. The global method will be removed in version 0.15.0."
)
Expand Down
2 changes: 2 additions & 0 deletions scvi/data/anndata/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ._layer_field import LayerField
from ._obs_field import CategoricalObsField, NumericalObsField
from ._obsm_field import CategoricalJointObsField, NumericalJointObsField, ObsmField
from ._scanvi import LabelsWithUnlabeledObsField
from ._totalvi import ProteinObsmField

__all__ = [
Expand All @@ -13,4 +14,5 @@
"CategoricalJointObsField",
"ObsmField",
"ProteinObsmField",
"LabelsWithUnlabeledObsField",
]
11 changes: 9 additions & 2 deletions scvi/data/anndata/fields/_obs_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class CategoricalObsField(BaseObsField):
"""

CATEGORICAL_MAPPING_KEY = "categorical_mapping"
ORIGINAL_ATTR_KEY = "original_key"

def __init__(self, registry_key: str, obs_key: Optional[str]) -> None:
self.is_default = obs_key is None
Expand Down Expand Up @@ -114,7 +115,10 @@ def register_field(self, adata: AnnData) -> dict:
categorical_mapping = _make_obs_column_categorical(
adata, self._original_attr_key, self.attr_key, return_mapping=True
)
return {self.CATEGORICAL_MAPPING_KEY: categorical_mapping}
return {
self.CATEGORICAL_MAPPING_KEY: categorical_mapping,
self.ORIGINAL_ATTR_KEY: self._original_attr_key,
}

def transfer_field(
self,
Expand Down Expand Up @@ -150,7 +154,10 @@ def transfer_field(
categorical_dtype=cat_dtype,
return_mapping=True,
)
return {self.CATEGORICAL_MAPPING_KEY: new_mapping}
return {
self.CATEGORICAL_MAPPING_KEY: new_mapping,
self.ORIGINAL_ATTR_KEY: self._original_attr_key,
}

def get_summary_stats(self, state_registry: dict) -> dict:
categorical_mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
Expand Down
3 changes: 2 additions & 1 deletion scvi/data/anndata/fields/_obsm_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def is_empty(self) -> bool:

def validate_field(self, adata: AnnData) -> None:
super().validate_field(adata)
assert self.attr_key in adata.obsm, f"{self.attr_key} not found in adata.obsm."
if self.attr_key not in adata.obsm:
raise KeyError(f"{self.attr_key} not found in adata.obsm.")

obsm_data = self.get_field(adata)

Expand Down
89 changes: 89 additions & 0 deletions scvi/data/anndata/fields/_scanvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Optional, Union

import numpy as np
from anndata import AnnData
from pandas.api.types import CategoricalDtype

from scvi.data.anndata._utils import _make_obs_column_categorical

from ._obs_field import CategoricalObsField


class LabelsWithUnlabeledObsField(CategoricalObsField):
"""
An AnnDataField for labels which include explicitly unlabeled cells.

Remaps unlabeled category to the final index if present in labels.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just better explain that unlabelled cells are labelled with a special category name that is user defined.

Copy link
Contributor Author

@justjhong justjhong Jan 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Labeled and labelled are both correct spellings. Labeled is the preferred spelling in American English. Labelled is the preferred spelling in British English.

Already integrating yourself into British culture I see


Parameters
----------
registry_key
Key to register field under in data registry.
obs_key
Key to access the field in the AnnData obs mapping. If None, defaults to `registry_key`.
unlabeled_category
Value assigned to unlabeled cells.
"""

UNLABELED_CATEGORY = "unlabeled_category"
WAS_REMAPPED = "was_remapped"

def __init__(
self,
registry_key: str,
obs_key: Optional[str],
unlabeled_category: Union[str, int, float],
) -> None:
super().__init__(registry_key, obs_key)
self._unlabeled_category = unlabeled_category

def _remap_unlabeled_to_final_category(
self, adata: AnnData, mapping: np.ndarray
) -> dict:
labels = self._get_original_column(adata)

if self._unlabeled_category in labels:
unlabeled_idx = np.where(mapping == self._unlabeled_category)
unlabeled_idx = unlabeled_idx[0][0]
# move unlabeled category to be the last position
mapping[unlabeled_idx], mapping[-1] = mapping[-1], mapping[unlabeled_idx]
cat_dtype = CategoricalDtype(categories=mapping, ordered=True)
# rerun setup for the batch column
mapping = _make_obs_column_categorical(
adata,
self._original_attr_key,
self.attr_key,
categorical_dtype=cat_dtype,
return_mapping=True,
)
remapped = True
else:
remapped = False

return {
self.CATEGORICAL_MAPPING_KEY: mapping,
self.ORIGINAL_ATTR_KEY: self._original_attr_key,
self.UNLABELED_CATEGORY: self._unlabeled_category,
self.WAS_REMAPPED: remapped,
}

def register_field(self, adata: AnnData) -> dict:
if self.is_default:
self._setup_default_attr(adata)

state_registry = super().register_field(adata)
mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
return self._remap_unlabeled_to_final_category(adata, mapping)

def transfer_field(
self,
state_registry: dict,
adata_target: AnnData,
extend_categories: bool = False,
**kwargs,
) -> dict:
transfer_state_registry = super().transfer_field(
state_registry, adata_target, extend_categories=extend_categories, **kwargs
)
mapping = transfer_state_registry[self.CATEGORICAL_MAPPING_KEY]
return self._remap_unlabeled_to_final_category(adata_target, mapping)
9 changes: 5 additions & 4 deletions scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from scvi import _CONSTANTS, settings
from scvi.data.anndata import AnnDataManager
from scvi.data.anndata.fields import LabelsWithUnlabeledObsField
from scvi.dataloaders._ann_dataloader import AnnDataLoader, BatchSampler
from scvi.dataloaders._semi_dataloader import SemiSupervisedDataLoader
from scvi.model._utils import parse_use_gpu_arg
Expand Down Expand Up @@ -212,10 +213,10 @@ def __init__(
self.data_loader_kwargs = kwargs
self.n_samples_per_label = n_samples_per_label

setup_dict = adata_manager.get_setup_dict()
key = setup_dict["data_registry"][_CONSTANTS.LABELS_KEY]["attr_key"]
original_key = setup_dict["categorical_mappings"][key]["original_key"]
labels = np.asarray(adata_manager.obs[original_key]).ravel()
original_key = adata_manager.get_state_registry(_CONSTANTS.LABELS_KEY)[
LabelsWithUnlabeledObsField.ORIGINAL_ATTR_KEY
]
labels = np.asarray(adata_manager.adata.obs[original_key]).ravel()
self._unlabeled_indices = np.argwhere(labels == unlabeled_category).ravel()
self._labeled_indices = np.argwhere(labels != unlabeled_category).ravel()

Expand Down
Loading