Skip to content

Commit

Permalink
Adapt all external models to new setup (#1302)
Browse files Browse the repository at this point in the history
* adapt cellassign

* adapt gimvi

* adapt solo model

* adapt stereoscope
  • Loading branch information
justjhong authored Jan 14, 2022
1 parent 14ec85e commit 1878ac5
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 131 deletions.
1 change: 1 addition & 0 deletions scvi/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class _CONSTANTS_NT(NamedTuple):
CAT_COVS_KEY: str = "extra_categorical_covs"
CONT_COVS_KEY: str = "extra_continuous_covs"
INDICES_KEY: str = "ind_x"
SIZE_FACTOR_KEY: str = "size_factor"


_CONSTANTS = _CONSTANTS_NT()
65 changes: 34 additions & 31 deletions scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
from pytorch_lightning.callbacks import Callback

from scvi import _CONSTANTS
from scvi.data.anndata import register_tensor_from_anndata
from scvi.data.anndata._utils import _setup_anndata
from scvi.data.anndata import AnnDataManager
from scvi.data.anndata.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
)
from scvi.dataloaders import DataSplitter
from scvi.external.cellassign._module import CellAssignModule
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
Expand Down Expand Up @@ -72,8 +78,10 @@ def __init__(
self.cell_type_markers = cell_type_markers
rho = torch.Tensor(cell_type_markers.to_numpy())
n_cats_per_cov = (
self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"]
if "extra_categoricals" in self.scvi_setup_dict_
self.adata_manager.get_state_registry(_CONSTANTS.CAT_COVS_KEY)[
CategoricalJointObsField.N_CATS_PER_KEY
]
if _CONSTANTS.CAT_COVS_KEY in self.adata_manager.data_registry
else None
)

Expand All @@ -93,7 +101,7 @@ def __init__(
b_g_0=col_means_normalized,
n_batch=self.summary_stats["n_batch"],
n_cats_per_cov=n_cats_per_cov,
n_continuous_cov=self.summary_stats["n_continuous_covs"],
n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
**model_kwargs,
)
self._model_summary_string = (
Expand Down Expand Up @@ -197,7 +205,7 @@ def train(
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()

data_splitter = DataSplitter(
self.adata,
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
Expand All @@ -214,50 +222,45 @@ def train(
)
return runner()

@staticmethod
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
size_factor_key: str,
batch_key: Optional[str] = None,
layer: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
copy: bool = False,
) -> Optional[AnnData]:
layer: Optional[str] = None,
**kwargs,
):
"""
%(summary)s.
Parameters
----------
%(param_adata)s
size_factor_key
key in `adata.obs` with continuous valued size factors.
%(param_batch_key)s
%(param_layer)s
%(param_cat_cov_keys)s
%(param_cat_cov_keys)s
%(param_copy)s
Returns
-------
%(returns)s
%(param_cont_cov_keys)s
"""
setup_data = _setup_anndata(
adata,
batch_key=batch_key,
layer=layer,
categorical_covariate_keys=categorical_covariate_keys,
continuous_covariate_keys=continuous_covariate_keys,
copy=copy,
)
register_tensor_from_anndata(
adata if setup_data is None else setup_data,
"_size_factor",
"obs",
size_factor_key,
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(_CONSTANTS.X_KEY, layer, is_count_data=True),
NumericalObsField(_CONSTANTS.SIZE_FACTOR_KEY, size_factor_key),
CategoricalObsField(_CONSTANTS.BATCH_KEY, batch_key),
CategoricalJointObsField(
_CONSTANTS.CAT_COVS_KEY, categorical_covariate_keys
),
NumericalJointObsField(_CONSTANTS.CONT_COVS_KEY, continuous_covariate_keys),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
return setup_data
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)


class ClampCallback(Callback):
Expand Down
2 changes: 1 addition & 1 deletion scvi/external/cellassign/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_inference_input(self, tensors):

def _get_generative_input(self, tensors, inference_outputs):
x = tensors[_CONSTANTS.X_KEY]
size_factor = tensors["_size_factor"]
size_factor = tensors[_CONSTANTS.SIZE_FACTOR_KEY]

to_cat = []
if self.n_batch > 0:
Expand Down
96 changes: 66 additions & 30 deletions scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@
from torch.utils.data import DataLoader

from scvi import _CONSTANTS
from scvi.data.anndata import transfer_anndata_setup
from scvi.data.anndata._utils import _setup_anndata
from scvi.data.anndata import AnnDataManager
from scvi.data.anndata._compat import manager_from_setup_dict
from scvi.data.anndata._constants import (
_DR_ATTR_KEY,
_MODEL_NAME_KEY,
_SETUP_KWARGS_KEY,
)
from scvi.data.anndata.fields import CategoricalObsField, LayerField
from scvi.dataloaders import DataSplitter
from scvi.model._utils import _init_library_size, parse_use_gpu_arg
from scvi.model.base import BaseModelClass, VAEMixin
Expand Down Expand Up @@ -82,10 +88,11 @@ def __init__(
):
super(GIMVI, self).__init__()
self.adatas = [adata_seq, adata_spatial]
self.scvi_setup_dicts_ = {
"seq": adata_seq.uns["_scvi"],
"spatial": adata_spatial.uns["_scvi"],
self.adata_managers = {
"seq": self.get_anndata_manager(adata_seq, required=True),
"spatial": self.get_anndata_manager(adata_spatial, required=True),
}
self.registries_ = [adm.registry for adm in self.adata_managers.values()]

seq_var_names = adata_seq.var_names
spatial_var_names = adata_spatial.var_names
Expand All @@ -98,23 +105,27 @@ def __init__(
]
spatial_gene_loc = np.concatenate(spatial_gene_loc)
gene_mappings = [slice(None), spatial_gene_loc]
sum_stats = [d.uns["_scvi"]["summary_stats"] for d in self.adatas]
sum_stats = [adm.summary_stats for adm in self.adata_managers.values()]
n_inputs = [s["n_vars"] for s in sum_stats]

total_genes = adata_seq.uns["_scvi"]["summary_stats"]["n_vars"]
total_genes = n_inputs[0]

# since we are combining datasets, we need to increment the batch_idx
# of one of the datasets
adata_seq_n_batches = adata_seq.uns["_scvi"]["summary_stats"]["n_batch"]
adata_spatial.obs["_scvi_batch"] += adata_seq_n_batches
adata_seq_n_batches = sum_stats[0]["n_batch"]
adata_spatial.obs[
self.adata_managers["spatial"].data_registry[_CONSTANTS.BATCH_KEY][
_DR_ATTR_KEY
]
] += adata_seq_n_batches

n_batches = sum(s["n_batch"] for s in sum_stats)

library_log_means = []
library_log_vars = []
for adata in self.adatas:
for adata_manager in self.adata_managers.values():
adata_library_log_means, adata_library_log_vars = _init_library_size(
adata, n_batches
adata_manager, n_batches
)
library_log_means.append(adata_library_log_means)
library_log_vars.append(adata_library_log_vars)
Expand Down Expand Up @@ -184,9 +195,9 @@ def train(
)
self.train_indices_, self.test_indices_, self.validation_indices_ = [], [], []
train_dls, test_dls, val_dls = [], [], []
for i, ad in enumerate(self.adatas):
for i, adm in enumerate(self.adata_managers.values()):
ds = DataSplitter(
ad,
adm,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
Expand Down Expand Up @@ -491,9 +502,32 @@ def load(
"need to be the same and in the same order as the adata used to train the model."
)

scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_")
transfer_anndata_setup(scvi_setup_dicts["seq"], adata_seq)
transfer_anndata_setup(scvi_setup_dicts["spatial"], adata_spatial)
if "scvi_setup_dicts_" in attr_dict:
scvi_setup_dicts = attr_dict.pop("scvi_setup_dicts_")
for adata, scvi_setup_dict in zip(adatas, scvi_setup_dicts):
cls.register_manager(
manager_from_setup_dict(cls, adata, scvi_setup_dict)
)
else:
registries = attr_dict.pop("registries_")
for adata, registry in zip(adatas, registries):
if (
_MODEL_NAME_KEY in registry
and registry[_MODEL_NAME_KEY] != cls.__name__
):
raise ValueError(
"It appears you are loading a model from a different class."
)

if _SETUP_KWARGS_KEY not in registry:
raise ValueError(
"Saved model does not contain original setup inputs. "
"Cannot load the original setup."
)

cls.setup_anndata(
adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY]
)

# get the parameters for the class init signiture
init_params = attr_dict.pop("init_params_")
Expand Down Expand Up @@ -523,34 +557,36 @@ def load(
model.to_device(device)
return model

@staticmethod
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
copy: bool = False,
) -> Optional[AnnData]:
layer: Optional[str] = None,
**kwargs,
):
"""
%(summary)s.
Parameters
----------
%(param_adata)s
%(param_batch_key)s
%(param_labels_key)s
%(param_copy)s
Returns
-------
%(returns)s
%(param_layer)s
"""
return _setup_anndata(
adata,
batch_key=batch_key,
labels_key=labels_key,
copy=copy,
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(_CONSTANTS.X_KEY, layer, is_count_data=True),
CategoricalObsField(_CONSTANTS.BATCH_KEY, batch_key),
CategoricalObsField(_CONSTANTS.LABELS_KEY, labels_key),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)


class TrainDL(DataLoader):
Expand Down
Loading

0 comments on commit 1878ac5

Please sign in to comment.