Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid clobber in setup_anndata using model instance and manager ID back references #1342

Merged
merged 7 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scvi/data/anndata/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Keys for UUIDs used for referencing model class manager stores.

_SCVI_UUID_KEY = "_scvi_uuid"
_SOURCE_SCVI_UUID_KEY = "_source_scvi_uuid"
_MANAGER_UUID_KEY = "_scvi_manager_uuid"

# scVI Registry Constants
# -----------------------
Expand Down
53 changes: 30 additions & 23 deletions scvi/data/anndata/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import sys
from collections import defaultdict
from typing import Optional, Sequence, Type
from copy import deepcopy
from typing import Optional, Sequence, Type, Union
from uuid import uuid4

import numpy as np
import pandas as pd
import rich
from anndata import AnnData

Expand Down Expand Up @@ -49,11 +52,11 @@ def __init__(
fields: Optional[Sequence[Type[BaseAnnDataField]]] = None,
setup_method_args: Optional[dict] = None,
) -> None:
self.id = str(uuid4())
self.adata = None
self.fields = fields or []
self._registry = {
_constants._SCVI_VERSION_KEY: scvi.__version__,
_constants._SOURCE_SCVI_UUID_KEY: None,
_constants._MODEL_NAME_KEY: None,
_constants._SETUP_KWARGS_KEY: None,
_constants._FIELD_REGISTRIES_KEY: defaultdict(dict),
Expand Down Expand Up @@ -97,23 +100,13 @@ def _assign_uuid(self):
scvi_uuid = self.adata.uns[_constants._SCVI_UUID_KEY]
self._registry[_constants._SCVI_UUID_KEY] = scvi_uuid

def _assign_source_uuid(self, source_registry: Optional[dict]):
def _assign_most_recent_manager_uuid(self):
justjhong marked this conversation as resolved.
Show resolved Hide resolved
"""
Assigns a source UUID to the AnnData object.

If setup not transferred from a source, set to current UUID.
Assigns a last manager UUID to the AnnData object for future validation.
"""
self._assert_anndata_registered()

if source_registry is None:
source_registry = self._registry
self._registry[_constants._SOURCE_SCVI_UUID_KEY] = self._registry[
_constants._SCVI_UUID_KEY
]

def _freeze_fields(self):
"""Freezes the fields associated with this instance."""
self.fields = tuple(self.fields)
self.adata.uns[_constants._MANAGER_UUID_KEY] = self.id

def register_fields(
self, adata: AnnData, source_registry: Optional[dict] = None, **transfer_kwargs
Expand All @@ -139,7 +132,6 @@ def register_fields(
)

self._validate_anndata_object(adata)
self.adata = adata
field_registries = self._registry[_constants._FIELD_REGISTRIES_KEY]

for field in self.fields:
Expand All @@ -160,23 +152,27 @@ def register_fields(
source_registry[_constants._FIELD_REGISTRIES_KEY][
field.registry_key
][_constants._STATE_REGISTRY_KEY],
self.adata,
adata,
**transfer_kwargs,
)
else:
field_registry[
_constants._STATE_REGISTRY_KEY
] = field.register_field(self.adata)
] = field.register_field(adata)

# Compute and set summary stats for the given field.
state_registry = field_registry[_constants._STATE_REGISTRY_KEY]
field_registry[_constants._SUMMARY_STATS_KEY] = field.get_summary_stats(
state_registry
)

self._freeze_fields()
# Save arguments for register_fields.
self._source_registry = deepcopy(source_registry)
self._transfer_kwargs = deepcopy(transfer_kwargs)

self.adata = adata
self._assign_uuid()
self._assign_source_uuid(source_registry)
self._assign_most_recent_manager_uuid()

def transfer_setup(self, adata_target: AnnData, **kwargs) -> AnnDataManager:
"""
Expand All @@ -202,7 +198,18 @@ def transfer_setup(self, adata_target: AnnData, **kwargs) -> AnnDataManager:
new_adata_manager.register_fields(adata_target, self._registry, **kwargs)
return new_adata_manager

def get_adata_uuid(self) -> str:
def validate(self) -> None:
"""Checks if AnnData was last setup with this AnnDataManager instance and reregisters it if not."""
self._assert_anndata_registered()
most_recent_manager_id = self.adata.uns[_constants._MANAGER_UUID_KEY]
# Re-register fields with same arguments if this AnnData object has been
# registered with a different AnnDataManager.
if most_recent_manager_id != self.id:
adata, self.adata = self.adata, None # Reset self.adata.
self.register_fields(adata, self._source_registry, **self._transfer_kwargs)

@property
def adata_uuid(self) -> str:
"""Returns the UUID for the AnnData object registered with this instance."""
self._assert_anndata_registered()

Expand Down Expand Up @@ -240,7 +247,7 @@ def summary_stats(self) -> attrdict:

return attrdict(summary_stats)

def get_from_registry(self, registry_key: str) -> np.ndarray:
def get_from_registry(self, registry_key: str) -> Union[np.ndarray, pd.DataFrame]:
"""
Returns the object in AnnData associated with the key in the data registry.

Expand All @@ -251,7 +258,7 @@ def get_from_registry(self, registry_key: str) -> np.ndarray:

Returns
-------
The requested data as a NumPy array.
The requested data.
"""
data_loc = self.data_registry[registry_key]
attr_name, attr_key = (
Expand Down
16 changes: 13 additions & 3 deletions scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,22 @@ def __init__(
**model_kwargs,
):
super(GIMVI, self).__init__()
if adata_seq is adata_spatial:
raise ValueError(
"`adata_seq` and `adata_spatial` cannot point to the same object. "
"If you would really like to do this, make a copy of the object and pass it in as `adata_spatial`."
)
self.adatas = [adata_seq, adata_spatial]
self.adata_managers = {
"seq": self.get_anndata_manager(adata_seq, required=True),
"spatial": self.get_anndata_manager(adata_spatial, required=True),
"seq": self._get_most_recent_anndata_manager(adata_seq, required=True),
"spatial": self._get_most_recent_anndata_manager(
adata_spatial, required=True
),
}
self.registries_ = [adm.registry for adm in self.adata_managers.values()]
self.registries_ = []
for adm in self.adata_managers.values():
self._register_manager_for_instance(adm)
self.registries_.append(adm.registry)

seq_var_names = adata_seq.var_names
spatial_var_names = adata_spatial.var_names
Expand Down
4 changes: 2 additions & 2 deletions scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def from_scvi_model(
).original_key

if adata is not None:
cls.register_manager(orig_adata_manager.transfer_setup(adata))
adata_manager = cls.get_anndata_manager(adata)
adata_manager = orig_adata_manager.transfer_setup(adata)
cls.register_manager(adata_manager)
else:
adata_manager = orig_adata_manager
adata = adata_manager.adata
Expand Down
5 changes: 2 additions & 3 deletions scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,15 @@ def load_query_data(
**registry[_SETUP_KWARGS_KEY]
)

adata_manager = cls.get_anndata_manager(adata, required=True)
model = _initialize_model(cls, adata, attr_dict)
adata_manager = model.get_anndata_manager(adata, required=True)

version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".")
if version_split[1] < "8" and version_split[0] == "0":
warnings.warn(
"Query integration should be performed using models trained with version >= 0.8"
)

model = _initialize_model(cls, adata, attr_dict)

model.to_device(device)

# model tweaking
Expand Down
Loading