Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wsi #29

Closed
wants to merge 20 commits into from
Closed

Wsi #29

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Hest tests

on:
#push:
# branches: [ "main", "develop"]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest
env:
HF_READ_TOKEN_PAUL: ${{ secrets.HF_READ_TOKEN_PAUL }}

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: "3.9"

- name: Install python dependencies
run: |
python -m pip install -e .
- name: Install apt dependencies
run: |
sudo apt-get update
sudo apt-get install libvips libvips-dev openslide-tools

- name: Run tests
run: |
python tests/hest_tests.py
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,5 @@ hest_vis
vis
vis2
models/deeplabv3*
.github
htmlcov
models/CellViT-SAM-H-x40.pth
139 changes: 41 additions & 98 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,25 @@

import cv2
import geopandas as gpd
import matplotlib
import numpy as np

from hest.io.seg_readers import (TissueContourReader,
write_geojson)
from hest.LazyShapes import LazyShapes, convert_old_to_gpd
from hest.io.seg_readers import TissueContourReader
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new
from hest.segmentation.TissueMask import TissueMask, load_tissue_mask
from hest.wsi import WSI, CucimWarningSingleton, NumpyWSI, wsi_factory
from hest.wsi import (WSI, CucimWarningSingleton, NumpyWSI, contours_to_img,
get_tissue_vis, wsi_factory)

try:
import openslide
except Exception:
print("Couldn't import openslide, verify that openslide is installed on your system, https://openslide.org/download/")
import pandas as pd
from matplotlib.collections import PatchCollection
from PIL import Image
from shapely import Point
from tqdm import tqdm

