Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

creating cache for faster access to results of previous cnmf extension calls #62

Merged
merged 34 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ebd2472
creating cache for faster access to results of previous cnmf extensio…
clewis7 Jun 28, 2022
22c4452
instantiating cache object in extension files, adding directives
clewis7 Jun 28, 2022
12e4ff1
reformatting files
clewis7 Jun 28, 2022
5bf7fd9
finishing up cache impl, adding clear_cache() and set_maxcache() capa…
clewis7 Jun 28, 2022
eb62b65
adding uuid column to cache so that different batch items can be deli…
clewis7 Jul 1, 2022
a40e6c4
updating kwarg comparison to handle numpy arrays
clewis7 Jul 1, 2022
3c5bd91
checking if args are equal when numpy array is passed as arg
clewis7 Jul 1, 2022
ed2fa18
updating check arg function so that arg equality is checked for every…
clewis7 Jul 1, 2022
52be0ae
returning func(*args, **kwargs) if len(cache)==0
clewis7 Jul 3, 2022
35897a8
adding ability for cache size to be controlled by memory size as oppo…
clewis7 Jul 3, 2022
fa68603
work in progress
clewis7 Jul 5, 2022
34ca881
fixing how size of item in cache are calculated for non-built in types
clewis7 Jul 5, 2022
98dcd3b
trying to make code more elegant for kushal
clewis7 Jul 5, 2022
19203fb
cache should not work with whether size is based on number of items o…
clewis7 Jul 5, 2022
7fd2630
debugging cache
clewis7 Jul 5, 2022
0157fc4
further debugging
clewis7 Jul 5, 2022
77d13f7
further changes to cache
clewis7 Jul 5, 2022
ee9b5d5
updates to cache and extensions which extensions use cache
clewis7 Jul 6, 2022
cc68e2f
final changes to cache, tests should pass except for linter
clewis7 Jul 8, 2022
398d952
kushal requested changes
clewis7 Jul 8, 2022
b4bf027
returning copies from cache for future downstream analysis, fixing cn…
clewis7 Jul 8, 2022
0da8b55
updates to returning a copy or original of extension outputs
clewis7 Jul 9, 2022
e0e80de
final changes to cache, need to write tests still
clewis7 Jul 9, 2022
34d3213
removing get_cache2()
clewis7 Jul 9, 2022
8de3b15
setting default copy = true
clewis7 Jul 9, 2022
62755a7
adding tests for cache, still need to debug
clewis7 Jul 9, 2022
7ee7df8
tests for cache, still need to fix issue with maxsize=0
clewis7 Jul 13, 2022
2d53a51
updating tests and cache to handle when cache is set to size 0
clewis7 Jul 13, 2022
6629216
updates to cache tests
clewis7 Jul 15, 2022
a1d30da
updates to cache tests
clewis7 Jul 18, 2022
6e34373
dumb kushal tests
clewis7 Jul 18, 2022
0928f24
insignificant merge conflict
clewis7 Jul 18, 2022
c13888c
fixing cache tests
clewis7 Jul 21, 2022
98b78d3
hopefully the last changes to cache as of now
clewis7 Jul 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions mesmerize_core/caiman_extensions/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from functools import wraps
from typing import Union, Optional

import pandas as pd
import time
import numpy as np
import sys
from caiman.source_extraction.cnmf import CNMF
import re
from sys import getsizeof
import copy


def _check_arg_equality(args, cache_args):
if not type(args) == type(cache_args):
return False
if isinstance(cache_args, np.ndarray):
return np.array_equal(cache_args, args)
else:
return cache_args == args


def _check_args_equality(args, cache_args):
if len(args) != len(cache_args):
clewis7 marked this conversation as resolved.
Show resolved Hide resolved
return False
equality = list()
if isinstance(args, tuple):
for arg, cache_arg in zip(args, cache_args):
equality.append(_check_arg_equality(arg, cache_arg))
else:
for k in args.keys():
equality.append(_check_arg_equality(args[k], cache_args[k]))
return all(equality)


def _return_wrapper(output, copy_bool):
if copy_bool == True:
return copy.deepcopy(output)
else:
return output


