Skip to content

Commit

Permalink
Merge pull request #62 from nel-lab/cache
Browse files Browse the repository at this point in the history
creating cache for faster access to results of previous cnmf extension calls
  • Loading branch information
kushalkolar authored Jul 21, 2022
2 parents 6ba1229 + 98b78d3 commit 88bb217
Show file tree
Hide file tree
Showing 5 changed files with 604 additions and 156 deletions.
188 changes: 188 additions & 0 deletions mesmerize_core/caiman_extensions/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from functools import wraps
from typing import Union, Optional

import pandas as pd
import time
import numpy as np
import sys
from caiman.source_extraction.cnmf import CNMF
import re
from sys import getsizeof
import copy


def _check_arg_equality(args, cache_args):
if not type(args) == type(cache_args):
return False
if isinstance(cache_args, np.ndarray):
return np.array_equal(cache_args, args)
else:
return cache_args == args


def _check_args_equality(args, cache_args):
if len(args) != len(cache_args):
return False
equality = list()
if isinstance(args, tuple):
for arg, cache_arg in zip(args, cache_args):
equality.append(_check_arg_equality(arg, cache_arg))
else:
for k in args.keys():
equality.append(_check_arg_equality(args[k], cache_args[k]))
return all(equality)


def _return_wrapper(output, copy_bool):
if copy_bool == True:
return copy.deepcopy(output)
else:
return output


class Cache:
def __init__(self, cache_size: Optional[Union[int, str]] = None):
self.cache = pd.DataFrame(
data=None,
columns=["uuid", "function", "args", "kwargs", "return_val", "time_stamp"],
)
self.set_maxsize(cache_size)

def get_cache(self):
return self.cache

def clear_cache(self):
while len(self.cache.index) != 0:
self.cache.drop(index=self.cache.index[-1], axis=0, inplace=True)

def set_maxsize(self, max_size: Union[int, str]):
if max_size is None:
self.storage_type = "RAM"
self.size = 1024**3
elif isinstance(max_size, str):
self.storage_type = "RAM"
if max_size.endswith("G"):
self.size = int(max_size[:-1]) * 1024**3
elif max_size.endswith("M"):
self.size = int(max_size[:-1]) * 1024**2
else:
self.storage_type = "ITEMS"
self.size = max_size

def _get_cache_size_bytes(self):
"""Returns in bytes"""
cache_size = 0
for i in range(len(self.cache.index)):
if isinstance(self.cache.iloc[i, 4], np.ndarray):
cache_size += self.cache.iloc[i, 4].data.nbytes
elif isinstance(self.cache.iloc[i, 4], (tuple, list)):
for lists in self.cache.iloc[i, 4]:
for array in lists:
cache_size += array.data.nbytes
elif isinstance(self.cache.iloc[i, 4], CNMF):
sizes = list()
for attr in self.cache.iloc[i, 4].estimates.__dict__.values():
if isinstance(attr, np.ndarray):
sizes.append(attr.data.nbytes)
else:
sizes.append(getsizeof(attr))
else:
cache_size += sys.getsizeof(self.cache.iloc[i, 4])

return cache_size

def use_cache(self, func):
@wraps(func)
def _use_cache(instance, *args, **kwargs):
if "return_copy" in kwargs.keys():
return_copy = kwargs["return_copy"]
else:
return_copy = True

if self.size == 0:
self.clear_cache()
return _return_wrapper(func(instance, *args, **kwargs), return_copy)

# if cache is empty, will always be a cache miss
if len(self.cache.index) == 0:
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
return _return_wrapper(return_val, copy_bool=return_copy)

# checking to see if there is a cache hit
for i in range(len(self.cache.index)):
if (
self.cache.iloc[i, 0] == instance._series["uuid"]
and self.cache.iloc[i, 1] == func.__name__
and _check_args_equality(args, self.cache.iloc[i, 2])
and _check_arg_equality(kwargs, self.cache.iloc[i, 3])
):
self.cache.iloc[i, 5] = time.time()
return_val = self.cache.iloc[i, 4]
return _return_wrapper(self.cache.iloc[i, 4], copy_bool=return_copy)

