Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnnData Manager refactor #1238

Merged
merged 17 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/api/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ Setting up an AnnData object is a prerequisite for running any ``scvi-tools`` mo
data.setup_anndata
data.transfer_anndata_setup
data.register_tensor_from_anndata
data.view_anndata_setup


Basic preprocessing
Expand Down
2 changes: 1 addition & 1 deletion scvi/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class _CONSTANTS_NT(NamedTuple):
X_KEY: str = "X"
BATCH_KEY: str = "batch_indices"
BATCH_KEY: str = "batch"
LABELS_KEY: str = "labels"
PROTEIN_EXP_KEY: str = "protein_expression"
CAT_COVS_KEY: str = "cat_covs"
Expand Down
12 changes: 0 additions & 12 deletions scvi/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from anndata import read_csv, read_h5ad, read_loom, read_text

from ._anndata import (
get_from_registry,
register_tensor_from_anndata,
setup_anndata,
transfer_anndata_setup,
view_anndata_setup,
)
from ._datasets import (
annotation_simulation,
brainlarge_dataset,
Expand Down Expand Up @@ -36,9 +29,6 @@
from ._read import read_10x_atac, read_10x_multiome

__all__ = [
"setup_anndata",
"get_from_registry",
"view_anndata_setup",
"poisson_gene_selection",
"organize_cite_seq_10x",
"pbmcs_10x_cite_seq",
Expand All @@ -58,8 +48,6 @@
"prefrontalcortex_starmap",
"frontalcortex_dropseq",
"annotation_simulation",
"transfer_anndata_setup",
"register_tensor_from_anndata",
"read_h5ad",
"read_csv",
"read_loom",
Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_brain_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
import scipy.sparse as sp_sparse

from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_cite_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pandas as pd

from scvi import settings
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata


def _load_pbmcs_10x_cite_seq(
Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_cortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
import pandas as pd

from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import anndata
import numpy as np

from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_heartcellatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import anndata

from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata


def _load_heart_cell_atlas_subsampled(
Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_loom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pandas as pd
from anndata import AnnData

from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_pbmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import pandas as pd

from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._dataset_10x import _load_dataset_10x
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata


def _load_purified_pbmc_dataset(
Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_seqfish.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
import pandas as pd

from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_smfish.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np
import pandas as pd

from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download
from scvi.data.anndata._utils import _setup_anndata

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/data/_built_in_data/_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from anndata import AnnData

from scvi.data._anndata import _setup_anndata
from scvi.data.anndata._utils import _setup_anndata

logger = logging.getLogger(__name__)

Expand Down
13 changes: 13 additions & 0 deletions scvi/data/anndata/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ._utils import (
get_from_registry,
register_tensor_from_anndata,
setup_anndata,
transfer_anndata_setup,
)

__all__ = [
"setup_anndata",
"get_from_registry",
"transfer_anndata_setup",
"register_tensor_from_anndata",
]
48 changes: 48 additions & 0 deletions scvi/data/anndata/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from anndata import AnnData

from . import _constants
from ._fields import CategoricalObsField, LayerField
from ._manager import AnnDataManager


def manager_from_setup_dict(
adata: AnnData, setup_dict: dict, **transfer_kwargs
) -> AnnDataManager:
"""
Creates an AnnDataManager given only a scvi-tools setup dictionary.

Only to be used for backwards compatibility when loading setup dictionaries for models.
Infers the AnnDataField instances used to define the AnnDataManager instance,
then uses the `AnnDataManager.transfer_setup(...)` method to register the new AnnData object.

Parameters
----------
adata
AnnData object to be registered.
setup_dict
Setup dictionary created after registering an AnnData using an AnnDataManager object.
**kwargs
Keyword arguments to modify transfer behavior.
"""
source_adata_manager = AnnDataManager()
data_registry = setup_dict[_constants._DATA_REGISTRY_KEY]
categorical_mappings = setup_dict[_constants._CATEGORICAL_MAPPINGS_KEY]
for registry_key, adata_mapping in data_registry.items():
field = None
attr_name = adata_mapping[_constants._DR_ATTR_NAME]
attr_key = adata_mapping[_constants._DR_ATTR_KEY]
if attr_name == _constants._ADATA_ATTRS.X:
field = LayerField(registry_key, None)
elif attr_name == _constants._ADATA_ATTRS.LAYERS:
field = LayerField(registry_key, attr_key)
elif attr_name == _constants._ADATA_ATTRS.OBS:
original_key = categorical_mappings[attr_key][_constants._CM_ORIGINAL_KEY]
field = CategoricalObsField(registry_key, original_key)
else:
raise NotImplementedError(
f"Backwards compatibility for attribute {attr_name} is not implemented yet."
)
source_adata_manager.add_field(field)
return source_adata_manager.transfer_setup(
adata, source_setup_dict=setup_dict, **transfer_kwargs
)
46 changes: 46 additions & 0 deletions scvi/data/anndata/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import NamedTuple

################################
# scVI Manager Store Constants #
################################

_SCVI_UUID_KEY = "_scvi_uuid"

#############################
# scVI Setup Dict Constants #
#############################

_SETUP_DICT_KEY = "_scvi"
_DATA_REGISTRY_KEY = "data_registry"
_CATEGORICAL_MAPPINGS_KEY = "categorical_mappings"
_SUMMARY_STATS_KEY = "summary_stats"

################################
# scVI Data Registry Constants #
################################

_DR_ATTR_NAME = "attr_name"
_DR_ATTR_KEY = "attr_key"

#######################################
# scVI Categorical Mappings Constants #
#######################################

_CM_ORIGINAL_KEY = "original_key"
_CM_MAPPING_KEY = "mapping"

############################
# AnnData Object Constants #
############################


class _ADATA_ATTRS_NT(NamedTuple):
X: str = "X"
LAYERS: str = "layers"
OBS: str = "obs"
OBSM: str = "obsm"
VAR: str = "var"
VARM: str = "varm"


_ADATA_ATTRS = _ADATA_ATTRS_NT()
5 changes: 5 additions & 0 deletions scvi/data/anndata/_fields/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._base_field import BaseAnnDataField
from ._layer_field import LayerField
from ._obs_field import CategoricalObsField

__all__ = ["BaseAnnDataField", "LayerField", "CategoricalObsField"]
88 changes: 88 additions & 0 deletions scvi/data/anndata/_fields/_base_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from abc import ABC, abstractmethod

import numpy as np
from anndata import AnnData

from scvi.data.anndata import _constants
from scvi.data.anndata._utils import _get_field


class BaseAnnDataField(ABC):
"""
Abstract class for a single AnnData field.

An AnnDataField class defines how scvi-tools will map a data field used by a model
to an attribute in an AnnData object.
"""

def __init__(self) -> None:
super().__init__()

@property
@abstractmethod
def registry_key(self):
"""The key that is referenced by models via a data loader."""
pass

@property
@abstractmethod
def attr_name(self):
"""The name of the AnnData attribute where the data is stored (e.g. obs)."""
pass

@property
@abstractmethod
def attr_key(self):
"""The key of the data field within the relevant AnnData attribute."""
pass

@abstractmethod
def validate_field(self, adata: AnnData) -> None:
"""Validates whether an AnnData object is compatible with this field definition."""
pass

@abstractmethod
def register_field(self, adata: AnnData) -> None:
"""Sets up the AnnData object and creates a mapping for scvi-tools models to use."""
self.validate_field(adata)

@abstractmethod
def transfer_field(self, setup_dict: dict, adata_target: AnnData, **kwargs) -> None:
"""
Takes an existing scvi-tools setup dictionary and transfers the same setup to the target AnnData.

Used when one is running a pretrained model on a new AnnData object, which
requires the mapping from the original data to be applied to the new AnnData object.

Parameters
----------
setup_dict
Setup dictionary created after registering an AnnData using an AnnDataManager object.
adata_target
AnnData object that is being registered.
**kwargs
Keyword arguments to modify transfer behavior.
"""
pass

def data_registry_mapping(self) -> dict:
"""
Returns a nested dictionary which describes the mapping to the AnnData data field.

The dictionary is of the form {registry_key: {"attr_name": attr_name, "attr_key": attr_key}}.
This mapping is then combined with the mappings of other fields to make up the data registry.
"""
return {
self.registry_key: {
_constants._DR_ATTR_NAME: self.attr_name,
_constants._DR_ATTR_KEY: self.attr_key,
}
}

def get_field(self, adata: AnnData) -> np.ndarray:
"""Returns the data field as a NumPy array for a given AnnData object."""
return _get_field(adata, self.attr_name, self.attr_key)

def compute_summary_stats(self, adata: AnnData) -> dict:
"""Returns a dictionary comprising of summary statistics relevant to the field."""
return dict()
Loading