Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load tiles in parallel on workers and add options to TissueDetectionHE #336

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
94 changes: 57 additions & 37 deletions pathml/core/slide_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from io import BytesIO

import dask
import numpy as np
import openslide
from javabridge.jutil import JavaException
Expand Down Expand Up @@ -62,8 +63,14 @@ class OpenSlideBackend(SlideBackend):
def __init__(self, filename):
logger.info(f"OpenSlideBackend loading file at: {filename}")
self.filename = filename
self.slide = openslide.open_slide(filename=filename)
self.level_count = self.slide.level_count

@property
def slide(self):
return openslide.open_slide(filename=self.filename)

@property
def level_count(self):
return self.slide.level_count

def __repr__(self):
return f"OpenSlideBackend('{self.filename}')"
Expand Down Expand Up @@ -211,9 +218,10 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, level=0):
for ix_i in range(n_tiles_i):
for ix_j in range(n_tiles_j):
coords = (int(ix_i * stride_i), int(ix_j * stride_j))
# get image for tile
tile_im = self.extract_region(location=coords, size=shape, level=level)
yield pathml.core.tile.Tile(image=tile_im, coords=coords)
image = dask.delayed(self.extract_region)(
location=coords, size=shape, level=level
)
yield pathml.core.tile.Tile(image, coords=coords)


def _init_logger():
Expand Down Expand Up @@ -421,7 +429,10 @@ def extract_region(
f"Multi-level images not supported with series_as_channels=True. Input 'level={level}' invalid. Use 'level=0'."
)

javabridge.start_vm(class_path=bioformats.JARS, max_heap_size="100G")
javabridge.start_vm(
class_path=bioformats.JARS, max_heap_size="100G", run_headless=True
)

