Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel histogram_standalization.train #970

Open
hsyang1222 opened this issue Sep 22, 2022 · 1 comment
Open

Parallel histogram_standalization.train #970

hsyang1222 opened this issue Sep 22, 2022 · 1 comment
Labels
enhancement New feature or request

Comments

@hsyang1222
Copy link
Contributor

🚀 Feature
parallel version torchio.transforms.preprocessing.intensity.histogram_standardization.train

Motivation
Currently, this method is driven using single thread. The environment in which the deep learning model is learned using torchio is likely to be a server environment with multiple cpu cores, so it can be processed in parallel to increase efficiency.

Code
I wrote the code as below using multiprocessing.pool. I think this code is useful.

import torchio
import tqdm
import numpy as np
from pathlib import Path
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from torchio.typing import TypePath

from torchio.transforms.preprocessing.intensity import histogram_standardization

import multiprocessing

DEFAULT_CUTOFF = 0.01, 0.99
STANDARD_RANGE = 0, 100
TypeLandmarks = Union[TypePath, Dict[str, Union[TypePath, np.ndarray]]]

def train(
        images_paths: Sequence[TypePath],
        cutoff: Optional[Tuple[float, float]] = None,
        mask_path: Optional[Union[Sequence[TypePath], TypePath]] = None,
        masking_function: Optional[Callable] = None,
        output_path: Optional[TypePath] = None,
        num_workers: int = 32
) -> np.ndarray:
    is_masks_list = isinstance(mask_path, Sequence)
    if is_masks_list and len(mask_path) != len(images_paths):  # type: ignore[arg-type]  # noqa: E501
        message = (
            f'Different number of images ({len(images_paths)})'  # type: ignore[arg-type]  # noqa: E501
            f' and mask ({len(mask_path)}) paths found'  # type: ignore[arg-type]  # noqa: E501
        )
        raise ValueError(message)
    quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff
    percentiles_cutoff = 100 * np.array(quantiles_cutoff)
    percentiles_database = []
    a, b = percentiles_cutoff  # for mypy
    percentiles = histogram_standardization._get_percentiles((a, b))

    mask_path_list = [None] * len(images_paths)
    masking_function_list = [None] * len(images_paths)

    if masking_function is not None:
        masking_function_list = [masking_function] * len(images_paths)
    else:
        if is_masks_list:
            mask_path_list = mask_path
        else:
            mask_path_list = [mask_path] * len(images_paths)


    # At least of the of masking_function or mask_path_list is None

    percentiles_list = [percentiles] * len(images_paths)

    pool = multiprocessing.Pool(num_workers)
    with tqdm.tqdm(total=len(images_paths), desc="make histogram") as pbar:
        args_ziped = zip(images_paths, masking_function_list, mask_path_list, percentiles_list)

        for percentile_values in pool.imap_unordered(img_to_percentiles_value, args_ziped):
            percentiles_database.append(percentile_values)
            pbar.update()

    percentiles_database_array = np.vstack(percentiles_database)
    mapping = histogram_standardization._get_average_mapping(percentiles_database_array)

    if output_path is not None:
        output_path = Path(output_path).expanduser()
        extension = output_path.suffix
        if extension == '.txt':
            modality = 'image'
            text = f'{modality} {" ".join(map(str, mapping))}'
            output_path.write_text(text)
        elif extension == '.npy':
            np.save(output_path, mapping)
    return mapping

def img_to_percentiles_value(args):
    image_file_path, masking_function, mask_path, percentiles = args
    tensor, _ = histogram_standardization.read_image(image_file_path)

    if masking_function is not None:
        mask = masking_function(tensor)
    else:
        if mask_path is None:
            mask = np.ones_like(tensor, dtype=bool)
        else:
            path = mask_path  # type: ignore[assignment]
            mask, _ = histogram_standardization.read_image(path)
            mask = mask.numpy() > 0
    array = tensor.numpy()
    percentile_values = np.percentile(array[mask], percentiles)
    return percentile_values
@hsyang1222 hsyang1222 added the enhancement New feature or request label Sep 22, 2022
@fepegar
Copy link
Owner

fepegar commented Oct 9, 2022

Hi, @hsyang1222. If you have tried this successfully, feel free to open a pull request with your changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants