Source code for atommic.collections.common.data.mri_loader

# coding=utf-8
__author__ = "Dimitris Karkalousos"

import json
import logging
import os
import random
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import h5py
import numpy as np
import yaml  # type: ignore
from defusedxml.ElementTree import fromstring
from torch.utils.data import Dataset

from atommic.collections.common.parts import utils


def et_query(root: str, qlist: Sequence[str], namespace: str = "http://www.ismrm.org/ISMRMRD") -> str:
    """Query an XML element for a list of attributes.

    Parameters
    ----------
    root : str
        The root element of the XML tree.
    qlist : list
        A list of strings, each of which is an attribute name.
    namespace : str, optional
        The namespace of the XML tree.

    Returns
    -------
    str
        A string containing the value of the last attribute in the list.
    """
    s = "."
    prefix = "ismrmrd_namespace"
    ns = {prefix: namespace}
    for el in qlist:
        s += f"//{prefix}:{el}"
    value = root.find(s, ns)  # type: ignore
    if value is None:
        return "0"
    return str(value.text)  # type: ignore


[docs]class MRIDataset(Dataset): """A generic class for loading an MRI dataset for any task. .. note:: Extends :class:`torch.utils.data.Dataset`. """ def __init__( # noqa: MC0001 self, root: Union[str, Path, os.PathLike], coil_sensitivity_maps_root: Union[str, Path, os.PathLike] = None, mask_root: Union[str, Path, os.PathLike] = None, noise_root: Union[str, Path, os.PathLike] = None, initial_predictions_root: Union[str, Path, os.PathLike] = None, dataset_format: str = None, sample_rate: Optional[float] = None, volume_sample_rate: Optional[float] = None, use_dataset_cache: bool = False, dataset_cache_file: Union[str, Path, os.PathLike] = None, num_cols: Optional[Tuple[int]] = None, consecutive_slices: int = 1, data_saved_per_slice: bool = False, n2r_supervised_rate: Optional[float] = 0.0, complex_target: bool = False, log_images_rate: Optional[float] = 1.0, transform: Optional[Callable] = None, **kwargs, # pylint: disable=unused-argument ): """Inits :class:`MRIDataset`. Parameters ---------- root : Union[str, Path, os.PathLike] Path to the dataset. coil_sensitivity_maps_root : Union[str, Path, os.PathLike], optional Path to the coil sensitivities maps dataset, if stored separately. mask_root : Union[str, Path, os.PathLike], optional Path to stored masks, if stored separately. noise_root : Union[str, Path, os.PathLike], optional Path to stored noise, if stored separately (in json format). initial_predictions_root : Union[str, Path, os.PathLike], optional Path to the dataset containing the initial predictions. If provided, the initial predictions will be used as the input of the reconstruction network. Default is ``None``. dataset_format : str, optional The format of the dataset. For example, ``'custom_dataset'`` or ``'public_dataset_name'``. sample_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the slices should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both. volume_sample_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the volumes should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both. use_dataset_cache : bool, optional Whether to cache dataset metadata. This is very useful for large datasets. dataset_cache_file : Union[str, Path, os.PathLike, none], optional A file in which to cache dataset information for faster load times. If not provided, the cache will be stored in the dataset root. num_cols : Optional[Tuple[int]], optional If provided, only slices with the desired number of columns will be considered. consecutive_slices : int, optional An int (>0) that determine the amount of consecutive slices of the file to be loaded at the same time. Default is ``1``, loading single slices. data_saved_per_slice : bool, optional Whether the data is saved per slice or per volume. n2r_supervised_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the subjects should be loaded for Noise to Reconstruction (N2R) supervised loss, if N2R is enabled. Default is ``0.0``. complex_target : bool, optional Whether to use a complex target or not. Default is ``False``. log_images_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the slices should be logged as images. Default is ``1.0``. transform : Optional[Callable], optional A sequence of callable objects that preprocesses the raw data into appropriate form. The transform function should take ``kspace``, ``coil sensitivity maps``, ``quantitative maps``, ``mask``, ``initial prediction``, ``target``, ``attributes``, ``filename``, and ``slice number`` as inputs. ``target`` may be null for test data. Default is ``None``. **kwargs Additional keyword arguments. """ super().__init__() self.coil_sensitivity_maps_root = coil_sensitivity_maps_root self.mask_root = mask_root if str(noise_root).endswith(".json"): with open(noise_root, "r") as f: # type: ignore # pylint: disable=unspecified-encoding noise_root = [json.loads(line) for line in f.readlines()] # type: ignore else: noise_root = None self.initial_predictions_root = initial_predictions_root self.dataset_format = dataset_format # set default sampling mode if none given if not utils.is_none(sample_rate) and not utils.is_none(volume_sample_rate): raise ValueError( f"Both sample_rate {sample_rate} and volume_sample_rate {volume_sample_rate} are set. " "Please set only one of them." ) if sample_rate is None or sample_rate == "None": sample_rate = 1.0 if volume_sample_rate is None or volume_sample_rate == "None": volume_sample_rate = 1.0 self.dataset_cache_file = ( None if utils.is_none(dataset_cache_file) else Path(dataset_cache_file) # type: ignore ) if self.dataset_cache_file is not None and self.dataset_cache_file.exists() and use_dataset_cache: with open(self.dataset_cache_file, "rb") as f: dataset_cache = yaml.safe_load(f) else: dataset_cache = {} if consecutive_slices < 1: raise ValueError(f"Consecutive slices {consecutive_slices} is out of range, must be > 0.") self.consecutive_slices = consecutive_slices self.complex_target = complex_target self.transform = transform self.data_saved_per_slice = data_saved_per_slice self.recons_key = "reconstruction" self.examples = [] # Check if our dataset is in the cache. If yes, use that metadata, if not, then regenerate the metadata. if dataset_cache.get(root) is None or not use_dataset_cache: if str(root).endswith(".json"): with open(root, "r") as f: # pylint: disable=unspecified-encoding examples = json.load(f) files = [Path(example) for example in examples] else: files = list(Path(root).iterdir()) if n2r_supervised_rate != 0.0: # randomly select a subset of files for N2R supervised loss based on n2r_supervised_rate n2r_supervised_files = random.sample( files, int(np.round(n2r_supervised_rate * len(files))) # type: ignore ) for fname in sorted(files): metadata, num_slices = self._retrieve_metadata(fname) metadata["noise_levels"] = ( self.__parse_noise__(noise_root, fname) if noise_root is not None else [] # type: ignore ) metadata["n2r_supervised"] = False if n2r_supervised_rate != 0.0: # Use lazy % formatting in logging logging.info("%s files are selected for N2R supervised loss.", n2r_supervised_files) if fname in n2r_supervised_files: metadata["n2r_supervised"] = True self.examples += [(fname, slice_ind, metadata) for slice_ind in range(num_slices)] if dataset_cache.get(root) is None and use_dataset_cache: dataset_cache[root] = self.examples logging.info("Saving dataset cache to %s.", self.dataset_cache_file) with open(self.dataset_cache_file, "wb") as f: # type: ignore yaml.dump(dataset_cache, f) else: logging.info("Using dataset cache from %s.", self.dataset_cache_file) self.examples = dataset_cache[root] # subsample if desired if sample_rate < 1.0: # sample by slice random.shuffle(self.examples) num_examples = round(len(self.examples) * sample_rate) self.examples = self.examples[:num_examples] elif volume_sample_rate < 1.0: # sample by volume vol_names = sorted(list({f[0].stem for f in self.examples})) random.shuffle(vol_names) num_volumes = round(len(vol_names) * volume_sample_rate) sampled_vols = vol_names[:num_volumes] self.examples = [example for example in self.examples if example[0].stem in sampled_vols] if num_cols and not utils.is_none(num_cols): self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols] self.indices_to_log = np.random.choice( [example[1] for example in self.examples], int(log_images_rate * len(self.examples)), # type: ignore replace=False, ) def _retrieve_metadata(self, fname: Union[str, Path]) -> Tuple[Dict, int]: """Retrieve metadata from a given file. Parameters ---------- fname : Union[str, Path] Path to file. Returns ------- Tuple[Dict, int] Metadata dictionary and number of slices in the file. Examples -------- >>> metadata, num_slices = _retrieve_metadata("file.h5") >>> metadata {'padding_left': 0, 'padding_right': 0, 'encoding_size': 0, 'recon_size': (0, 0)} >>> num_slices 1 """ with h5py.File(fname, "r") as hf: if "ismrmrd_header" in hf: et_root = fromstring(hf["ismrmrd_header"][()]) enc = ["encoding", "encodedSpace", "matrixSize"] enc_size = ( int(et_query(et_root, enc + ["x"])), int(et_query(et_root, enc + ["y"])), int(et_query(et_root, enc + ["z"])), ) rec = ["encoding", "reconSpace", "matrixSize"] recon_size = ( int(et_query(et_root, rec + ["x"])), int(et_query(et_root, rec + ["y"])), int(et_query(et_root, rec + ["z"])), ) params = ["encoding", "encodingLimits", "kspace_encoding_step_1"] enc_limits_center = int(et_query(et_root, params + ["center"])) enc_limits_max = int(et_query(et_root, params + ["maximum"])) + 1 padding_left = enc_size[1] // 2 - enc_limits_center padding_right = padding_left + enc_limits_max else: padding_left = 0 padding_right = 0 enc_size = (0, 0, 0) recon_size = (0, 0, 0) if "kspace" in hf: shape = hf["kspace"].shape elif "reconstruction" in hf: shape = hf["reconstruction"].shape elif "target" in hf: shape = hf["target"].shape else: raise ValueError(f"{fname} does not contain kspace, reconstruction, or target data.") num_slices = 1 if self.data_saved_per_slice else shape[0] metadata = { "padding_left": padding_left, "padding_right": padding_right, "encoding_size": enc_size, "recon_size": recon_size, "num_slices": num_slices, } return metadata, num_slices @staticmethod def __parse_noise__(noise: str, fname: Path) -> List[str]: """Parse noise type from filename. Parameters ---------- noise : str json string of noise type. fname : Path Filename to parse noise type from. Returns ------- List[str] List of noise values. """ return [noise[i]["noise"] for i in range(len(noise)) if noise[i]["fname"] == fname.name] # type: ignore
[docs] def get_consecutive_slices(self, data: Dict, key: str, dataslice: int) -> np.ndarray: """Get consecutive slices from a given data dictionary. Parameters ---------- data : dict Data to extract slices from. key : str Key to extract slices from. dataslice : int Slice to index. Returns ------- np.ndarray Array of consecutive slices. If ``self.consecutive_slices`` is > 1, then the array will have shape ``(self.consecutive_slices, *data[key].shape[1:])``. Otherwise, the array will have shape ``data[key].shape[1:]``. Examples -------- >>> data = {"kspace": np.random.rand(10, 640, 368)} >>> from atommic.collections.common.data.mri_loader import MRIDataset >>> MRIDataset.get_consecutive_slices(data, "kspace", 1).shape (1, 640, 368) >>> MRIDataset.get_consecutive_slices(data, "kspace", 5).shape (5, 640, 368) """ # read data x = data[key] if self.data_saved_per_slice: x = np.expand_dims(x, axis=0) if self.consecutive_slices == 1: if x.shape[0] == 1: return x[0] if x.ndim != 2: return x[dataslice] return x # get consecutive slices num_slices = x.shape[0] # If the number of consecutive slices is greater than or equal to the total slices, return the entire stack if self.consecutive_slices >= num_slices: # pad left and right with zero slices to match the desired number of slices slices_to_add_start = (self.consecutive_slices - num_slices) // 2 slices_to_add_end = self.consecutive_slices - num_slices - slices_to_add_start if slices_to_add_start > 0: zero_slices = np.zeros((slices_to_add_start, *x.shape[1:])) x = np.concatenate((zero_slices, x), axis=0) if slices_to_add_end > 0: zero_slices = np.zeros((slices_to_add_end, *x.shape[1:])) x = np.concatenate((x, zero_slices), axis=0) return x # Calculate half of the consecutive slices to determine the middle position half_slices = self.consecutive_slices // 2 # Determine the start and end slices based on the middle position start_slice = dataslice - half_slices end_slice = dataslice + half_slices + 1 # Handle edge cases slices_to_add_start = 0 slices_to_add_end = 0 if start_slice < 0: slices_to_add_start = abs(start_slice) start_slice = 0 if end_slice > (num_slices - 1): slices_to_add_end = end_slice - num_slices extracted_slices = x[start_slice:] else: extracted_slices = x[start_slice:end_slice] # Add slices to the start and end if needed if slices_to_add_start > 0: zero_slices = np.zeros((slices_to_add_start, *extracted_slices.shape[1:])) extracted_slices = np.concatenate((zero_slices, extracted_slices), axis=0) if slices_to_add_end > 0: zero_slices = np.zeros((slices_to_add_end, *extracted_slices.shape[1:])) extracted_slices = np.concatenate((extracted_slices, zero_slices), axis=0) return extracted_slices
def __len__(self): """Length of :class:`MRIDataset`.""" return len(self.examples) def __getitem__(self, i: int): """Get item from :class:`MRIDataset`.""" raise NotImplementedError