from .segmentation.segmentation import (apply_otsu_thresholding,
contours_to_img, get_tissue_vis,
mask_to_contours, save_pkl,
segment_tissue_deep)
from .segmentation.segmentation import (apply_otsu_thresholding, mask_to_gdf,
save_pkl, segment_tissue_deep)
from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated,
find_first_file_endswith, get_path_from_meta_row,
plot_verify_pixel_size, tiff_save, verify_paths)
Expand Down Expand Up @@ -137,7 +133,7 @@ def save_spatial_plot(self, save_path: str, name: str='', key='total_counts', pl
filename = f"{name}spatial_plots.png"

# Save the figure
fig.savefig(os.path.join(save_path, filename))
fig.savefig(os.path.join(save_path, filename), dpi=400)
print(f"H&E overlay spatial plots saved in {save_path}")


Expand Down Expand Up @@ -261,21 +257,14 @@ def segment_tissue(
tissue_mask = np.round(cv2.resize(mask, (width, height))).astype(np.uint8)

#TODO directly convert to gpd
gdf_contours = mask_to_contours(tissue_mask, pixel_size=self.pixel_size)
gdf_contours = mask_to_gdf(tissue_mask, pixel_size=self.pixel_size)
self._tissue_contours = gdf_contours

return self.tissue_contours


def save_tissue_contours(self, save_dir: str, name: str) -> None:
write_geojson(
self.tissue_contours,
os.path.join(save_dir, name + '_contours.geojson'),
'tissue_id',
extra_prop=True,
index_key='hole'
)

self.tissue_contours.to_file(os.path.join(save_dir, name + '_contours.geojson'), driver="GeoJSON")

@deprecated
def get_tissue_mask(self) -> np.ndarray:
Expand Down Expand Up @@ -316,6 +305,7 @@ def dump_patches(
"""

import matplotlib.pyplot as plt
dst_pixel_size = target_pixel_size

adata = self.adata.copy()

Expand All @@ -326,91 +316,40 @@ def dump_patches(

src_pixel_size = self.pixel_size

# minimum intersection percecentage with the tissue mask to keep a patch
TISSUE_INTER_THRESH = 0.7
TARGET_VIS_SIZE = 1000

scale_factor = target_pixel_size / src_pixel_size
patch_size_pxl = round(target_patch_size * scale_factor)
patch_count = 0
output_datafile = os.path.join(patch_save_dir, name + '.h5')

assert len(adata.obs) == len(adata.obsm['spatial'])

_, ax = plt.subplots()

mode_HE = 'w'
i = 0
img_width, img_height = self.wsi.get_dimensions()
patch_rectangles = [] # lower corner (x, y) + (widht, height)
downscale_vis = TARGET_VIS_SIZE / img_width

if use_mask:
tissue_mask = np.zeros((img_height, img_width, 3), dtype=np.uint8)
tissue_mask = contours_to_img(
self.tissue_contours,
tissue_mask,
draw_contours=False,
line_color=(1, 1, 1)
)[:, :, 0]
else:
tissue_mask = np.ones((img_height, img_width)).astype(np.uint8)

mask_plot = self.get_tissue_vis()

ax.imshow(mask_plot)
for _, row in tqdm(adata.obs.iterrows(), total=len(adata.obs)):

barcode_spot = row.name

xImage = int(adata.obsm['spatial'][i][0])
yImage = int(adata.obsm['spatial'][i][1])

i += 1

if not(0 <= xImage and xImage < img_width and 0 <= yImage and yImage < img_height):
if verbose:
print('Warning, spot is out of the image, skipping')
continue

if not(0 <= yImage - patch_size_pxl // 2 and yImage + patch_size_pxl // 2 < img_height and \
0 <= xImage - patch_size_pxl // 2 and xImage + patch_size_pxl // 2 < img_width):
if verbose:
print('Warning, patch is out of the image, skipping')
continue

## TODO reimplement now that we use the pyramidal level
image_patch = self.wsi.read_region((xImage - patch_size_pxl // 2, yImage - patch_size_pxl // 2), 0, (patch_size_pxl, patch_size_pxl))
rect_x = (xImage - patch_size_pxl // 2) * downscale_vis
rect_y = (yImage - patch_size_pxl // 2) * downscale_vis
rect_width = patch_size_pxl * downscale_vis
rect_height = patch_size_pxl * downscale_vis

image_patch = np.array(image_patch)
if image_patch.shape[2] == 4:
image_patch = image_patch[:, :, :3]


if use_mask:
patch_mask = tissue_mask[yImage - patch_size_pxl // 2: yImage + patch_size_pxl // 2,
xImage - patch_size_pxl // 2: xImage + patch_size_pxl // 2]
patch_area = patch_mask.shape[0] ** 2
pixel_count = patch_mask.sum()

if pixel_count / patch_area < TISSUE_INTER_THRESH:
continue
patch_size_src = target_patch_size * (dst_pixel_size / src_pixel_size)
coords_center = adata.obsm['spatial']
coords_topleft = coords_center - patch_size_src // 2
len_tmp = len(coords_topleft)
coords_topleft = coords_topleft[(0 <= coords_topleft[:, 0] + patch_size_src) & (coords_topleft[:, 0] < self.wsi.width) & (0 <= coords_topleft[:, 1] + patch_size_src) & (coords_topleft[:, 1] < self.wsi.height)]
if len(coords_topleft) < len_tmp:
warnings.warn(f"Filtered {len_tmp - len(coords_topleft)} spots outside the WSI")

barcodes = np.array(adata.obs.index)
mask = self.tissue_contours if use_mask else None
coords_topleft = np.array(coords_topleft).astype(int)
patcher = self.wsi.create_patcher(target_patch_size, src_pixel_size, dst_pixel_size, mask=mask, custom_coords=coords_topleft)

patch_rectangles.append(matplotlib.patches.Rectangle((rect_x, rect_y), rect_width, rect_height))

patch_count += 1
image_patch = cv2.resize(image_patch, (target_patch_size, target_patch_size), interpolation=cv2.INTER_CUBIC)
i = 0
for tile, x, y in tqdm(patcher):

center_x = x + patch_size_src // 2
center_y = y + patch_size_src // 2

# Save ref patches
assert image_patch.shape == (target_patch_size, target_patch_size, 3)
asset_dict = { 'img': np.expand_dims(image_patch, axis=0), # (1 x w x h x 3)
'coords': np.expand_dims([yImage, xImage], axis=0), # (1 x 2)
'barcode': np.expand_dims([barcode_spot], axis=0)
assert tile.shape == (target_patch_size, target_patch_size, 3)
asset_dict = { 'img': np.expand_dims(tile, axis=0), # (1 x w x h x 3)
'coords': np.expand_dims([center_y, center_x], axis=0), # (1 x 2)
'barcode': np.expand_dims([barcodes[i]], axis=0)
}

attr_dict = {}
Expand All @@ -419,13 +358,10 @@ def dump_patches(

initsave_hdf5(output_datafile, asset_dict, attr_dict, mode=mode_HE)
mode_HE = 'a'
i += 1


if dump_visualization:
ax.add_collection(PatchCollection(patch_rectangles, facecolor='none', edgecolor='black', linewidth=0.3))
ax.set_axis_off()
plt.tight_layout()
plt.savefig(os.path.join(patch_save_dir, name + '_patch_vis.png'), dpi=400, bbox_inches = 'tight')
patcher.save_visualization(os.path.join(patch_save_dir, name + '_patch_vis.png'), dpi=400)

if verbose:
print(f'found {patch_count} valid patches')
Expand Down Expand Up @@ -773,8 +709,15 @@ def read_HESTData(
tissue_contours = None
tissue_seg = None
if tissue_contours_path is not None:
tissue_contours = TissueContourReader().read_gdf(tissue_contours_path)
tissue_contours['tissue_id'] = tissue_contours['tissue_id'].astype(int)
with open(tissue_contours_path) as f:
lines = f.read()
if 'hole' in lines:
warnings.warn("this type of .geojson tissue contour file is deprecated, please download the new `tissue_seg` folder on huggingface: https://huggingface.co/datasets/MahmoodLab/hest/tree/main")
gdf = TissueContourReader().read_gdf(tissue_contours_path)
tissue_contours = old_geojson_to_new(gdf)
else:
tissue_contours = gpd.read_file(tissue_contours_path)

elif mask_path_pkl is not None and mask_path_jpg is not None:
tissue_seg = load_tissue_mask(mask_path_pkl, mask_path_jpg, width, height)

Expand Down
33 changes: 21 additions & 12 deletions src/hest/LazyShapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,28 @@ def convert_old_to_gpd(contours_holes, contours_tissue) -> gpd.GeoDataFrame:
types = []
for i in range(len(contours_holes)):
tissue = contours_tissue[i]
shapes.append(Polygon(tissue[:, 0, :]))
tissue_ids.append(i)
types.append('tissue')
holes = contours_holes[i]
if len(holes) > 0:
for hole in holes:
shapes.append(Polygon(hole[:, 0, :]))
tissue_ids.append(i)
types.append('hole')

holes = contours_holes[i] if len(contours_holes[i]) > 0 else None
shapes.append(Polygon(tissue[:, 0, :]), holes=holes)

df = pd.DataFrame(tissue_ids, columns=['tissue_id'])
df['hole'] = types
df['hole'] = df['hole'] == 'hole'

return gpd.GeoDataFrame(df, geometry=shapes)



def old_geojson_to_new(gdf):
polygons = []
keys = []
for key, group in gdf.groupby('tissue_id'):
holes = []
for row in group.values:
if row[2]:
holes.append([coord for coord in row[0].exterior.coords])
else:
exterior = [coord for coord in row[0].exterior.coords]
polygons.append(Polygon(exterior, holes))
keys.append(key)

gdf = gpd.GeoDataFrame(geometry=polygons)
gdf['tissue_id'] = keys
return gdf
6 changes: 2 additions & 4 deletions src/hest/io/seg_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ def read_gdf(self, path) -> gpd.GeoDataFrame:

class TissueContourReader(GDFReader):

def read_gdf(self, path) -> gpd.GeoDataFrame:

gdf = _read_geojson(path, class_name='tissue_id', index_key='hole')

def read_gdf(self, path) -> gpd.GeoDataFrame:
gdf = _read_geojson(path, 'tissue_id', extra_props=False, index_key='hole')
return gdf


Expand Down
12 changes: 2 additions & 10 deletions src/hest/segmentation/SegDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,18 @@ def __getitem__(self, index):


class SegWSIDataset(Dataset):
masks = []
patches = []
coords = []

def __init__(self, patcher: WSIPatcher, transform):
self.patcher = patcher

self.cols, self.rows = self.patcher.get_cols_rows()
self.size = self.cols * self.rows

self.transform = transform


def __len__(self):
return self.size
return len(self.patcher)

def __getitem__(self, index):
col = index % self.cols
row = index // self.cols
tile, x, y = self.patcher.get_tile(col, row)
tile, x, y = self.patcher[index]

if self.transform:
tile = self.transform(tile)
Expand Down
Loading
Loading