class Cache:
def __init__(self, cache_size: Optional[Union[int, str]] = None):
self.cache = pd.DataFrame(
data=None,
columns=["uuid", "function", "args", "kwargs", "return_val", "time_stamp"],
)
self.set_maxsize(cache_size)

def get_cache(self):
return self.cache

def clear_cache(self):
while len(self.cache.index) != 0:
self.cache.drop(index=self.cache.index[-1], axis=0, inplace=True)

def set_maxsize(self, max_size: Union[int, str]):
if max_size is None:
self.storage_type = "RAM"
self.size = 1024**3
elif isinstance(max_size, str):
self.storage_type = "RAM"
if max_size.endswith("G"):
self.size = int(max_size[:-1]) * 1024**3
elif max_size.endswith("M"):
self.size = int(max_size[:-1]) * 1024**2
else:
self.storage_type = "ITEMS"
self.size = max_size

def _get_cache_size_bytes(self):
"""Returns in bytes"""
cache_size = 0
for i in range(len(self.cache.index)):
if isinstance(self.cache.iloc[i, 4], np.ndarray):
cache_size += self.cache.iloc[i, 4].data.nbytes
elif isinstance(self.cache.iloc[i, 4], (tuple, list)):
for lists in self.cache.iloc[i, 4]:
for array in lists:
cache_size += array.data.nbytes
elif isinstance(self.cache.iloc[i, 4], CNMF):
sizes = list()
for attr in self.cache.iloc[i, 4].estimates.__dict__.values():
if isinstance(attr, np.ndarray):
sizes.append(attr.data.nbytes)
else:
sizes.append(getsizeof(attr))
else:
cache_size += sys.getsizeof(self.cache.iloc[i, 4])

return cache_size
kushalkolar marked this conversation as resolved.
Show resolved Hide resolved

def use_cache(self, func):
@wraps(func)
def _use_cache(instance, *args, **kwargs):
if "return_copy" in kwargs.keys():
return_copy = kwargs["return_copy"]
else:
return_copy = True

if self.size == 0:
kushalkolar marked this conversation as resolved.
Show resolved Hide resolved
self.clear_cache()
return _return_wrapper(func(instance, *args, **kwargs), return_copy)

clewis7 marked this conversation as resolved.
Show resolved Hide resolved
# if cache is empty, will always be a cache miss
if len(self.cache.index) == 0:
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
return _return_wrapper(return_val, copy_bool=return_copy)

kushalkolar marked this conversation as resolved.
Show resolved Hide resolved
# checking to see if there is a cache hit
for i in range(len(self.cache.index)):
if (
self.cache.iloc[i, 0] == instance._series["uuid"]
and self.cache.iloc[i, 1] == func.__name__
and _check_args_equality(args, self.cache.iloc[i, 2])
and _check_arg_equality(kwargs, self.cache.iloc[i, 3])
):
self.cache.iloc[i, 5] = time.time()
return_val = self.cache.iloc[i, 4]
return _return_wrapper(self.cache.iloc[i, 4], copy_bool=return_copy)

# no cache hit, must check cache limit, and if limit is going to be exceeded...remove least recently used and add new entry
# if memory type is 'ITEMS': drop the least recently used and then add new item
if self.storage_type == "ITEMS" and len(self.cache.index) >= self.size:
return_val = func(instance, *args, **kwargs)
self.cache.drop(
index=self.cache.sort_values(
by=["time_stamp"], ascending=False
).index[-1],
axis=0,
inplace=True,
)
self.cache = self.cache.reset_index(drop=True)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
return _return_wrapper(
self.cache.iloc[len(self.cache.index) - 1, 4], copy_bool=return_copy
)
# if memory type is 'RAM': add new item and then remove least recently used items until cache is under correct size again
elif self.storage_type == "RAM":
while self._get_cache_size_bytes() > self.size:
self.cache.drop(
index=self.cache.sort_values(
by=["time_stamp"], ascending=False
).index[-1],
axis=0,
inplace=True,
)
self.cache = self.cache.reset_index(drop=True)
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
# no matter the storage type if size is not going to be exceeded for either, then item can just be added to cache
else:
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]

