Source code for atommic.collections.multitask.rs.data.mrirs_loader

# coding=utf-8
__author__ = "Dimitris Karkalousos"

import os
import warnings
from pathlib import Path
from typing import Callable, Optional, Tuple, Union

import h5py
import nibabel as nib
import numpy as np

from atommic.collections.common.data.mri_loader import MRIDataset
from atommic.collections.common.parts.utils import is_none


[docs]class RSMRIDataset(MRIDataset): """A dataset class for accelerated-MRI reconstruction and MRI segmentation. Examples -------- >>> from atommic.collections.multitask.rs.data.mrirs_loader import RSMRIDataset >>> dataset = RSMRIDataset(root='data/train', sample_rate=0.1) >>> print(len(dataset)) 100 >>> kspace, imspace, coil_sensitivities, mask, initial_prediction, segmentation_labels, attrs, filename, \ slice_num = dataset[0] >>> print(kspace.shape) np.array([30, 640, 368]) .. note:: Extends :class:`atommic.collections.common.data.MRIDataset`. """ def __init__( self, root: Union[str, Path, os.PathLike], coil_sensitivity_maps_root: Union[str, Path, os.PathLike] = None, mask_root: Union[str, Path, os.PathLike] = None, noise_root: Union[str, Path, os.PathLike] = None, initial_predictions_root: Union[str, Path, os.PathLike] = None, dataset_format: str = None, sample_rate: Optional[float] = None, volume_sample_rate: Optional[float] = None, use_dataset_cache: bool = False, dataset_cache_file: Union[str, Path, os.PathLike] = None, num_cols: Optional[Tuple[int]] = None, consecutive_slices: int = 1, data_saved_per_slice: bool = False, n2r_supervised_rate: Optional[float] = 0.0, complex_target: bool = False, log_images_rate: Optional[float] = 1.0, transform: Optional[Callable] = None, segmentations_root: Union[str, Path, os.PathLike] = None, segmentation_classes: int = 2, segmentation_classes_to_remove: Optional[Tuple[int]] = None, segmentation_classes_to_combine: Optional[Tuple[int]] = None, segmentation_classes_to_separate: Optional[Tuple[int]] = None, segmentation_classes_thresholds: Optional[Tuple[float]] = None, complex_data: bool = True, **kwargs, ): """Inits :class:`RSMRIDataset`. Parameters ---------- root : Union[str, Path, os.PathLike] Path to the dataset. sense_root : Union[str, Path, os.PathLike], optional Path to the coil sensitivities maps dataset, if stored separately. mask_root : Union[str, Path, os.PathLike], optional Path to stored masks, if stored separately. noise_root : Union[str, Path, os.PathLike], optional Path to stored noise, if stored separately (in json format). initial_predictions_root : Union[str, Path, os.PathLike], optional Path to the dataset containing the initial predictions. If provided, the initial predictions will be used as the input of the reconstruction network. Default is ``None``. dataset_format : str, optional The format of the dataset. For example, ``'custom_dataset'`` or ``'public_dataset_name'``. Default is ``None``. sample_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the slices should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both. volume_sample_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the volumes should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both. use_dataset_cache : bool, optional Whether to cache dataset metadata. This is very useful for large datasets. dataset_cache_file : Union[str, Path, os.PathLike], optional A file in which to cache dataset information for faster load times. num_cols : Optional[Tuple[int]], optional If provided, only slices with the desired number of columns will be considered. consecutive_slices : int, optional An int (>0) that determine the amount of consecutive slices of the file to be loaded at the same time. Default is ``1``, loading single slices. data_saved_per_slice : bool, optional Whether the data is saved per slice or per volume. n2r_supervised_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the subjects should be loaded for Noise to Reconstruction (N2R) supervised loss, if N2R is enabled. Default is ``0.0``. complex_target : bool, optional Whether to use a complex target or not. Default is ``False``. log_images_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the subjects should be logged as images. Default is ``1.0``. transform : Optional[Callable], optional A sequence of callable objects that preprocesses the raw data into appropriate form. The transform function should take ``kspace``, ``coil sensitivity maps``, ``mask``, ``initial prediction``, ``segmentation``, ``target``, ``attributes``, ``filename``, and ``slice number`` as inputs. ``target`` may be null for test data. Default is ``None``. segmentations_root : Union[str, Path, os.PathLike], optional Path to the dataset containing the segmentations. segmentation_classes : int, optional The number of segmentation classes. Default is ``2``. segmentation_classes_to_remove : Optional[Tuple[int]], optional A tuple of segmentation classes to remove. For example, if the dataset contains segmentation classes 0, 1, 2, 3, and 4, and you want to remove classes 1 and 3, set this to ``(1, 3)``. Default is ``None``. segmentation_classes_to_combine : Optional[Tuple[int]], optional A tuple of segmentation classes to combine. For example, if the dataset contains segmentation classes 0, 1, 2, 3, and 4, and you want to combine classes 1 and 3, set this to ``(1, 3)``. Default is ``None``. segmentation_classes_to_separate : Optional[Tuple[int]], optional A tuple of segmentation classes to separate. For example, if the dataset contains segmentation classes 0, 1, 2, 3, and 4, and you want to separate class 1 into 2 classes, set this to ``(1, 2)``. Default is ``None``. segmentation_classes_thresholds : Optional[Tuple[float]], optional A tuple of thresholds for the segmentation classes. For example, if the dataset contains segmentation classes 0, 1, 2, 3, and 4, and you want to set the threshold for class 1 to 0.5, set this to ``(0.5, 0.5, 0.5, 0.5, 0.5)``. Default is ``None``. complex_data : bool, optional Whether the data is complex. If ``False``, the data is assumed to be magnitude only. Default is ``True``. **kwargs : dict Additional keyword arguments. """ super().__init__( root, coil_sensitivity_maps_root, mask_root, noise_root, initial_predictions_root, dataset_format, sample_rate, volume_sample_rate, use_dataset_cache, dataset_cache_file, num_cols, consecutive_slices, data_saved_per_slice, n2r_supervised_rate, complex_target, log_images_rate, transform, **kwargs, ) self.segmentations_root = segmentations_root # Create random number generator used for consecutive slice selection and set consecutive slice amount self.consecutive_slices = consecutive_slices self.segmentation_classes = segmentation_classes self.segmentation_classes_to_remove = segmentation_classes_to_remove self.segmentation_classes_to_combine = segmentation_classes_to_combine self.segmentation_classes_to_separate = segmentation_classes_to_separate self.segmentation_classes_thresholds = segmentation_classes_thresholds self.complex_data = complex_data
[docs] def process_segmentation_labels(self, segmentation_labels: np.ndarray) -> np.ndarray: # noqa: MC0001 """Processes segmentation labels to remove, combine, and separate classes. Parameters ---------- segmentation_labels : np.ndarray The segmentation labels. The shape should be (num_slices, height, width) or (height, width). Returns ------- np.ndarray The processed segmentation labels. """ # find the dimension with the segmentation classes segmentation_labels_dim = segmentation_labels.ndim - 1 for dim in range(segmentation_labels.ndim): if segmentation_labels.shape[dim] == self.segmentation_classes: segmentation_labels_dim = dim # move it to the last dimension segmentation_labels = np.moveaxis(segmentation_labels, segmentation_labels_dim, -1) # if we have a single slice, add a new dimension if segmentation_labels.ndim == 2: segmentation_labels = np.expand_dims(segmentation_labels, axis=0) # check if we need to remove any classes, e.g. background if self.segmentation_classes_to_remove is not None: segmentation_labels = np.delete(segmentation_labels, self.segmentation_classes_to_remove, axis=-1) # check if we need to combine any classes, e.g. White Matter and Gray Matter if self.segmentation_classes_to_combine is not None: segmentation_labels_to_combine = np.sum( segmentation_labels[..., self.segmentation_classes_to_combine], axis=-1, keepdims=True ) segmentation_labels_to_keep = np.delete(segmentation_labels, self.segmentation_classes_to_combine, axis=-1) if self.segmentation_classes_to_remove is not None and 0 in self.segmentation_classes_to_remove: # if background is removed, we can stack the combined labels with the rest straight away segmentation_labels = np.concatenate( [segmentation_labels_to_combine, segmentation_labels_to_keep], axis=-1 ) else: # if background is not removed, we need to add it back as new background channel segmentation_labels = np.concatenate( [segmentation_labels[..., 0:1], segmentation_labels_to_combine, segmentation_labels_to_keep], axis=-1, ) # check if we need to separate any classes, e.g. pathologies from White Matter and Gray Matter if self.segmentation_classes_to_separate is not None: for x in self.segmentation_classes_to_separate: segmentation_class_to_separate = segmentation_labels[..., x] for i in range(segmentation_labels.shape[-1]): if i == x: continue segmentation_labels[..., i][segmentation_class_to_separate > 0] = 0 # threshold probability maps if any threshold is given if self.segmentation_classes_thresholds is not None: for i, voxel_thres in enumerate(self.segmentation_classes_thresholds): if voxel_thres is not None: segmentation_labels[..., i][segmentation_labels[..., i] < voxel_thres] = 0 segmentation_labels[..., i][segmentation_labels[..., i] >= voxel_thres] = 1 if self.consecutive_slices == 1: # bring the segmentation classes dimension back to the first dimension segmentation_labels = np.moveaxis(segmentation_labels, -1, 0) elif self.consecutive_slices > 1: # bring the segmentation classes dimension back to the second dimension segmentation_labels = np.moveaxis(segmentation_labels, -1, 1) return segmentation_labels
def __getitem__(self, i: int): # noqa: MC0001 """Get item from :class:`RSMRIDataset`.""" fname, dataslice, metadata = self.examples[i] with h5py.File(fname, "r") as hf: if self.complex_data: kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64) sensitivity_map = np.array([]) if "sensitivity_map" in hf: sensitivity_map = self.get_consecutive_slices(hf, "sensitivity_map", dataslice).astype( np.complex64 ) elif "maps" in hf: sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64) elif self.coil_sensitivity_maps_root is not None and self.coil_sensitivity_maps_root != "None": coil_sensitivity_maps_root = self.coil_sensitivity_maps_root split_dir = str(fname).split("/") # check if exists if not os.path.exists(Path(f"{coil_sensitivity_maps_root}/{split_dir[-2]}/{fname.name}")): # find to what depth the coil_sensitivity_maps_root directory is nested for j in range(len(split_dir)): # get the coil_sensitivity_maps_root directory name coil_sensitivity_maps_root = Path(f"{self.coil_sensitivity_maps_root}/{split_dir[-j]}/") if os.path.exists(coil_sensitivity_maps_root / Path(split_dir[-2]) / fname.name): break # load coil sensitivity maps with h5py.File(Path(coil_sensitivity_maps_root) / Path(split_dir[-2]) / fname.name, "r") as sf: if "sensitivity_map" in sf or "sensitivity_map" in next(iter(sf.keys())): sensitivity_map = ( self.get_consecutive_slices(sf, "sensitivity_map", dataslice) .squeeze() .astype(np.complex64) ) mask = None if "mask" in hf: mask = np.asarray(self.get_consecutive_slices(hf, "mask", dataslice)) if mask.ndim == 3: mask = mask[dataslice] elif self.mask_root is not None and self.mask_root != "None": with h5py.File(Path(self.mask_root) / fname.name, "r") as mf: mask = np.asarray(self.get_consecutive_slices(mf, "mask", dataslice)) imspace = np.empty([]) elif not self.complex_data: if "reconstruction_rss" in hf: imspace = self.get_consecutive_slices(hf, "reconstruction_rss", dataslice) elif "reconstruction_sense" in hf: imspace = self.get_consecutive_slices(hf, "reconstruction_sense", dataslice) elif "reconstruction" in hf: imspace = self.get_consecutive_slices(hf, "reconstruction", dataslice) elif "target" in hf: imspace = self.get_consecutive_slices(hf, "target", dataslice) else: raise ValueError( "Complex data has not been selected but no reconstruction data found in file. " "Only 'reconstruction' key is supported." ) kspace = np.empty([]) sensitivity_map = np.array([]) mask = np.empty([]) segmentation_labels = np.empty([]) if self.segmentations_root is not None and self.segmentations_root != "None": with h5py.File(Path(self.segmentations_root) / fname.name, "r") as sf: segmentation_labels = np.asarray(self.get_consecutive_slices(sf, "segmentation", dataslice)) segmentation_labels = self.process_segmentation_labels(segmentation_labels) elif "segmentation" in hf: segmentation_labels = np.asarray(self.get_consecutive_slices(hf, "segmentation", dataslice)) segmentation_labels = self.process_segmentation_labels(segmentation_labels) initial_prediction = np.empty([]) if not is_none(self.initial_predictions_root): with h5py.File(Path(self.initial_predictions_root) / fname.name, "r") as ipf: # type: ignore if "reconstruction" in hf: initial_prediction = ( self.get_consecutive_slices(ipf, "reconstruction", dataslice) .squeeze() .astype(np.complex64) ) elif "initial_prediction" in hf: initial_prediction = ( self.get_consecutive_slices(ipf, "initial_prediction", dataslice) .squeeze() .astype(np.complex64) ) else: if "reconstruction" in hf: initial_prediction = ( self.get_consecutive_slices(hf, "reconstruction", dataslice).squeeze().astype(np.complex64) ) elif "initial_prediction" in hf: initial_prediction = ( self.get_consecutive_slices(hf, "initial_prediction", dataslice).squeeze().astype(np.complex64) ) attrs = dict(hf.attrs) # get noise level for current slice, if metadata["noise_levels"] is not empty if "noise_levels" in metadata and len(metadata["noise_levels"]) > 0: metadata["noise"] = metadata["noise_levels"][dataslice] else: metadata["noise"] = 1.0 attrs.update(metadata) if sensitivity_map.shape != kspace.shape and sensitivity_map.ndim > 1: if sensitivity_map.ndim == 3: sensitivity_map = np.transpose(sensitivity_map, (2, 0, 1)) elif sensitivity_map.ndim == 4: sensitivity_map = np.transpose(sensitivity_map, (0, 3, 1, 2)) else: raise ValueError( f"Sensitivity map has invalid dimensions {sensitivity_map.shape} compared to kspace {kspace.shape}" ) attrs["log_image"] = bool(dataslice in self.indices_to_log) return ( ( kspace, imspace, sensitivity_map, mask, initial_prediction, segmentation_labels, attrs, fname.name, dataslice, ) if self.transform is None else self.transform( kspace, imspace, sensitivity_map, mask, initial_prediction, segmentation_labels, attrs, fname.name, dataslice, ) )
[docs]class SKMTEARSMRIDataset(RSMRIDataset): """Supports the SKM-TEA dataset for multitask accelerated MRI reconstruction and MRI segmentation. .. note:: Extends :class:`atommic.collections.multitask.rs.data.mrirs_loader.RSMRIDataset`. """ def __getitem__(self, i: int): # noqa: MC0001 """Get item from :class:`SKMTEARSMRIDataset`.""" if not is_none(self.dataset_format): dataset_format = self.dataset_format.lower() # type: ignore masking = "default" if "custom_masking" in dataset_format: masking = "custom" dataset_format = dataset_format.replace("custom_masking", "").strip("_") else: dataset_format = None masking = "custom" fname, dataslice, metadata = self.examples[i] with h5py.File(fname, "r") as hf: kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64) if not is_none(dataset_format) and dataset_format == "skm-tea-echo1": kspace = kspace[:, :, 0, :] elif not is_none(dataset_format) and dataset_format == "skm-tea-echo2": kspace = kspace[:, :, 1, :] elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2": kspace = kspace[:, :, 0, :] + kspace[:, :, 1, :] elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2-mc": kspace = np.concatenate([kspace[:, :, 0, :], kspace[:, :, 1, :]], axis=-1) else: warnings.warn( f"Dataset format {dataset_format} is either not supported or set to None. " "Using by default only the first echo." ) kspace = kspace[:, :, 0, :] kspace = kspace[48:-48, 40:-40] sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64) sensitivity_map = sensitivity_map[..., 0] sensitivity_map = sensitivity_map[48:-48, 40:-40] if masking == "custom": mask = np.array([]) else: masks = hf["masks"] mask = {} for key, val in masks.items(): mask[key.split("_")[-1].split(".")[0]] = np.asarray(val) # get the file format of the segmentation files segmentation_labels = nib.load( Path(self.segmentations_root) / Path(str(fname.name.split(".")[0]) + ".nii.gz") # type: ignore ).get_fdata() # get a slice segmentation_labels = self.get_consecutive_slices({"seg": segmentation_labels}, "seg", dataslice) # Get the segmentation labels. They are valued as follows: # 0: Patellar Cartilage patellar_cartilage = np.zeros_like(segmentation_labels) patellar_cartilage[segmentation_labels == 1] = 1 # 1: Femoral Cartilage femoral_cartilage = np.zeros_like(segmentation_labels) femoral_cartilage[segmentation_labels == 2] = 1 # 2: Lateral Tibial Cartilage lateral_tibial_cartilage = np.zeros_like(segmentation_labels) lateral_tibial_cartilage[segmentation_labels == 3] = 1 # 3: Medial Tibial Cartilage medial_tibial_cartilage = np.zeros_like(segmentation_labels) medial_tibial_cartilage[segmentation_labels == 4] = 1 # 4: Lateral Meniscus lateral_meniscus = np.zeros_like(segmentation_labels) lateral_meniscus[segmentation_labels == 5] = 1 # 5: Medial Meniscus medial_meniscus = np.zeros_like(segmentation_labels) medial_meniscus[segmentation_labels == 6] = 1 # combine Lateral Tibial Cartilage and Medial Tibial Cartilage tibial_cartilage = lateral_tibial_cartilage + medial_tibial_cartilage # combine Lateral Meniscus and Medial Meniscus medial_meniscus = lateral_meniscus + medial_meniscus if self.consecutive_slices > 1: segmentation_labels_dim = 1 else: segmentation_labels_dim = 0 # stack the labels in the last dimension segmentation_labels = np.stack( [patellar_cartilage, femoral_cartilage, tibial_cartilage, medial_meniscus], axis=segmentation_labels_dim, ) # TODO: This is hardcoded on the SKM-TEA side, how to generalize this? # We need to crop the segmentation labels in the frequency domain to reduce the FOV. segmentation_labels = np.fft.fftshift(np.fft.fft2(segmentation_labels)) segmentation_labels = segmentation_labels[:, 48:-48, 40:-40] segmentation_labels = np.fft.ifft2(np.fft.ifftshift(segmentation_labels)).real segmentation_labels = np.where(segmentation_labels > 0.5, 1.0, 0.0) # Make sure the labels are binary. imspace = np.empty([]) initial_prediction = np.empty([]) attrs = dict(hf.attrs) # get noise level for current slice, if metadata["noise_levels"] is not empty if "noise_levels" in metadata and len(metadata["noise_levels"]) > 0: metadata["noise"] = metadata["noise_levels"][dataslice] else: metadata["noise"] = 1.0 attrs.update(metadata) kspace = np.transpose(kspace, (2, 0, 1)) sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1)) attrs["log_image"] = bool(dataslice in self.indices_to_log) return ( ( kspace, imspace, sensitivity_map, mask, initial_prediction, segmentation_labels, attrs, fname.name, dataslice, ) if self.transform is None else self.transform( kspace, imspace, sensitivity_map, mask, initial_prediction, segmentation_labels, attrs, fname.name, dataslice, ) )