Source code for moosez.image_processing

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ----------------------------------------------------------------------------------------------------------------------
# Author: Lalith Kumar Shiyam Sundar
#         Sebastian Gutschmayer
# Institution: Medical University of Vienna
# Research Group: Quantitative Imaging and Medical Physics (QIMP) Team
# Date: 05.06.2023
# Version: 2.0.0
#
# Description:
# This module handles image processing for the moosez.
#
# Usage:
# The functions in this module can be imported and used in other modules within the moosez for image processing.
#
# ----------------------------------------------------------------------------------------------------------------------

import SimpleITK
import itertools
import dask.array as da
import numpy as np
import pandas as pd
import scipy.ndimage as ndimage
import nibabel
import os
from moosez.constants import CHUNK_THRESHOLD_RESAMPLING, CHUNK_THRESHOLD_INFERRING
from moosez import models
from moosez import system


[docs] def get_intensity_statistics(image: SimpleITK.Image, mask_image: SimpleITK.Image, model: models.Model, out_csv: str) -> None: """ Get the intensity statistics of a NIFTI image file. :param image: The source image from which the intensity statistics are calculated. :type image: sitk.Image :param mask_image: The multilabel mask image. :type mask_image: sitk.Image :param model: The model. :type model: Model :param out_csv: The path to the output CSV file. :type out_csv: str :return: None """ intensity_statistics = SimpleITK.LabelIntensityStatisticsImageFilter() min_intensity = SimpleITK.GetArrayViewFromImage(image).min() max_intensity = SimpleITK.GetArrayViewFromImage(image).max() bins = int(max_intensity - min_intensity) intensity_statistics.SetNumberOfBins(bins) intensity_statistics.Execute(mask_image, image) stats_list = [(intensity_statistics.GetMean(i), intensity_statistics.GetStandardDeviation(i), intensity_statistics.GetMedian(i), intensity_statistics.GetMaximum(i), intensity_statistics.GetMinimum(i)) for i in intensity_statistics.GetLabels()] columns = ['Mean', 'Standard-Deviation', 'Median', 'Maximum', 'Minimum'] stats_df = pd.DataFrame(data=stats_list, index=intensity_statistics.GetLabels(), columns=columns) labels_present = stats_df.index.to_list() regions_present = [] organ_indices_dict = model.organ_indices for label in labels_present: if label in organ_indices_dict: regions_present.append(organ_indices_dict[label]) else: continue stats_df.insert(0, 'Regions-Present', np.array(regions_present)) stats_df.to_csv(out_csv)
[docs] def get_shape_statistics(mask_image: SimpleITK.Image, model: models.Model, out_csv: str) -> None: """ Get the shape statistics of a NIFTI image file. :param mask_image: The multilabel mask image. :type mask_image: sitk.Image :param model: The model. :type model: Model :param out_csv: The path to the output CSV file. :type out_csv: str :return: None """ label_shape_filter = SimpleITK.LabelShapeStatisticsImageFilter() label_shape_filter.Execute(mask_image) stats_list = [(label_shape_filter.GetPhysicalSize(i),) for i in label_shape_filter.GetLabels() if i != 0] # exclude background label columns = ['Volume(mm3)'] stats_df = pd.DataFrame(data=stats_list, index=[i for i in label_shape_filter.GetLabels() if i != 0], columns=columns) labels_present = stats_df.index.to_list() regions_present = [] organ_indices_dict = model.organ_indices for label in labels_present: if label in organ_indices_dict: regions_present.append(organ_indices_dict[label]) else: continue stats_df.insert(0, 'Regions-Present', np.array(regions_present)) stats_df.to_csv(out_csv)
[docs] def limit_fov(image_array: np.array, segmentation_array: np.array, fov_label: list[int] | int, largest_component_only: bool = False): if largest_component_only: segmentation_array = largest_connected_component(segmentation_array, fov_label) if type(fov_label) is list: z_indices = np.where((segmentation_array >= fov_label[0]) & (segmentation_array <= fov_label[1]))[0] else: z_indices = np.where(segmentation_array == fov_label)[0] z_min, z_max = np.min(z_indices), np.max(z_indices) # Crop the CT data along the z-axis limited_fov_array = image_array[z_min:z_max + 1, :, :] return limited_fov_array, {"z_min": z_min, "z_max": z_max, "original_shape": image_array.shape}
[docs] def expand_segmentation_fov(limited_fov_segmentation_array: np.ndarray, original_fov_info: dict) -> np.ndarray: z_min = original_fov_info["z_min"] z_max = original_fov_info["z_max"] original_shape = original_fov_info["original_shape"] # Initialize an array of zeros with the shape of the original CT filled_segmentation_array = np.zeros(original_shape, np.uint8) # Place the cropped segmentation back into its original position filled_segmentation_array[z_min:z_max + 1, :, :] = limited_fov_segmentation_array return filled_segmentation_array
[docs] def largest_connected_component(segmentation_array, intensities): """ Extracts the largest connected component for one or more specific intensities from a multilabel segmentation array and returns a new multilabel array where the largest components retain their original intensity. Parameters: - segmentation_array: 3D or 2D numpy array with multiple labels. - intensities: A single intensity or a list of intensities for which the largest component(s) should be extracted. Returns: - largest_components_multilabel: A multilabel array of the same shape as `segmentation_array`, where the largest connected component(s) of the specified intensity or intensities retain their original intensity, and all other areas are 0. """ # Ensure intensities is a list (even if only one intensity is provided) if not isinstance(intensities, (list, tuple, np.ndarray)): intensities = [intensities] # Initialize an array to store the largest connected components largest_components_multilabel = np.zeros_like(segmentation_array, dtype=segmentation_array.dtype) # Loop over each intensity for intensity in intensities: # Create a binary mask for the current intensity binary_mask = segmentation_array == intensity # Label connected components in the binary mask labeled_array, num_features = ndimage.label(binary_mask) # Find the sizes of each connected component component_sizes = np.bincount(labeled_array.ravel()) # Ignore the background (component 0) component_sizes[0] = 0 # Find the largest connected component for this intensity largest_component_label = component_sizes.argmax() # Create a mask for the largest connected component of this intensity largest_component = labeled_array == largest_component_label # Assign the original intensity value to the largest connected component largest_components_multilabel[largest_component] = intensity return largest_components_multilabel
[docs] class ImageChunker: @staticmethod def __compute_interior_indices(axis_length: int, number_of_chunks: int) -> (list[int], list[int]): start = [int(round(k * axis_length / number_of_chunks)) for k in range(number_of_chunks)] end = [int(round((k + 1) * axis_length / number_of_chunks)) for k in range(number_of_chunks)] return start, end @staticmethod def __chunk_array_with_overlap(array_shape: list[int] | tuple[int, ...], splits_per_dimension: list[int] | tuple[int, ...], overlap_per_dimension: list[int] | tuple[int, ...]) -> list[dict]: dims = array_shape num_dims = len(array_shape) starts_list = [] ends_list = [] for dimension_index in range(num_dims): axis_length = dims[dimension_index] number_of_chunks = splits_per_dimension[dimension_index] start_index, end_index = ImageChunker.__compute_interior_indices(axis_length, number_of_chunks) starts_list.append(start_index) ends_list.append(end_index) chunk_info = [] for idx in itertools.product(*(range(len(s)) for s in starts_list)): chunk_slice = [] interior_slice = [] dest_slice = [] for dimension_index, chunk_index in enumerate(idx): start_index = starts_list[dimension_index] end_index = ends_list[dimension_index] axis_length = dims[dimension_index] number_of_chunks = splits_per_dimension[dimension_index] overlap = overlap_per_dimension[dimension_index] start = max(0, start_index[chunk_index] - overlap if chunk_index > 0 else start_index[chunk_index]) end = min(axis_length, end_index[chunk_index] + overlap if chunk_index < number_of_chunks - 1 else end_index[chunk_index]) start_in_chunk = start_index[chunk_index] - start end_in_chunk = start_in_chunk + (end_index[chunk_index] - start_index[chunk_index]) start_in_full = start_index[chunk_index] end_in_full = end_index[chunk_index] chunk_slice.append(slice(start, end)) interior_slice.append(slice(start_in_chunk, end_in_chunk)) dest_slice.append(slice(start_in_full, end_in_full)) chunk_info.append({ 'chunk_slice': tuple(chunk_slice), 'interior_slice': tuple(interior_slice), 'dest_slice': tuple(dest_slice) }) return chunk_info
[docs] @staticmethod def array_to_chunks(image_array: np.ndarray, splits_per_dimension: list[int] | tuple[int, ...], overlap_per_dimension: list[int] | tuple[int, ...]) -> (list[np.ndarray], list[dict]): chunk_info = ImageChunker.__chunk_array_with_overlap(image_array.shape, splits_per_dimension, overlap_per_dimension) image_chunks = [] positions = [] for info in chunk_info: image_chunk = image_array[info['chunk_slice']] positions.append({ 'interior_slice': info['interior_slice'], 'dest_slice': info['dest_slice'] }) image_chunks.append(image_chunk) return image_chunks, positions
[docs] @staticmethod def chunks_to_array(image_chunks: list[np.ndarray], image_chunk_positions: dict, final_shape: list[int] | tuple[int, ...]) -> np.ndarray: final_arr = np.empty(final_shape, dtype=image_chunks[0].dtype) for image_chunk, image_chunk_position in zip(image_chunks, image_chunk_positions): interior_region = image_chunk[image_chunk_position['interior_slice']] final_arr[image_chunk_position['dest_slice']] = interior_region return final_arr
[docs] @staticmethod def determine_splits(image_array: np.ndarray) -> tuple: image_shape = image_array.shape splits = [] for axis in image_shape: if axis == 1: splits.append(1) continue split = round((axis // CHUNK_THRESHOLD_INFERRING) + 0.5) if split == 0: split = 1 splits.append(split) return tuple(splits)
[docs] class ImageResampler:
[docs] @staticmethod def chunk_along_axis(axis: int) -> int: """ Determines the maximum number of evenly-sized chunks that the axis can be split into. Each chunk is at least of size CHUNK_THRESHOLD. :param axis: Length of the axis. :type axis: int :return: The maximum number of evenly-sized chunks. :rtype: int :raises ValueError: If axis is negative or if CHUNK_THRESHOLD is less than or equal to 0. """ # Check for negative input values if axis < 0: raise ValueError('Axis must be non-negative') if CHUNK_THRESHOLD_RESAMPLING <= 0: raise ValueError('CHUNK_THRESHOLD must be greater than 0') # If the axis is smaller than the threshold, it cannot be split into smaller chunks if axis < CHUNK_THRESHOLD_RESAMPLING: return 1 # Determine the maximum number of chunks that the axis can be split into split = axis // CHUNK_THRESHOLD_RESAMPLING # Reduce the number of chunks until the axis is evenly divisible by split while axis % split != 0: split -= 1 return split
[docs] @staticmethod def resample_chunk_SimpleITK(image_chunk: da.array, input_spacing: tuple, interpolation_method: int, output_spacing: tuple, output_size: tuple) -> da.array: """ Resamples a dask array chunk. :param image_chunk: The chunk (part of an image) to be resampled. :type image_chunk: da.array :param input_spacing: The original spacing of the chunk (part of an image). :type input_spacing: tuple :param interpolation_method: SimpleITK interpolation type. :type interpolation_method: int :param output_spacing: Spacing of the newly resampled chunk. :type output_spacing: tuple :param output_size: Size of the newly resampled chunk. :type output_size: tuple :return: The resampled chunk (part of an image). :rtype: da.array """ sitk_image_chunk = SimpleITK.GetImageFromArray(image_chunk) sitk_image_chunk.SetSpacing(input_spacing) resampled_sitk_image = SimpleITK.Resample(sitk_image_chunk, output_size, SimpleITK.Transform(), interpolation_method, sitk_image_chunk.GetOrigin(), output_spacing, sitk_image_chunk.GetDirection(), 0.0, sitk_image_chunk.GetPixelIDValue()) resampled_array = SimpleITK.GetArrayFromImage(resampled_sitk_image) return resampled_array
[docs] @staticmethod def resample_image_SimpleITK_DASK(sitk_image: SimpleITK.Image, interpolation: str, output_spacing: tuple = (1.5, 1.5, 1.5), output_size: tuple = None) -> SimpleITK.Image: """ Resamples a sitk_image using Dask and SimpleITK. :param sitk_image: The SimpleITK image to be resampled. :type sitk_image: sitk.Image :param interpolation: nearest|linear|bspline. :type interpolation: str :param output_spacing: The desired output spacing of the resampled sitk_image. :type output_spacing: tuple :param output_size: The new size to use. :type output_size: tuple :return: The resampled sitk_image as SimpleITK.Image. :rtype: sitk.Image :raises ValueError: If the interpolation method is not supported. """ resample_result = ImageResampler.resample_image_SimpleITK_DASK_array(sitk_image, interpolation, output_spacing, output_size) resampled_image = SimpleITK.GetImageFromArray(resample_result) resampled_image.SetSpacing(output_spacing) resampled_image.SetOrigin(sitk_image.GetOrigin()) resampled_image.SetDirection(sitk_image.GetDirection()) return resampled_image
[docs] @staticmethod def reslice_identity(reference_image: SimpleITK.Image, moving_image: SimpleITK.Image, output_image_path: str = None, is_label_image: bool = False) -> SimpleITK.Image: """ Reslices an image to the same space as another image. :param reference_image: The reference image. :type reference_image: SimpleITK.Image :param moving_image: The image to reslice to the reference image. :type moving_image: SimpleITK.Image :param output_image_path: Path to the resliced image. Default is None. :type output_image_path: str :param is_label_image: Determines if the image is a label image. Default is False. :type is_label_image: bool :return: The resliced image as SimpleITK.Image. :rtype: SimpleITK.Image """ resampler = SimpleITK.ResampleImageFilter() resampler.SetReferenceImage(reference_image) if is_label_image: resampler.SetInterpolator(SimpleITK.sitkNearestNeighbor) else: resampler.SetInterpolator(SimpleITK.sitkLinear) resampled_image = resampler.Execute(moving_image) resampled_image = SimpleITK.Cast(resampled_image, SimpleITK.sitkInt32) if output_image_path is not None: SimpleITK.WriteImage(resampled_image, output_image_path) return resampled_image
[docs] @staticmethod def resample_image_SimpleITK_DASK_array(sitk_image: SimpleITK.Image, interpolation: str, output_spacing: tuple = (1.5, 1.5, 1.5), output_size: tuple = None) -> np.array: if interpolation == 'nearest': interpolation_method = SimpleITK.sitkNearestNeighbor elif interpolation == 'linear': interpolation_method = SimpleITK.sitkLinear elif interpolation == 'bspline': interpolation_method = SimpleITK.sitkBSpline else: raise ValueError('The interpolation method is not supported.') input_spacing = sitk_image.GetSpacing() input_size = sitk_image.GetSize() input_chunks = [axis / ImageResampler.chunk_along_axis(axis) for axis in input_size] input_chunks_reversed = list(reversed(input_chunks)) image_dask = da.from_array(SimpleITK.GetArrayViewFromImage(sitk_image), chunks=input_chunks_reversed) if output_size is not None: output_spacing = [input_spacing[i] * (input_size[i] / output_size[i]) for i in range(len(input_size))] output_chunks = [round(input_chunks[i] * (input_spacing[i] / output_spacing[i])) for i in range(len(input_chunks))] output_chunks_reversed = list(reversed(output_chunks)) result = da.map_blocks(ImageResampler.resample_chunk_SimpleITK, image_dask, input_spacing, interpolation_method, output_spacing, output_chunks, chunks=output_chunks_reversed, meta=np.array(()), dtype=np.float32) return result.compute()
[docs] @staticmethod def resample_segmentation(reference_image: SimpleITK.Image, segmentation_image: SimpleITK.Image): resampled_sitk_image = SimpleITK.Resample(segmentation_image, reference_image.GetSize(), SimpleITK.Transform(), SimpleITK.sitkNearestNeighbor, reference_image.GetOrigin(), reference_image.GetSpacing(), reference_image.GetDirection(), 0.0, segmentation_image.GetPixelIDValue()) return resampled_sitk_image
[docs] def determine_orientation_code(image: nibabel.Nifti1Image) -> [tuple | list, str]: affine = image.affine orthonormal_orientation = nibabel.orientations.aff2axcodes(affine) return orthonormal_orientation, ''.join(orthonormal_orientation)
[docs] def confirm_orthonormality(image: nibabel.Nifti1Image) -> tuple[nibabel.Nifti1Image, bool]: data = image.get_fdata() affine = image.affine header = image.header rotation_matrix = affine[:3, :3] spacing = np.linalg.norm(rotation_matrix, axis=0) ortho_rotation_matrix = rotation_matrix / spacing is_orthonormal = np.allclose(ortho_rotation_matrix.T @ ortho_rotation_matrix, np.eye(3)) if not is_orthonormal: orthonormalized = True q, _ = np.linalg.qr(ortho_rotation_matrix) ortho_rotation_matrix = q * spacing orthonormal_affine = np.eye(4) orthonormal_affine[:3, :3] = ortho_rotation_matrix orthonormal_affine[:3, 3] = affine[:3, 3] orthonormal_header = header.copy() orthonormal_header.set_qform(orthonormal_affine) orthonormal_header.set_sform(orthonormal_affine) image = nibabel.Nifti1Image(data, orthonormal_affine, orthonormal_header) else: orthonormalized = False return image, orthonormalized
[docs] def confirm_orientation(image: nibabel.Nifti1Image) -> tuple[nibabel.Nifti1Image, bool]: data = image.get_fdata() affine = image.affine header = image.header original_orientation = nibabel.orientations.aff2axcodes(affine) if original_orientation[0] == 'R': reoriented = True current_orientation = nibabel.orientations.axcodes2ornt(original_orientation) target_orientation = nibabel.orientations.axcodes2ornt(('L', original_orientation[1], original_orientation[2])) orientation_transform = nibabel.orientations.ornt_transform(current_orientation, target_orientation) reoriented_data = nibabel.orientations.apply_orientation(data, orientation_transform) reoriented_affine = nibabel.orientations.inv_ornt_aff(orientation_transform, data.shape).dot(affine) reoriented_header = header.copy() reoriented_header.set_qform(reoriented_affine) reoriented_header.set_sform(reoriented_affine) image = nibabel.Nifti1Image(reoriented_data, reoriented_affine, reoriented_header) else: reoriented = False return image, reoriented
[docs] def convert_to_sitk(image: nibabel.Nifti1Image) -> SimpleITK.Image: data = image.get_fdata() affine = image.affine spacing = image.header.get_zooms() image_data_swapped_axes = data.swapaxes(0, 2) sitk_image = SimpleITK.GetImageFromArray(image_data_swapped_axes) translation_vector = affine[:3, 3] rotation_matrix = affine[:3, :3] axis_flip_matrix = np.diag([-1, -1, 1]) sitk_image.SetSpacing([spacing.item() for spacing in spacing]) sitk_image.SetOrigin(np.dot(axis_flip_matrix, translation_vector)) sitk_image.SetDirection((np.dot(axis_flip_matrix, rotation_matrix) / np.absolute(spacing)).flatten()) return sitk_image
[docs] def standardize_image(image_path: str, output_manager: system.OutputManager, standardization_output_path: str | None) -> SimpleITK.Image: image = nibabel.load(image_path) _, original_orientation = determine_orientation_code(image) output_manager.log_update(f" Image loaded. Orientation: {original_orientation}") image, orthonormalized = confirm_orthonormality(image) if orthonormalized: _, orthonormal_orientation = determine_orientation_code(image) output_manager.log_update(f" Image orthonormalized. Orientation: {orthonormal_orientation}") image, reoriented = confirm_orientation(image) if reoriented: _, reoriented_orientation = determine_orientation_code(image) output_manager.log_update(f" Image reoriented. Orientation: {reoriented_orientation}") sitk_image = convert_to_sitk(image) output_manager.log_update(f" Image converted to SimpleITK.") processing_steps = [orthonormalized, reoriented] prefixes = ["orthonormal", "reoriented"] if standardization_output_path is not None and any(processing_steps): output_manager.log_update(f" Writing standardized image.") prefix = "_".join([prefix for processing_step, prefix in zip(processing_steps, prefixes) if processing_step]) output_path = os.path.join(standardization_output_path, f"{prefix}_{os.path.basename(image_path)}") SimpleITK.WriteImage(sitk_image, output_path) return sitk_image