Skip to content

Commit

Permalink
Merge pull request #44 from rajewsky-lab/fix_imaging_preprocessing
Browse files Browse the repository at this point in the history
Fix imaging preprocessing
  • Loading branch information
danilexn authored Apr 19, 2024
2 parents 6b377e9 + 69eefc9 commit ea9a69e
Show file tree
Hide file tree
Showing 20 changed files with 386 additions and 189 deletions.
4 changes: 2 additions & 2 deletions docs/computational/generate_expression_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ To create such a spatial cell-by-gene ($M\times G$) expression matrix, you will

We efficiently segment cells (or nuclei) from staining images using [cellpose](https:/MouseLand/cellpose).
We provide a model that we fine-tuned for segmentation of fresh-frozen, H&E-stained tissue,
[here](https://github.com/danilexn/openst/blob/main/models/HE_cellpose_rajewsky).
[here](http://bimsbstatic.mdc-berlin.de/rajewsky/openst-public-data/models/HE_cellpose_rajewsky).
You can specify any other model that works best for your data -
refer to the [cellpose](https://cellpose.readthedocs.io/en/latest/index.html) documentation.

Expand All @@ -31,7 +31,7 @@ By default, segmentation is extended radially 10 pixels. This can be changed wit
Make sure to populate the arguments with the values specific to your dataset. Here, we provide `--h5-in` consistent
with the previous steps, `--image-in` and `--mask-out` will read and write the staining and mask inside the Open-ST h5 object,
and `--model` is `HE_cellpose_rajewsky`, the default used in our manuscript. This is the model we recommend for H&E images, and
weights are automatically downloaded. It is also [provided in our repo](https://github.com/rajewsky-lab/openst/blob/main/models/HE_cellpose_rajewsky).
weights are automatically downloaded. It is also [provided in our repo](http://bimsbstatic.mdc-berlin.de/rajewsky/openst-public-data/models/HE_cellpose_rajewsky).
The rest of parameters can be checked with `openst segment --help`.

!!! tip
Expand Down
33 changes: 25 additions & 8 deletions docs/computational/preprocessing_imaging.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,33 @@ tile-scan. You can run this by running the following command on the stitched ima

```bash
openst image_preprocess \
--input=<path_to_input_image> \
--CUT \
--CUT-model=<path_to_model> \
--output=<path_to_output>
--image-in Image_Stitched_Composite.tif \
--image-out Image_Stitched_Composite_Restored.tif
# --device cuda # in case you have a CUDA-compatible GPU
```

Make sure to replace the placeholders (`<...>`). For instance,
`<path_to_input_image>` is the full path and file name of the previously stitched image; `<path_to_model>`
is filename our pre-trained [CUT model](https:/rajewsky-lab/openst/models/CUT.pth), and `<output_image>`
is the path to a folder (writeable) and desired filename for the output image.
If you ran `openst merge_modalities`, then imaging data will be contained inside the Open-ST h5 object, and the command
can be adapted:

```bash
openst image_preprocess \
--h5-in multimodal/spots_stitched.h5ad # just a placeholder, adapt
# --device cuda # in case you have a CUDA-compatible GPU
```

By default, the image will be loaded from the key `uns/spatial/staining_image`, and the CUT-restored image will be saved
to `uns/spatial/staining_image_restored`. You can preview the image restoration results using:

```bash
openst preview \
--h5-in multimodal/spots_stitched.h5ad
--image-key uns/spatial/staining_image uns/spatial/staining_image_restored
```

This will load the two images and visualize it using `napari`. Later, you can run segmentation and pairwise alignment
using either the default merged image (`uns/spatial/staining_image`), or the restored image (`uns/spatial/staining_image_restored`).
Always assess these preprocessing choices (quantitatively and qualitatively) to decide whether these make sense for your data.


## Expected output
After running the stitching (and optionally correction algorithms), you will have a single image file per sample. This, together with
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/adult_mouse/generate_expression_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ into a cell-by-genes matrix, where cells are defined from the segmentation mask.
To create such a spatial cell-by-gene ($M\times G$) expression matrix, you will first need a segmentation mask.

We efficiently segment cells (or nuclei) from staining images using [cellpose](https:/MouseLand/cellpose).
For the H&E-stained tissue provided in this example, we used our [fine-tuned model](https://github.com/rajewsky-lab/openst/blob/main/models/HE_cellpose_rajewsky).
For the H&E-stained tissue provided in this example, we used our [fine-tuned model](http://bimsbstatic.mdc-berlin.de/rajewsky/openst-public-data/models/HE_cellpose_rajewsky).
Make sure to download it and save it into a new `models` folder that you need to create under the `openst_adult_demo` main folder.

You can run the segmentation on the previously created `openst_demo_adult_mouse_spatial_beads_puck_collection_aligned.h5ad` file, which
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/e13_mouse/generate_expression_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ into a cell-by-genes matrix, where cells are defined from the segmentation mask.
To create such a spatial cell-by-gene ($M\times G$) expression matrix, you will first need a segmentation mask.

We efficiently segment cells (or nuclei) from staining images using [cellpose](https:/MouseLand/cellpose).
For the H&E-stained tissue provided in this example, we used our [fine-tuned model](https://github.com/rajewsky-lab/openst/blob/main/models/HE_cellpose_rajewsky).
For the H&E-stained tissue provided in this example, we used our [fine-tuned model](http://bimsbstatic.mdc-berlin.de/rajewsky/openst-public-data/models/HE_cellpose_rajewsky).
Make sure to download it and save it into a new `models` folder that you need to create under the `openst_e13_demo` main folder.

You can run the segmentation on the previously created `openst_demo_e13_mouse_head_spatial_beads_puck_collection_aligned.h5ad` file, which
Expand Down
2 changes: 1 addition & 1 deletion openst/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.2'
__version__ = '0.0.11'
39 changes: 34 additions & 5 deletions openst/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,17 +680,46 @@ def get_image_preprocess_parser():
allow_abbrev=False,
add_help=False,
)
parser.add_argument("--input_img", type=str, required=True, help="path to input image")
parser.add_argument("--cut_dir", type=str, required=True, help="path to CUT directory (to save patched images)")
parser.add_argument("--tile_size_px", type=int, required=True, help="size of the tile in pixels")
parser.add_argument(
"--h5-in",
type=str,
default="",
help="""If set, image is loaded from the Open-ST h5 object (key in --image-in),
and retored image is saved there (to the key --image-out)""",
)
parser.add_argument(
"--image-in",
type=str,
default="uns/spatial/staining_image",
help="Key or path to the input image",
)
parser.add_argument(
"--image-out",
type=str,
default="uns/spatial/staining_image_restored",
help="Key or path where the restored image will be written into",
)
parser.add_argument(
"--tile-size-px",
type=int,
default=512,
help="The input image is split into squared tiles of side `--tile-size-px`, for inference."+
"Larger values avoid boundary effects, but require more memory.",
)
parser.add_argument("--model", type=str, default="HE_CUT_rajewsky", help="CUT model used for image restoration")
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="Device used to run feature matching model. Can be ['cpu', 'cuda']",
help="Device used to run CUT restoration model. Can be ['cpu', 'cuda']",
)
parser.add_argument(
"--num-workers",
type=int,
default=-1,
help="Number of CPU workers for parallel processing",
)
parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here")

return parser

Expand Down
7 changes: 4 additions & 3 deletions openst/preprocessing/CUT/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
"""

import importlib
from models.base_model import BaseModel
import logging

from openst.preprocessing.CUT.models.base_model import BaseModel

def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
Expand All @@ -29,7 +30,7 @@ def find_model_using_name(model_name):
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "models." + model_name + "_model"
model_filename = "openst.preprocessing.CUT.models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
Expand Down Expand Up @@ -63,5 +64,5 @@ def create_model(opt):
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
logging.info("Model architecture `%s` was created" % type(instance).__name__)
return instance
15 changes: 9 additions & 6 deletions openst/preprocessing/CUT/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import torch
import logging
from collections import OrderedDict
from abc import ABC, abstractmethod
from . import networks

from openst.preprocessing.CUT.models import networks


class BaseModel(ABC):
Expand Down Expand Up @@ -211,7 +213,7 @@ def load_networks(self, epoch):
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
logging.info(f'Loading model weights from {load_path}')
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
Expand All @@ -229,17 +231,18 @@ def print_networks(self, verbose):
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print('---------- Networks initialized -------------')
message = '---------- (Start) Networks initialized -------------\n'
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
message += f"{net}\n"
message += '[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6)
message += '---------- (End) Networks initialized -------------'
logging.debug(message)

def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Expand Down
19 changes: 8 additions & 11 deletions openst/preprocessing/CUT/models/cut_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import torch
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util

from openst.preprocessing.CUT.models.base_model import BaseModel
from openst.preprocessing.CUT.models import networks
from openst.preprocessing.CUT.models.patchnce import PatchNCELoss
from openst.preprocessing.CUT.util import util

class CUTModel(BaseModel):
""" This class implements CUT and FastCUT model, described in the paper
Expand Down Expand Up @@ -98,10 +98,10 @@ def data_dependent_initialize(self, data):
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
bs_per_gpu = self.opt.batch_size
self.set_input(data)
self.real_A = self.real_A[:bs_per_gpu]
self.real_B = self.real_B[:bs_per_gpu]
self.real_B = None
self.forward() # compute fake images: G(A)
if self.opt.isTrain:
self.compute_D_loss().backward() # calculate gradients for D
Expand Down Expand Up @@ -138,14 +138,11 @@ def set_input(self, input):
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
self.real_A = input['A'].to(self.device)

def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
self.real = self.real_A
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
Expand Down
14 changes: 0 additions & 14 deletions openst/preprocessing/CUT/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
import functools
from torch.optim import lr_scheduler
import numpy as np
from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator

###############################################################################
# Helper Functions
###############################################################################


def get_filter(filt_size=3):
if(filt_size == 1):
Expand Down Expand Up @@ -256,10 +250,6 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'unet_256':
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'stylegan2':
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt)
elif netG == 'smallstylegan2':
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt)
elif netG == 'resnet_cat':
n_blocks = 8
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
Expand Down Expand Up @@ -323,8 +313,6 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,)
elif netD == 'pixel': # classify if each pixel is real or fake
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
elif 'stylegan2' in netD:
net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
return init_net(net, init_type, init_gain, gpu_ids,
Expand Down Expand Up @@ -1214,8 +1202,6 @@ def forward(self, input):

class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
X -------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""

def __init__(self, outer_nc, inner_nc, input_nc=None,
Expand Down
54 changes: 54 additions & 0 deletions openst/preprocessing/CUT/models/patchnce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from packaging import version
import torch
from torch import nn

class PatchNCELoss(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool

def forward(self, feat_q, feat_k):
num_patches = feat_q.shape[0]
dim = feat_q.shape[1]
feat_k = feat_k.detach()

# pos logit
l_pos = torch.bmm(
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
l_pos = l_pos.view(num_patches, 1)

# neg logit

# Should the negatives from the other samples of a minibatch be utilized?
# In CUT and FastCUT, we found that it's best to only include negatives
# from the same image. Therefore, we set
# --nce_includes_all_negatives_from_minibatch as False
# However, for single-image translation, the minibatch consists of
# crops from the "same" high-resolution image.
# Therefore, we will include the negatives from the entire minibatch.
if self.opt.nce_includes_all_negatives_from_minibatch:
# reshape features as if they are all negatives of minibatch of size 1.
batch_dim_for_bmm = 1
else:
batch_dim_for_bmm = self.opt.batch_size

# reshape features to batch size
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
npatches = feat_q.size(1)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))

# diagonal entries are similarity between same features, and hence meaningless.
# just fill the diagonal with very small number, which is exp(-10) and almost zero
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
l_neg_curbatch.masked_fill_(diagonal, -10.0)
l_neg = l_neg_curbatch.view(-1, npatches)

out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T

loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
device=feat_q.device))

return loss
7 changes: 4 additions & 3 deletions openst/preprocessing/CUT/models/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
min_<netG> ||netG(data_A) - data_B||_1
You need to implement the following functions:
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
<__init__>: Initialize this model class.
<set_input>: Unpack input data and perform data pre-processing.
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
<optimize_parameters>: Update network weights; it will be called in every training iteration.
"""
import torch
from .base_model import BaseModel
from . import networks

from openst.preprocessing.CUT.models.base_model import BaseModel
from openst.preprocessing.CUT.models import networks


class TemplateModel(BaseModel):
Expand Down
Loading

0 comments on commit ea9a69e

Please sign in to comment.