# no cache hit, must check cache limit, and if limit is going to be exceeded...remove least recently used and add new entry
# if memory type is 'ITEMS': drop the least recently used and then add new item
if self.storage_type == "ITEMS" and len(self.cache.index) >= self.size:
return_val = func(instance, *args, **kwargs)
self.cache.drop(
index=self.cache.sort_values(
by=["time_stamp"], ascending=False
).index[-1],
axis=0,
inplace=True,
)
self.cache = self.cache.reset_index(drop=True)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
return _return_wrapper(
self.cache.iloc[len(self.cache.index) - 1, 4], copy_bool=return_copy
)
# if memory type is 'RAM': add new item and then remove least recently used items until cache is under correct size again
elif self.storage_type == "RAM":
while self._get_cache_size_bytes() > self.size:
self.cache.drop(
index=self.cache.sort_values(
by=["time_stamp"], ascending=False
).index[-1],
axis=0,
inplace=True,
)
self.cache = self.cache.reset_index(drop=True)
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]
# no matter the storage type if size is not going to be exceeded for either, then item can just be added to cache
else:
return_val = func(instance, *args, **kwargs)
self.cache.loc[len(self.cache.index)] = [
instance._series["uuid"],
func.__name__,
args,
kwargs,
return_val,
time.time(),
]

return _return_wrapper(return_val, copy_bool=return_copy)

return _use_cache
16 changes: 12 additions & 4 deletions mesmerize_core/caiman_extensions/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from caiman.utils.visualization import get_contours as caiman_get_contours

from .common import validate
from .cache import Cache

cache = Cache()


@pd.api.extensions.register_series_accessor("cnmf")
Expand All @@ -22,6 +25,7 @@ class CNMFExtensions:
def __init__(self, s: pd.Series):
self._series = s

@validate("cnmf")
def get_cnmf_memmap(self) -> np.ndarray:
"""
Get the CNMF memmap
Expand Down Expand Up @@ -70,7 +74,8 @@ def get_output_path(self) -> Path:
return self._series.paths.resolve(self._series["outputs"]["cnmf-hdf5-path"])

@validate("cnmf")
def get_output(self) -> CNMF:
@cache.use_cache
def get_output(self, return_copy=True) -> CNMF:
"""
Returns
-------
Expand All @@ -83,8 +88,9 @@ def get_output(self) -> CNMF:

# TODO: Make the ``ixs`` parameter for spatial stuff optional
@validate("cnmf")
@cache.use_cache
def get_spatial_masks(
self, ixs_components: Optional[np.ndarray] = None, threshold: float = 0.01
self, ixs_components: Optional[np.ndarray] = None, threshold: float = 0.01, return_copy=True
) -> np.ndarray:
"""
Get binary masks of the spatial components at the given `ixs`
Expand Down Expand Up @@ -149,8 +155,9 @@ def _get_spatial_contours(
return contours

@validate("cnmf")
@cache.use_cache
def get_spatial_contours(
self, ixs_components: Optional[np.ndarray] = None
self, ixs_components: Optional[np.ndarray] = None, return_copy=True
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Get the contour and center of mass for each spatial footprint
Expand Down Expand Up @@ -182,8 +189,9 @@ def get_spatial_contours(
return coordinates, coms

@validate("cnmf")
@cache.use_cache
def get_temporal_components(
self, ixs_components: Optional[np.ndarray] = None, add_background: bool = False
self, ixs_components: Optional[np.ndarray] = None, add_background: bool = False, return_copy=True
) -> np.ndarray:
"""
Get the temporal components for this CNMF item
Expand Down
4 changes: 2 additions & 2 deletions mesmerize_core/caiman_extensions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,10 @@ def get_input_movie_path(self) -> Path:
def get_input_movie(self) -> Union[np.ndarray, pims.FramesSequence]:
extension = self.get_input_movie_path().suffixes[-1]

if extension in ['.tiff', '.tif', '.btf']:
if extension in [".tiff", ".tif", ".btf"]:
return pims.open(str(self.get_input_movie_path()))

elif extension in ['.mmap', '.memmap']:
elif extension in [".mmap", ".memmap"]:
Yr, dims, T = load_memmap(str(self.get_input_movie_path()))
return np.reshape(Yr.T, [T] + list(dims), order="F")

Expand Down
4 changes: 4 additions & 0 deletions mesmerize_core/caiman_extensions/mcorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from .common import validate
from typing import *
from .cache import Cache

cache = Cache()


@pd.api.extensions.register_series_accessor("mcorr")
Expand Down Expand Up @@ -45,6 +48,7 @@ def get_output(self) -> np.ndarray:
return mc_movie

@validate("mcorr")
@cache.use_cache
def get_shifts(
self, pw_rigid: bool = False
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
Expand Down
Loading

0 comments on commit 88bb217

Please sign in to comment.