Skip to content

Commit

Permalink
fix wsi
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldoucet committed Aug 9, 2024
1 parent f5c83cf commit a6dabc5
Showing 1 changed file with 50 additions and 11 deletions.
61 changes: 50 additions & 11 deletions src/hest/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import openslide
from PIL import Image
from shapely import Polygon


class CucimWarningSingleton:
Expand Down Expand Up @@ -115,7 +116,17 @@ def read_region(self, location, level, size) -> np.ndarray:
img = self.img
x_start, y_start = location[0], location[1]
x_size, y_size = size[0], size[1]
return img[y_start:y_start + y_size, x_start:x_start + x_size]
x_end, y_end = x_start + x_size, y_start + y_size
padding_left = max(0 - x_start, 0)
padding_top = max(0 - y_start, 0)
padding_right = max(x_start + x_size - self.width, 0)
padding_bottom = max(y_start + y_size - self.height, 0)
x_start, y_start = max(x_start, 0), max(y_start, 0)
x_end, y_end = min(x_end, self.width), min(y_end, self.height)
tile = img[y_start:y_end, x_start:x_end]
padded_tile = np.pad(tile, ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)), mode='constant', constant_values=0)

return padded_tile

def get_thumbnail(self, width, height) -> np.ndarray:
return cv2.resize(self.img, (width, height))
Expand Down Expand Up @@ -239,7 +250,8 @@ def __init__(
overlap: int = 0,
mask: gpd.GeoDataFrame = None,
coords_only = False,
custom_coords = None
custom_coords = None,
threshold = 0.15
):
""" Initialize patcher, compute number of (masked) rows, columns.
Expand All @@ -251,6 +263,8 @@ def __init__(
overlap (int, optional): overlap size in pixel before rescaling. Defaults to 0.
mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None.
coords_only (bool, optional): whenever to extract only the coordinates insteaf of coordinates + tile. Default to False.
threshold (float, optional): minimum proportion of the patch under tissue to be kept.
This argument is ignored if mask=None, passing threshold=0 will be faster
"""
self.wsi = wsi
self.overlap = overlap
Expand Down Expand Up @@ -283,9 +297,8 @@ def __init__(
if round(custom_coords[0][0]) != custom_coords[0][0]:
raise ValueError("custom_coords must be a (N, 2) array of int")
coords = custom_coords

if self.mask is not None:
self.valid_patches_nb, self.valid_coords = self._compute_masked(coords)
self.valid_patches_nb, self.valid_coords = self._compute_masked(coords, threshold)
else:
self.valid_patches_nb, self.valid_coords = len(coords), coords

Expand All @@ -295,17 +308,43 @@ def _colrow_to_xy(self, col, row):
y = row * (self.patch_size_src) - self.overlap * np.clip(row - 1, 0, None)
return (x, y)

def _compute_masked(self, coords) -> None:
def _compute_masked(self, coords, threshold) -> None:
""" Compute tiles which center falls under the tissue """

# TODO spots are already at the center
# Note: we don't take into account the overlap size we calculating centers
xy_centers = coords + self.patch_size_src // 2
# Filter coordinates by bounding boxes of mask polygons
bounding_boxes = self.mask.geometry.bounds
valid_coords = []
for _, bbox in bounding_boxes.iterrows():
bbox_coords = coords[
(coords[:, 0] >= bbox['minx'] - self.patch_size_src) & (coords[:, 0] <= bbox['maxx'] + self.patch_size_src) &
(coords[:, 1] >= bbox['miny'] - self.patch_size_src) & (coords[:, 1] <= bbox['maxy'] + self.patch_size_src)
]
valid_coords.append(bbox_coords)

if len(valid_coords) > 0:
coords = np.vstack(valid_coords)
coords = np.unique(coords, axis=0)
else:
coords = np.array([])


union_mask = self.mask.union_all()

points = gpd.points_from_xy(xy_centers[:, 0], xy_centers[:, 1])
valid_mask = gpd.GeoSeries(points).within(union_mask).values

squares = [
Polygon([
(xy[0], xy[1]),
(xy[0] + self.patch_size_src, xy[1]),
(xy[0] + self.patch_size_src, xy[1] + self.patch_size_src),
(xy[0], xy[1] + self.patch_size_src)])
for xy in coords
]
if threshold == 0:
valid_mask = gpd.GeoSeries(squares).intersects(union_mask).values
else:
gdf = gpd.GeoSeries(squares)
areas = gdf.area
valid_mask = gdf.intersection(union_mask).area >= threshold * areas

valid_patches_nb = valid_mask.sum()
valid_coords = coords[valid_mask]
return valid_patches_nb, valid_coords
Expand Down

0 comments on commit a6dabc5

Please sign in to comment.