return _return_wrapper(return_val, copy_bool=return_copy)

return _use_cache
20 changes: 14 additions & 6 deletions mesmerize_core/caiman_extensions/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from caiman.utils.visualization import get_contours as caiman_get_contours

from .common import validate
from .cache import Cache

cache = Cache()
kushalkolar marked this conversation as resolved.
Show resolved Hide resolved


@pd.api.extensions.register_series_accessor("cnmf")
Expand All @@ -21,6 +24,7 @@ class CNMFExtensions:
def __init__(self, s: pd.Series):
self._series = s

@validate("cnmf")
def get_cnmf_memmap(self) -> np.ndarray:
"""
Get the CNMF memmap
Expand Down Expand Up @@ -69,7 +73,8 @@ def get_output_path(self) -> Path:
return self._series.paths.resolve(self._series["outputs"]["cnmf-hdf5-path"])

@validate("cnmf")
def get_output(self) -> CNMF:
@cache.use_cache
def get_output(self, return_copy=True) -> CNMF:
"""
Returns
-------
Expand All @@ -82,8 +87,9 @@ def get_output(self) -> CNMF:

# TODO: Make the ``ixs`` parameter for spatial stuff optional
@validate("cnmf")
@cache.use_cache
def get_spatial_masks(
self, ixs_components: Optional[np.ndarray] = None, threshold: float = 0.01
self, ixs_components: Optional[np.ndarray] = None, threshold: float = 0.01, return_copy=True
) -> np.ndarray:
"""
Get binary masks of the spatial components at the given `ixs`
Expand Down Expand Up @@ -148,8 +154,9 @@ def _get_spatial_contours(
return contours

@validate("cnmf")
@cache.use_cache
def get_spatial_contours(
self, ixs_components: Optional[np.ndarray] = None
self, ixs_components: Optional[np.ndarray] = None, return_copy=True
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Get the contour and center of mass for each spatial footprint
Expand Down Expand Up @@ -181,8 +188,9 @@ def get_spatial_contours(
return coordinates, coms

@validate("cnmf")
@cache.use_cache
def get_temporal_components(
self, ixs_components: Optional[np.ndarray] = None, add_background: bool = False
self, ixs_components: Optional[np.ndarray] = None, add_background: bool = False, return_copy=True
) -> np.ndarray:
"""
Get the temporal components for this CNMF item
Expand Down Expand Up @@ -250,8 +258,8 @@ def get_reconstructed_movie(
if isinstance(ixs_frames, int):
ixs_frames = (ixs_frames, ixs_frames + 1)

dn = cnmf_obj.estimates.A[:,idx_components].dot(
cnmf_obj.estimates.C[idx_components, ixs_frames[0] : ixs_frames[1]]
dn = cnmf_obj.estimates.A[:, idx_components].dot(
cnmf_obj.estimates.C[idx_components, ixs_frames[0]: ixs_frames[1]]
)

if add_background:
Expand Down
4 changes: 2 additions & 2 deletions mesmerize_core/caiman_extensions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,10 @@ def get_input_movie_path(self) -> Path:
def get_input_movie(self) -> Union[np.ndarray, pims.FramesSequence]:
extension = self.get_input_movie_path().suffixes[-1]

if extension in ['.tiff', '.tif', '.btf']:
if extension in [".tiff", ".tif", ".btf"]:
return pims.open(str(self.get_input_movie_path()))

elif extension in ['.mmap', '.memmap']:
elif extension in [".mmap", ".memmap"]:
Yr, dims, T = load_memmap(str(self.get_input_movie_path()))
return np.reshape(Yr.T, [T] + list(dims), order="F")

Expand Down
4 changes: 4 additions & 0 deletions mesmerize_core/caiman_extensions/mcorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from .common import validate
from typing import *
from .cache import Cache

cache = Cache()


@pd.api.extensions.register_series_accessor("mcorr")
Expand Down Expand Up @@ -45,6 +48,7 @@ def get_output(self) -> np.ndarray:
return mc_movie

@validate("mcorr")
@cache.use_cache
def get_shifts(
self, pw_rigid: bool = False
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
Expand Down
Loading