Skip to content

Commit

Permalink
add new samples and aligned transcripts (#53)
Browse files Browse the repository at this point in the history
* add xenium seg

* minor changes

* minor changes

* add iter_hest

* update readme

* update tutorial

* update metadata

* move barcode_coords to asset

* update tutorials

* update README

* fix problem

* put gene_db.parquet on HF and correct harcoded path
  • Loading branch information
pauldoucet authored Sep 24, 2024
1 parent bdf659e commit 56b2426
Show file tree
Hide file tree
Showing 14 changed files with 1,470 additions and 1,191 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ HEST-1k, HEST-Library, and HEST-Benchmark are released under the Attribution-Non

## Updates

- **23.09.24**: 121 new samples released, including 27 Xenium and 7 Visium HD! We also make the aligned Xenium transcripts + the aligned DAPI segmented cells/nuclei public.

- **30.08.24**: HEST-Benchmark results updated. Includes H-Optimus-0, Virchow 2, Virchow, and GigaPath. New COAD task based on 4 Xenium samples. HuggingFace bench data have been updated.

- **28.08.24**: New set of helpers for batch effect visualization and correction. Tutorial [here](https:/mahmoodlab/HEST/blob/main/tutorials/5-Batch-effect-visualization.ipynb).

## Download/Query HEST-1k (743GB)
## Download/Query HEST-1k (>1TB)

To download/query HEST-1k, follow the tutorial [1-Downloading-HEST-1k.ipynb](https:/mahmoodlab/HEST/blob/main/tutorials/1-Downloading-HEST-1k.ipynb) or follow instructions on [Hugging Face](https://huggingface.co/datasets/MahmoodLab/hest).

**NOTE:** The entire dataset weighs 743 GB but you can easily download a subset by querying per id, organ, species...
**NOTE:** The entire dataset weighs more than 1TB but you can easily download a subset by querying per id, organ, species...


## HEST-Library installation
Expand Down Expand Up @@ -63,13 +65,10 @@ pip install \
You can then simply view the dataset as,

```python
from hest import load_hest
from hest import iter_hest

print('Lazy loading of hest...')
hest_data = load_hest('hest_data') # location of the data
print('loaded hest')
for d in hest_data:
print(d)
for st in iter_hest('../hest_data', id_list=['TENX95']):
print(st)
```

## HEST-Library API
Expand Down
1,230 changes: 1,230 additions & 0 deletions assets/HEST_v1_1_0.csv

Large diffs are not rendered by default.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
1,109 changes: 0 additions & 1,109 deletions metadata/HEST_v1_0_0.csv

This file was deleted.

175 changes: 135 additions & 40 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import shutil
import warnings
from typing import Dict, List, Union
from typing import Dict, Iterator, List, Union

import cv2
import geopandas as gpd
Expand All @@ -23,7 +23,7 @@
print("Couldn't import openslide, verify that openslide is installed on your system, https://openslide.org/download/")
import pandas as pd
from hestcore.segmentation import (apply_otsu_thresholding, mask_to_gdf,
save_pkl, segment_tissue_deep)
save_pkl, segment_tissue_deep, get_path_relative)
from PIL import Image
from shapely import Point
from tqdm import tqdm
Expand Down Expand Up @@ -474,7 +474,7 @@ def to_spatial_data(self, lazy_img=True) -> SpatialData: # type: ignore
from dask import delayed
from dask.array import from_delayed
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, ShapesModel
from spatialdata.models import Image2DModel, ShapesModel, TableModel

def read_hest_wsi(wsi: WSI):
return wsi.numpy()
Expand Down Expand Up @@ -504,7 +504,11 @@ def read_hest_wsi(wsi: WSI):
shape_names.append('tissue_contours')

my_images = {"he": parsed_image}
my_tables = {"anndata": self.adata}

self.adata.obs['instance'] = self.adata.obs.index
self.adata.obs['region'] = 'he'
parsed_adata = TableModel.parse(self.adata, instance_key='instance', region_key='region', region='he')
my_tables = {"anndata": parsed_adata}

st = SpatialData(images=my_images, tables=my_tables, shapes=dict(zip(shape_names, shape_validated)))

Expand Down Expand Up @@ -651,7 +655,10 @@ def read_HESTData(
mask_path_pkl: str = None, # Deprecated
mask_path_jpg: str = None, # Deprecated
cellvit_path: str = None,
tissue_contours_path: str = None
tissue_contours_path: str = None,
xenium_cell_path: str = None,
xenium_nucleus_path: str = None,
transcripts_path: str = None
) -> HESTData:
""" Read a HEST sample from disk
Expand All @@ -663,8 +670,11 @@ def read_HESTData(
metrics_path (str): metadata dictionary containing information such as the pixel size, or QC metrics attached to that sample
mask_path_pkl (str): *Deprecated* path to a .pkl file containing the tissue segmentation contours. Defaults to None.
mask_path_jpg (str): *Deprecated* path to a .jog file containing the greyscale tissue segmentation mask. Defaults to None.
cellvit_path (str): path to a cell segmentation file in .geojson or .parquet. Defaults to None,
tissue_contours_path (str): path to a .geojson tissue contours file. Defaults to None
cellvit_path (str): path to a cell segmentation file in .geojson or .parquet. Defaults to None.
tissue_contours_path (str): path to a .geojson tissue contours file. Defaults to None.
xenium_cell_path (str): path to a .parquet xeniun cell segmentation file. Defaults to None.
xenium_nucleus_path (str): path to a .parquet xenium nucleus segmentation file. Defaults to None.
transcripts_path (str): path to a .parquet transcript dataframe. Defaults to None.
Returns:
Expand Down Expand Up @@ -705,11 +715,40 @@ def read_HESTData(
shapes = []
if cellvit_path is not None:
shapes.append(LazyShapes(cellvit_path, 'cellvit', 'he'))
if xenium_cell_path is not None:
shapes.append(LazyShapes(xenium_cell_path, 'xenium_cell', 'he'))
if xenium_nucleus_path is not None:
shapes.append(LazyShapes(xenium_nucleus_path, 'xenium_nucleus', 'he'))

transcripts = None
if transcripts_path is not None:
transcripts = pd.read_parquet(transcripts_path)

adata = sc.read_h5ad(adata_path)
with open(metrics_path) as metrics_f:
metrics = json.load(metrics_f)
return HESTData(adata, img, metrics['pixel_size_um_estimated'], metrics, tissue_seg=tissue_seg, shapes=shapes, tissue_contours=tissue_contours)

if transcripts is not None:
return XeniumHESTData(
adata,
img,
metrics['pixel_size_um_estimated'],
metrics,
tissue_seg=tissue_seg,
shapes=shapes,
tissue_contours=tissue_contours,
transcript_df=transcripts
)
else:
return HESTData(
adata,
img,
metrics['pixel_size_um_estimated'],
metrics,
tissue_seg=tissue_seg,
shapes=shapes,
tissue_contours=tissue_contours
)


def mask_and_patchify_bench(meta_df: pd.DataFrame, save_dir: str, use_mask=True, keep_largest=None):
Expand Down Expand Up @@ -796,6 +835,85 @@ def create_splits(dest_dir, splits, K):
test_df.to_csv(os.path.join(dest_dir, f'test_{i}.csv'), index=False)


class HESTIterator:
def __init__(self, hest_dir, id_list, **read_kwargs):
if id_list is not None and (not(isinstance(id_list, list) or isinstance(id_list, np.ndarray))):
raise ValueError('id_list must a list or a numpy array')
self.id_list = id_list
self.hest_dir = hest_dir
self.i = 0
self.read_kwargs = read_kwargs

def __iter__(self):
self.i = 0
return self

def __next__(self) -> HESTData:
if self.i < len(self):
x = _read_st(self.hest_dir, self.id_list[self.i], **self.read_kwargs)
self.i += 1
return x
else:
raise StopIteration

def __len__(self):
return len(self.id_list)

def iter_hest(hest_dir: str, id_list: List[str] = None, **read_kwargs) -> HESTIterator:
return HESTIterator(hest_dir, id_list, **read_kwargs)

def _read_st(hest_dir, st_filename, load_transcripts=False):
id = st_filename.split('.')[0]
adata_path = os.path.join(hest_dir, 'st', f'{id}.h5ad')
img_path = os.path.join(hest_dir, 'wsis', f'{id}.tif')
meta_path = os.path.join(hest_dir, 'metadata', f'{id}.json')

masks_path_pkl = None
masks_path_jpg = None
verify_paths([adata_path, img_path, meta_path], suffix='\nHave you downloaded the dataset? (https://huggingface.co/datasets/MahmoodLab/hest)')


if os.path.exists(os.path.join(hest_dir, 'tissue_seg')):
masks_path_pkl = find_first_file_endswith(os.path.join(hest_dir, 'tissue_seg'), f'{id}_mask.pkl')
masks_path_jpg = find_first_file_endswith(os.path.join(hest_dir, 'tissue_seg'), f'{id}_mask.jpg')
tissue_contours_path = find_first_file_endswith(os.path.join(hest_dir, 'tissue_seg'), f'{id}_contours.geojson')

cellvit_path = None
if os.path.exists(os.path.join(hest_dir, 'cellvit_seg')):
cellvit_path = find_first_file_endswith(os.path.join(hest_dir, 'cellvit_seg'), f'{id}_cellvit_seg.parquet')
if cellvit_path is None:
cellvit_path = find_first_file_endswith(os.path.join(hest_dir, 'cellvit_seg'), f'{id}_cellvit_seg.geojson')
if cellvit_path is not None:
warnings.warn(f'reading the cell segmentation as .geojson can be slow, download the .parquet cells for faster loading https://huggingface.co/datasets/MahmoodLab/hest')

if os.path.exists(os.path.join(hest_dir, 'xenium_seg')):
xenium_cell_path = find_first_file_endswith(os.path.join(hest_dir, 'xenium_seg'), f'{id}_xenium_cell_seg.parquet')
xenium_nucleus_path = find_first_file_endswith(os.path.join(hest_dir, 'xenium_seg'), f'{id}_xenium_nucleus_seg.parquet')
else:
xenium_cell_path = None
xenium_nucleus_path = None


transcripts_path = None
if load_transcripts:
transcripts_path = find_first_file_endswith(os.path.join(hest_dir, 'transcripts'), f'{id}_transcripts.parquet')

st = read_HESTData(
adata_path,
img_path,
meta_path,
masks_path_pkl,
masks_path_jpg,
cellvit_path=cellvit_path,
tissue_contours_path=tissue_contours_path,
xenium_cell_path=xenium_cell_path,
xenium_nucleus_path=xenium_nucleus_path,
transcripts_path=transcripts_path
)
return st



def load_hest(hest_dir: str, id_list: List[str] = None) -> List[HESTData]:
"""Read HEST-1k samples from a local directory
Expand All @@ -821,38 +939,7 @@ def load_hest(hest_dir: str, id_list: List[str] = None) -> List[HESTData]:
st_filenames = os.listdir(os.path.join(hest_dir, 'st'))

for st_filename in tqdm(st_filenames):
id = st_filename.split('.')[0]
adata_path = os.path.join(hest_dir, 'st', f'{id}.h5ad')
img_path = os.path.join(hest_dir, 'wsis', f'{id}.tif')
meta_path = os.path.join(hest_dir, 'metadata', f'{id}.json')

masks_path_pkl = None
masks_path_jpg = None
verify_paths([adata_path, img_path, meta_path], suffix='\nHave you downloaded the dataset? (https://huggingface.co/datasets/MahmoodLab/hest)')


if os.path.exists(os.path.join(hest_dir, 'tissue_seg')):
masks_path_pkl = find_first_file_endswith(os.path.join(hest_dir, 'tissue_seg'), f'{id}_mask.pkl')
masks_path_jpg = find_first_file_endswith(os.path.join(hest_dir, 'tissue_seg'), f'{id}_mask.jpg')
tissue_contours_path = find_first_file_endswith(os.path.join(hest_dir, 'tissue_seg'), f'{id}_contours.geojson')

cellvit_path = None
if os.path.exists(os.path.join(hest_dir, 'cellvit_seg')):
cellvit_path = find_first_file_endswith(os.path.join(hest_dir, 'cellvit_seg'), f'{id}_cellvit_seg.parquet')
if cellvit_path is None:
cellvit_path = find_first_file_endswith(os.path.join(hest_dir, 'cellvit_seg'), f'{id}_cellvit_seg.geojson')
if cellvit_path is not None and not warned:
warnings.warn(f'reading the cell segmentation as .geojson can be slow, download the .parquet cells for faster loading https://huggingface.co/datasets/MahmoodLab/hest')
warned = True

st = read_HESTData(
adata_path,
img_path,
meta_path,
masks_path_pkl,
masks_path_jpg,
cellvit_path=cellvit_path,
tissue_contours_path=tissue_contours_path)
st = _read_st(hest_dir, st_filename)
hestdata_list.append(st)

warnings.resetwarnings()
Expand All @@ -878,6 +965,14 @@ def get_gene_db(species, cache_dir='.genes') -> pd.DataFrame:


def _get_alias_to_parent_df():

path_folder_assets = get_path_relative(__file__, '../../assets')
path_gene_db = os.path.join(path_folder_assets, 'human_gene_db.parquet')

if not os.path.exists(path_gene_db):
from huggingface_hub import snapshot_download
snapshot_download(repo_id="MahmoodLab/hest", repo_type='dataset', local_dir=path_folder_assets, allow_patterns=['human_gene_db.parquet'])

df = pd.read_parquet('assets/gene_db.parquet')
df = df[['symbol', 'ensembl_gene_id', 'alias_symbol']].explode('alias_symbol')
none_mask = df['alias_symbol'].isna()
Expand Down
2 changes: 1 addition & 1 deletion src/hest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .utils import tiff_save, find_pixel_size_from_spot_coords, write_10X_h5, get_k_genes, SpotPacking
from .autoalign import autoalign_visium
from .readers import *
from .HESTData import HESTData, read_HESTData, load_hest
from .HESTData import HESTData, read_HESTData, load_hest, iter_hest
from .segmentation.cell_segmenters import segment_cellvit

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion src/hest/custom_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def align_her2(path, raw_count_path):


def infer_row_col_from_barcodes(barcodes_df, adata):
barcode_path = './barcode_coords/visium-v1_coordinates.txt'
barcode_path = './assets/barcode_coords/visium-v1_coordinates.txt'
barcode_coords = pd.read_csv(barcode_path, sep='\t', header=None)
barcode_coords = barcode_coords.rename(columns={
0: 'barcode',
Expand Down
2 changes: 1 addition & 1 deletion src/hest/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def _find_alignment_barcodes(self, alignment_df: str, barcode_path: str) -> pd.D

def _find_visium_slide_version(self, alignment_df: str, adata: sc.AnnData) -> str: # type: ignore
highest_nb_match = -1
barcode_dir = get_path_relative(__file__, '../../barcode_coords/')
barcode_dir = get_path_relative(__file__, '../../assets/barcode_coords/')
for barcode_path in os.listdir(barcode_dir):
spatial_aligned = self._find_alignment_barcodes(alignment_df, os.path.join(barcode_dir, barcode_path))
nb_match = len(pd.merge(spatial_aligned, adata.obs, left_index=True, right_index=True))
Expand Down
26 changes: 14 additions & 12 deletions tutorials/1-Downloading-HEST-1k.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"source": [
"import datasets\n",
"\n",
"local_dir='hest_data' # hest will be dowloaded to this folder\n",
"local_dir='../hest_data' # hest will be dowloaded to this folder\n",
"\n",
"# Note that the full dataset is around 1TB of data\n",
"dataset = datasets.load_dataset(\n",
Expand All @@ -119,7 +119,7 @@
"source": [
"import datasets\n",
"\n",
"local_dir='hest_data' # hest will be dowloaded to this folder\n",
"local_dir='../hest_data' # hest will be dowloaded to this folder\n",
"\n",
"ids_to_query = ['TENX95', 'TENX99'] # list of ids to query\n",
"\n",
Expand Down Expand Up @@ -147,7 +147,7 @@
"import datasets\n",
"import pandas as pd\n",
"\n",
"local_dir='hest_data' # hest will be dowloaded to this folder\n",
"local_dir='../hest_data' # hest will be dowloaded to this folder\n",
"\n",
"meta_df = pd.read_csv(\"hf://datasets/MahmoodLab/hest/HEST_v1_0_2.csv\")\n",
"\n",
Expand Down Expand Up @@ -182,10 +182,12 @@
" - `{id}_mask.jpg`: Downscaled or full resolution greyscale tissue mask\n",
" - `{id}_mask.pkl`: Tissue/holes contours in a pickle file\n",
" - `{id}_vis.jpg`: Visualization of the tissue mask on the downscaled WSI\n",
"- **cellvit_seg/**: Cellvit nuclei segmentation\n",
"- **pixel_size_vis/**: Visualization of the pixel size\n",
"- **patches/**: 256x256 H&E patches (0.5µm/px) extracted around ST spots in a .h5 object optimized for deep-learning. Each patch is matched to the corresponding ST profile (see **st/**) with a barcode.\n",
"- **patches_vis/**: Visualization of the mask and patches on a downscaled WSI.\n"
"- **patches_vis/**: Visualization of the mask and patches on a downscaled WSI.\n",
"- **transcripts/**: individual transcripts aligned to H&E for xenium samples; read with pandas.read_parquet; aligned coordinates in pixel are in columns `['he_x', 'he_y']`\n",
"- **cellvit_seg/**: Cellvit nuclei segmentation\n",
"- **xenium_seg**: xenium segmentation on DAPI and aligned to H&E\n"
]
},
{
Expand All @@ -194,18 +196,18 @@
"metadata": {},
"outputs": [],
"source": [
"from hest import load_hest\n",
"from hest import iter_hest\n",
"import pandas as pd\n",
"\n",
"# Ex: inspect all the Invasive Lobular Carcinoma samples (ILC)\n",
"meta_df = pd.read_csv('../metadata/HEST_v1_0_2.csv')\n",
"meta_df = pd.read_csv('../assets/HEST_v1_1_0.csv')\n",
"\n",
"id_list = meta_df[meta_df['oncotree_code'] == 'ILC']['id'].values\n",
"\n",
"print('load hest...')\n",
"hest_data = load_hest('../hest_data', id_list=id_list) # location of the data\n",
"for d in hest_data:\n",
" print(d)"
"# Iterate through a subset of hest\n",
"for st in iter_hest('../hest_data', id_list=id_list):\n",
" print(st)"
]
},
{
Expand All @@ -218,9 +220,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "hest",
"display_name": "cuml",
"language": "python",
"name": "hest"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
Loading

0 comments on commit 56b2426

Please sign in to comment.