with bioformats.ImageReader(str(self.filename), perform_init=True) as reader:
# expand size
logger.info(f"extracting region with input size = {size}")
Expand Down Expand Up @@ -593,27 +604,28 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, level=0, **kwargs):
for ix_j in range(n_tiles_j):
coords = (int(ix_i * stride_i), int(ix_j * stride_j))
if coords[0] + shape[0] < i and coords[1] + shape[1] < j:
# get image for tile
tile_im = self.extract_region(
image = dask.delayed(self.extract_region)(
location=coords, size=shape, level=level, **kwargs
)
yield pathml.core.tile.Tile(image=tile_im, coords=coords)
# Image on edge and needs to be padded with 0s
else:
unpaddedshape = (
unpadded_shape = (
i - coords[0] if coords[0] + shape[0] > i else shape[0],
j - coords[1] if coords[1] + shape[1] > j else shape[1],
)
tile_im = self.extract_region(
location=coords, size=unpaddedshape, level=level, **kwargs
edge_image = dask.delayed(self.extract_region)(
location=coords, size=unpadded_shape, level=level, **kwargs
)
zeroarrayshape = list(tile_im.shape)
zeroarrayshape[0], zeroarrayshape[1] = (
list(shape)[0],
list(shape)[1],
)
padded_im = np.zeros(zeroarrayshape)
padded_im[: tile_im.shape[0], : tile_im.shape[1], ...] = tile_im
yield pathml.core.tile.Tile(image=padded_im, coords=coords)

def pad(image):
"""Pads edge tiles with zeros."""
padded = np.zeros((*shape, *image.shape[:-2]))
padded[: image.shape[0], : image.shape[1]] = image
return padded

# Need to delay to use shape of edge_image
image = dask.delayed(pad)(edge_image)
yield pathml.core.tile.Tile(image=image, coords=coords)


class DICOMBackend(SlideBackend):
Expand Down Expand Up @@ -653,19 +665,25 @@ def __init__(self, filename):
f"DICOM metadata: frame_shape={self.frame_shape}, nrows = {self.n_rows}, ncols = {self.n_cols}"
)

# actual file
self.fp = DicomFile(self.filename, mode="rb")
self.fp.is_little_endian = self.transfer_syntax_uid.is_little_endian
self.fp.is_implicit_VR = self.transfer_syntax_uid.is_implicit_VR
fp = self.fp

# need to do this to advance the file to the correct point, at the beginning of the pixels
self.metadata = dcmread(self.fp, stop_before_pixels=True)
self.pixel_data_offset = self.fp.tell()
self.fp.seek(self.pixel_data_offset, 0)
self.metadata = dcmread(fp, stop_before_pixels=True)
pixel_data_offset = fp.tell()
fp.seek(pixel_data_offset, 0)
# note that reading this tag is necessary to advance the file to correct position
_ = TupleTag(self.fp.read_tag())
_ = TupleTag(fp.read_tag())
# get basic offset table, to enable reading individual frames without loading entire image
self.bot = self.get_bot(self.fp)
self.first_frame = self.fp.tell()
self.bot = self.get_bot(fp)
self.first_frame = fp.tell()

@property
def fp(self):
"""actual file"""
fp = DicomFile(self.filename, mode="rb")
fp.is_little_endian = self.transfer_syntax_uid.is_little_endian
fp.is_implicit_VR = self.transfer_syntax_uid.is_implicit_VR
return fp

def __repr__(self):
out = f"DICOMBackend('{self.filename}')\n"
Expand Down Expand Up @@ -807,7 +825,9 @@ def _read_frame(self, frame_ix):
np.ndarray: pixel data of that frame
"""
frame_offset = self.bot[int(frame_ix)]
self.fp.seek(self.first_frame + frame_offset, 0)
# self.fp refers to a different filelike object each time it is accessed
fp = self.fp
fp.seek(self.first_frame + frame_offset, 0)
try:
stop_at = self.bot[frame_ix + 1] - frame_offset
except IndexError:
Expand All @@ -816,11 +836,11 @@ def _read_frame(self, frame_ix):
# A frame may comprised of multiple chunks
chunks = []
while True:
tag = TupleTag(self.fp.read_tag())
tag = TupleTag(fp.read_tag())
if n == stop_at or int(tag) == SequenceDelimiterTag:
break
length = self.fp.read_UL()
chunks.append(self.fp.read(length))
length = fp.read_UL()
chunks.append(fp.read(length))
n += 8 + length

frame_bytes = b"".join(chunks)
Expand Down Expand Up @@ -899,7 +919,7 @@ def generate_tiles(self, shape, stride, pad, level=0, **kwargs):
if i >= (self.n_frames - self.n_cols):
continue

frame_im = self.extract_region(location=i)
im = dask.delayed(self.extract_region)(location=i)
coords = self._index_to_coords(i)
frame_tile = pathml.core.tile.Tile(image=frame_im, coords=coords)
yield frame_tile
tile = pathml.core.tile.Tile(image=im, coords=coords)
yield tile
68 changes: 45 additions & 23 deletions pathml/core/slide_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pathml.core
import pathml.preprocessing.pipeline
from pathml.core.slide_types import SlideType
from pathml.preprocessing.transforms import DropTileException


def infer_backend(path):
Expand Down Expand Up @@ -309,31 +310,47 @@ def run(
)

# map pipeline application onto each tile
processed_tile_futures = []
futures = [
client.submit(pipeline.apply, tile)
for tile in self.generate_tiles(
level=level,
shape=tile_size,
stride=tile_stride,
pad=tile_pad,
**kwargs,
)
]

for tile in self.generate_tiles(
level=level,
shape=tile_size,
stride=tile_stride,
pad=tile_pad,
**kwargs,
):
if not tile.slide_type:
tile.slide_type = self.slide_type
# explicitly scatter data, i.e. send the tile data out to the cluster before applying the pipeline
# according to dask, this can reduce scheduler burden and keep data on workers
big_future = client.scatter(tile)
f = client.submit(pipeline.apply, big_future)
processed_tile_futures.append(f)

# as tiles are processed, add them to h5
for future, tile in dask.distributed.as_completed(
processed_tile_futures, with_results=True
# After a worker processes a tile, add the tile to h5
for future, result in dask.distributed.as_completed(
futures, with_results=True, raise_errors=False
):
self.tiles.add(tile)
if future.status == "finished":
self.tiles.add(result)
if future.status == "error":
typ, exc, tb = result
if typ is DropTileException:
pass
else:
raise exc.with_traceback(tb)
# TODO: Free memory used for tile
# Each in-memory future holding a Tile shows a size of 48 bytes on the Dask dashboard
# which clearly does not include image data.
# Could it be that loaded image data is somehow not being garbage collected with Tiles?

# # all of these still leave unmanaged memory on each worker
# future.release()
# future.cancel()
# del result
# del future
# del futures

if shutdown_after:
client.shutdown()
else:
pass
# Stopgap to free unmanaged memory on client before processing another slide
client.restart()

else:
for tile in self.generate_tiles(
Expand All @@ -343,8 +360,6 @@ def run(
pad=tile_pad,
**kwargs,
):
if not tile.slide_type:
tile.slide_type = self.slide_type
pipeline.apply(tile)
self.tiles.add(tile)

Expand Down Expand Up @@ -410,14 +425,19 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, **kwargs):
pathml.core.tile.Tile: Extracted Tile object
"""
for tile in self.slide.generate_tiles(shape, stride, pad, **kwargs):
# TODO: move to worker!! (forces loading data on main thread)

# add masks for tile, if possible
# i.e. if the SlideData has a Masks object, and the tile has coordinates
if self.masks is not None and tile.coords is not None:
# masks not supported if pad=True
# to implement, need to update Mask.slice to support slices that go beyond the full mask
if not pad:
i, j = tile.coords
di, dj = tile.image.shape[0:2]
# Accessing image loads data on main thread
# dask.delayed waits until compute is called on worker
shape = dask.delayed(tile).image.shape[0:2]
di, dj = shape[0], shape[1]
# add the Masks object for the masks corresponding to the tile
# this assumes that the tile didn't already have any masks
# this should work since the backend reads from image only
Expand All @@ -430,6 +450,8 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, **kwargs):
tile_slices = [slice(i, i + di), slice(j, j + dj)]
tile.masks = self.masks.slice(tile_slices)

# TODO: end move to worker

# add slide-level labels to each tile, if possible
if self.labels is not None:
tile.labels = self.labels
Expand Down
45 changes: 31 additions & 14 deletions pathml/core/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from collections import OrderedDict

import anndata
import dask
import h5py
import matplotlib.pyplot as plt
import numpy as np
from dask.delayed import Delayed

import pathml.core.masks

Expand All @@ -21,7 +23,7 @@ class Tile:
on labelling the top-leftmost pixel as (0, 0)

Args:
image (np.ndarray): Image array of tile
image (np.ndarray or dask.delayed.Delayed): Tile image or dask.delayed.Delayed object to load image
coords (tuple): Coordinates of tile relative to the whole-slide image.
The (i,j) coordinate system is based on labelling the top-leftmost pixel of the WSI as (0, 0).
name (str, optional): Name of tile
Expand Down Expand Up @@ -60,9 +62,9 @@ def __init__(
time_series=None,
):
# check inputs
assert isinstance(
assert isinstance(image, Delayed) or isinstance(
image, np.ndarray
), f"image of type {type(image)} must be a np.ndarray"
), f"image of type {type(image)} must be a np.ndarray or a dask.delayed.Delayed object"
assert masks is None or isinstance(
masks, dict
), f"masks is of type {type(masks)} but must be of type dict"
Expand Down Expand Up @@ -115,23 +117,38 @@ def __init__(
counts, anndata.AnnData
), f"counts is of type {type(counts)} but must be of type anndata.AnnData or None"

if masks:
for val in masks.values():
if val.shape[:2] != image.shape[:2]:
raise ValueError(
f"mask is of shape {val.shape} but must match tile shape {image.shape}"
)
self.masks = masks
else:
self.masks = OrderedDict()

self.image = image
self._image = image
self.masks = masks if masks else OrderedDict()
self.name = name
self.coords = coords
self.slide_type = slide_type
self.labels = labels
self.counts = counts

@property
def image(self):
if isinstance(self._image, Delayed):
image = dask.compute(self._image, scheduler="single-threaded")
if isinstance(image, tuple):
image = image[0]
assert isinstance(
image, np.ndarray
), f"image of type {type(image)} must be a np.ndarray"
for val in self.masks.values():
if val.shape[:2] != image.shape[:2]:
raise ValueError(
f"mask is of shape {val.shape} but must match tile shape {image.shape}"
)
self._image = image
return self._image

@image.setter
def image(self, image):
assert isinstance(
image, np.ndarray
), f"image of type {type(image)} must be a np.ndarray"
self._image = image

def __repr__(self):
out = []
out.append(f"Tile(coords={self.coords}")
Expand Down
Loading