Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass manager to data loader #1280

Merged
merged 2 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scvi/data/anndata/_compat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from anndata import AnnData

from . import _constants
from ._manager import AnnDataManager
from .fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
)
from .manager import AnnDataManager


def manager_from_setup_dict(
Expand Down
4 changes: 2 additions & 2 deletions scvi/data/anndata/fields/_layer_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
else _constants._ADATA_ATTRS.LAYERS
)
self._attr_key = layer
self._is_count_data = is_count_data
self.is_count_data = is_count_data

@property
def registry_key(self):
Expand All @@ -56,7 +56,7 @@ def validate_field(self, adata: AnnData) -> None:
super().validate_field(adata)
x = self.get_field(adata)

if self._is_count_data and not _check_nonnegative_integers(x):
if self.is_count_data and not _check_nonnegative_integers(x):
logger_data_loc = (
"adata.X" if self.attr_key is None else f"adata.layers[{self.attr_key}]"
)
Expand Down
10 changes: 3 additions & 7 deletions scvi/data/anndata/_manager.py → scvi/data/anndata/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def add_field(self, field: Type[BaseAnnDataField]) -> None:
), "Fields have been frozen. Create a new AnnDataManager object for additional fields."
self.fields.add(field)

def _register_fields(
def register_fields(
self,
adata: AnnData,
source_setup_dict: Optional[dict] = None,
**transfer_kwargs
):
"""
Helper function with registers each field associated with this instance.
Registers each field associated with this instance with the AnnData object.

Either registers or transfers the setup from `source_setup_dict` if passed in.

Expand Down Expand Up @@ -114,10 +114,6 @@ def _register_fields(

self._assign_uuid()

def register_fields(self, adata: AnnData):
"""Registers each field associated with this instance with the AnnData object."""
return self._register_fields(adata)

def transfer_setup(
self, adata_target: AnnData, source_setup_dict: Optional[dict] = None, **kwargs
) -> AnnDataManager:
Expand All @@ -143,7 +139,7 @@ def transfer_setup(
)
fields = self.fields
new_adata_manager = self.__class__(fields)
new_adata_manager._register_fields(adata_target, setup_dict, **kwargs)
new_adata_manager.register_fields(adata_target, setup_dict, **kwargs)
return new_adata_manager

def get_adata_uuid(self) -> UUID:
Expand Down
17 changes: 10 additions & 7 deletions scvi/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import logging
from typing import Optional, Union

import anndata
import numpy as np
import torch
from torch.utils.data import DataLoader

from scvi.data.anndata.manager import AnnDataManager

from ._anntorchdataset import AnnTorchDataset

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -91,8 +92,8 @@ class AnnDataLoader(DataLoader):

Parameters
----------
adata
An anndata objects
adata_manager
AnnDataManager object that has been created via setup_anndata.
shuffle
Whether the data should be shuffled
indices
Expand All @@ -109,7 +110,7 @@ class AnnDataLoader(DataLoader):

def __init__(
self,
adata: anndata.AnnData,
adata_manager: AnnDataManager,
shuffle=False,
indices=None,
batch_size=128,
Expand All @@ -118,11 +119,11 @@ def __init__(
**data_loader_kwargs,
):

if "_scvi" not in adata.uns.keys():
if adata_manager.adata is None:
raise ValueError("Please run setup_anndata() on your anndata object first.")

if data_and_attributes is not None:
data_registry = adata.uns["_scvi"]["data_registry"]
data_registry = adata_manager.get_data_registry()
for key in data_and_attributes.keys():
if key not in data_registry.keys():
raise ValueError(
Expand All @@ -131,7 +132,9 @@ def __init__(
)
)

self.dataset = AnnTorchDataset(adata, getitem_tensors=data_and_attributes)
self.dataset = AnnTorchDataset(
adata_manager.adata, getitem_tensors=data_and_attributes
)

sampler_kwargs = {
"batch_size": batch_size,
Expand Down
1 change: 0 additions & 1 deletion scvi/dataloaders/_anntorchdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def setup_getitem(self):
----------
getitem_tensors:
Either a list of keys in the scvi data registry to return when getitem is called
or

Examples
--------
Expand Down
11 changes: 6 additions & 5 deletions scvi/dataloaders/_concat_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import List, Optional, Union

import numpy as np
from anndata import AnnData
from torch.utils.data import DataLoader

from scvi.data.anndata.manager import AnnDataManager

from ._ann_dataloader import AnnDataLoader


Expand All @@ -14,8 +15,8 @@ class ConcatDataLoader(DataLoader):

Parameters
----------
adata
AnnData object that has been registered via setup_anndata.
adata_manager
AnnDataManager object that has been created via setup_anndata.
indices_list
List where each element is a list of indices in the adata to load
shuffle
Expand All @@ -32,7 +33,7 @@ class ConcatDataLoader(DataLoader):

def __init__(
self,
adata: AnnData,
adata_manager: AnnDataManager,
indices_list: List[List[int]],
shuffle: bool = False,
batch_size: int = 128,
Expand All @@ -44,7 +45,7 @@ def __init__(
for indices in indices_list:
self.dataloaders.append(
AnnDataLoader(
adata,
adata_manager,
indices=indices,
shuffle=shuffle,
batch_size=batch_size,
Expand Down
Loading