# coding=utf-8
__author__ = "Dimitris Karkalousos"
# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
import json
import logging
import os
import random
import re
import warnings
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import h5py
import numpy as np
import yaml # type: ignore
from defusedxml.ElementTree import fromstring
from torch.utils.data import Dataset
from atommic.collections.common.data.mri_loader import MRIDataset, et_query
from atommic.collections.common.parts.utils import is_none
[docs]class ReconstructionMRIDataset(MRIDataset):
"""A dataset class for accelerated MRI reconstruction.
Examples
--------
>>> from atommic.collections.reconstruction.data.mri_reconstruction_loader import ReconstructionMRIDataset
>>> dataset = ReconstructionMRIDataset(root='data/train', sample_rate=0.1)
>>> print(len(dataset))
100
>>> kspace, coil_sensitivities, mask, initial_prediction, target, attrs, filename, slice_num = dataset[0]
>>> print(kspace.shape)
np.array([30, 640, 368])
.. note::
Extends :class:`atommic.collections.common.data.mri_loader.MRIDataset`.
"""
def __getitem__(self, i: int): # noqa: MC0001
"""Get item from :class:`ReconstructionMRIDataset`."""
fname, dataslice, metadata = self.examples[i]
with h5py.File(fname, "r") as hf:
min_val = hf["min"][()] if "min" in hf else None
max_val = hf["max"][()] if "max" in hf else None
mean_val = hf["mean"][()] if "mean" in hf else None
std_val = hf["std"][()] if "std" in hf else None
kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64)
sensitivity_map = np.array([])
if "sensitivity_map" in hf:
sensitivity_map = self.get_consecutive_slices(hf, "sensitivity_map", dataslice).astype(np.complex64)
elif "maps" in hf:
sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64)
elif self.coil_sensitivity_maps_root is not None and self.coil_sensitivity_maps_root != "None":
coil_sensitivity_maps_root = self.coil_sensitivity_maps_root
split_dir = str(fname).split("/")
for j in range(len(split_dir)):
coil_sensitivity_maps_root = Path(f"{self.coil_sensitivity_maps_root}/{split_dir[-j]}/")
if os.path.exists(coil_sensitivity_maps_root / Path(split_dir[-2]) / fname.name):
break
with h5py.File(Path(coil_sensitivity_maps_root) / Path(split_dir[-2]) / fname.name, "r") as sf:
if "sensitivity_map" in sf or "sensitivity_map" in next(iter(sf.keys())):
sensitivity_map = (
self.get_consecutive_slices(sf, "sensitivity_map", dataslice)
.squeeze()
.astype(np.complex64)
)
mask = None
if "mask" in hf:
mask = np.asarray(self.get_consecutive_slices(hf, "mask", dataslice))
if mask.ndim == 3:
mask = mask[dataslice]
elif self.mask_root is not None and self.mask_root != "None":
with h5py.File(Path(self.mask_root) / fname.name, "r") as mf:
mask = np.asarray(self.get_consecutive_slices(mf, "mask", dataslice))
prediction = np.empty([])
if not is_none(self.initial_predictions_root):
with h5py.File(Path(self.initial_predictions_root) / fname.name, "r") as ipf: # type: ignore
if "reconstruction" in hf:
prediction = (
self.get_consecutive_slices(ipf, "reconstruction", dataslice)
.squeeze()
.astype(np.complex64)
)
elif "initial_prediction" in hf:
prediction = (
self.get_consecutive_slices(ipf, "initial_prediction", dataslice)
.squeeze()
.astype(np.complex64)
)
else:
if "reconstruction" in hf:
prediction = (
self.get_consecutive_slices(hf, "reconstruction", dataslice).squeeze().astype(np.complex64)
)
elif "initial_prediction" in hf:
prediction = (
self.get_consecutive_slices(hf, "initial_prediction", dataslice).squeeze().astype(np.complex64)
)
if self.complex_target:
target = None
else:
# find key containing "reconstruction_"
rkey = re.findall(r"reconstruction_(.*)", str(hf.keys()))
self.recons_key = "reconstruction_" + rkey[0] if rkey else "target"
if "reconstruction_rss" in self.recons_key:
self.recons_key = "reconstruction_rss"
elif "reconstruction_sense" in hf:
self.recons_key = "reconstruction_sense"
target = self.get_consecutive_slices(hf, self.recons_key, dataslice) if self.recons_key in hf else None
attrs = dict(hf.attrs)
# get noise level for current slice, if metadata["noise_levels"] is not empty
if "noise_levels" in metadata and len(metadata["noise_levels"]) > 0:
metadata["noise"] = metadata["noise_levels"][dataslice]
else:
metadata["noise"] = 1.0
attrs.update(metadata)
if sensitivity_map.shape != kspace.shape and sensitivity_map.ndim > 1:
if sensitivity_map.ndim == 3:
sensitivity_map = np.transpose(sensitivity_map, (2, 0, 1))
elif sensitivity_map.ndim == 4:
sensitivity_map = np.transpose(sensitivity_map, (0, 3, 1, 2))
else:
raise ValueError(
f"Sensitivity map has invalid dimensions {sensitivity_map.shape} compared to kspace {kspace.shape}"
)
attrs["log_image"] = bool(dataslice in self.indices_to_log)
if min_val is not None:
attrs["min"] = min_val
if max_val is not None:
attrs["max"] = max_val
if mean_val is not None:
attrs["mean"] = mean_val
if std_val is not None:
attrs["std"] = std_val
return (
(
kspace,
sensitivity_map,
mask,
prediction,
target,
attrs,
fname.name,
dataslice,
)
if self.transform is None
else self.transform(
kspace,
sensitivity_map,
mask,
prediction,
target,
attrs,
fname.name,
dataslice,
)
)
[docs]class CC359ReconstructionMRIDataset(Dataset):
"""Supports the CC359 dataset for accelerated MRI reconstruction.
.. note::
Similar to :class:`atommic.collections.common.data.mri_loader.MRIDataset`. It does not extend it because we
need to override the ``__init__`` and ``__getitem__`` methods.
"""
def __init__( # noqa: MC0001
self,
root: Union[str, Path, os.PathLike],
coil_sensitivity_maps_root: Union[str, Path, os.PathLike] = None,
mask_root: Union[str, Path, os.PathLike] = None,
noise_root: Union[str, Path, os.PathLike] = None,
initial_predictions_root: Union[str, Path, os.PathLike] = None,
dataset_format: str = None,
sample_rate: Optional[float] = None,
volume_sample_rate: Optional[float] = None,
use_dataset_cache: bool = False,
dataset_cache_file: Union[str, Path, os.PathLike] = None,
num_cols: Optional[Tuple[int]] = None,
consecutive_slices: int = 1,
data_saved_per_slice: bool = False,
n2r_supervised_rate: Optional[float] = 0.0,
complex_target: bool = False,
log_images_rate: Optional[float] = 1.0,
transform: Optional[Callable] = None,
**kwargs, # pylint: disable=unused-argument
):
"""Inits :class:`CC359ReconstructionMRIDataset`.
Parameters
----------
root : Union[str, Path, os.PathLike]
Path to the dataset.
coil_sensitivity_maps_root : Union[str, Path, os.PathLike], optional
Path to the coil sensitivities maps dataset, if stored separately.
mask_root : Union[str, Path, os.PathLike], optional
Path to stored masks, if stored separately.
noise_root : Union[str, Path, os.PathLike], optional
Path to stored noise, if stored separately (in json format).
initial_predictions_root : Union[str, Path, os.PathLike], optional
Path to the dataset containing the initial predictions. If provided, the initial predictions will be used
as the input of the reconstruction network. Default is ``None``.
dataset_format : str, optional
The format of the dataset. For example, ``'custom_dataset'`` or ``'public_dataset_name'``.
sample_rate : Optional[float], optional
A float between 0 and 1. This controls what fraction of the slices should be loaded. When creating
subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes)
but not both.
volume_sample_rate : Optional[float], optional
A float between 0 and 1. This controls what fraction of the volumes should be loaded. When creating
subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes)
but not both.
use_dataset_cache : bool, optional
Whether to cache dataset metadata. This is very useful for large datasets.
dataset_cache_file : Union[str, Path, os.PathLike, none], optional
A file in which to cache dataset information for faster load times. If not provided, the cache will be
stored in the dataset root.
num_cols : Optional[Tuple[int]], optional
If provided, only slices with the desired number of columns will be considered.
consecutive_slices : int, optional
An int (>0) that determine the amount of consecutive slices of the file to be loaded at the same time.
Default is ``1``, loading single slices.
data_saved_per_slice : bool, optional
Whether the data is saved per slice or per volume.
n2r_supervised_rate : Optional[float], optional
A float between 0 and 1. This controls what fraction of the subjects should be loaded for Noise to
Reconstruction (N2R) supervised loss, if N2R is enabled. Default is ``0.0``.
complex_target : bool, optional
Whether to use a complex target or not. Default is ``False``.
log_images_rate : Optional[float], optional
A float between 0 and 1. This controls what fraction of the slices should be logged as images. Default is
``1.0``.
transform : Optional[Callable], optional
A sequence of callable objects that preprocesses the raw data into appropriate form. The transform function
should take ``kspace``, ``coil sensitivity maps``, ``quantitative maps``, ``mask``, ``initial prediction``,
``target``, ``attributes``, ``filename``, and ``slice number`` as inputs. ``target`` may be null for test
data. Default is ``None``.
**kwargs
Additional keyword arguments.
"""
super().__init__()
self.coil_sensitivity_maps_root = coil_sensitivity_maps_root
self.mask_root = mask_root
if str(noise_root).endswith(".json"):
with open(noise_root, "r") as f: # type: ignore # pylint: disable=unspecified-encoding
noise_root = [json.loads(line) for line in f.readlines()] # type: ignore
else:
noise_root = None
self.initial_predictions_root = initial_predictions_root
self.dataset_format = dataset_format
# set default sampling mode if none given
if not is_none(sample_rate) and not is_none(volume_sample_rate):
raise ValueError(
f"Both sample_rate {sample_rate} and volume_sample_rate {volume_sample_rate} are set. "
"Please set only one of them."
)
if sample_rate is None or sample_rate == "None":
sample_rate = 1.0
if volume_sample_rate is None or volume_sample_rate == "None":
volume_sample_rate = 1.0
self.dataset_cache_file = None if is_none(dataset_cache_file) else Path(dataset_cache_file) # type: ignore
if self.dataset_cache_file is not None and self.dataset_cache_file.exists() and use_dataset_cache:
with open(self.dataset_cache_file, "rb") as f:
dataset_cache = yaml.safe_load(f)
else:
dataset_cache = {}
if consecutive_slices < 1:
raise ValueError(f"Consecutive slices {consecutive_slices} is out of range, must be > 0.")
self.consecutive_slices = consecutive_slices
self.complex_target = complex_target
self.transform = transform
self.data_saved_per_slice = data_saved_per_slice
self.recons_key = "reconstruction"
self.examples = []
# Check if our dataset is in the cache. If yes, use that metadata, if not, then regenerate the metadata.
if dataset_cache.get(root) is None or not use_dataset_cache:
if str(root).endswith(".json"):
with open(root, "r") as f: # pylint: disable=unspecified-encoding
examples = json.load(f)
files = [Path(example) for example in examples]
else:
files = list(Path(root).iterdir())
if n2r_supervised_rate != 0.0:
# randomly select a subset of files for N2R supervised loss based on n2r_supervised_rate
n2r_supervised_files = random.sample(
files, int(np.round(n2r_supervised_rate * len(files))) # type: ignore
)
for fname in sorted(files):
metadata, num_slices = self._retrieve_metadata(fname)
metadata["noise_levels"] = (
self.__parse_noise__(noise_root, fname) if noise_root is not None else [] # type: ignore
)
metadata["n2r_supervised"] = False
if n2r_supervised_rate != 0.0:
# Use lazy % formatting in logging
logging.info("%s files are selected for N2R supervised loss.", n2r_supervised_files)
if fname in n2r_supervised_files:
metadata["n2r_supervised"] = True
if not is_none(num_slices) and not is_none(consecutive_slices):
num_slices = num_slices - (consecutive_slices - 1)
# Specific to CC359 dataset, we need to remove the first and last 50 slices
self.examples += [
(fname, slice_ind, metadata) for slice_ind in range(num_slices) if 50 < slice_ind < num_slices - 50
]
if dataset_cache.get(root) is None and use_dataset_cache:
dataset_cache[root] = self.examples
logging.info("Saving dataset cache to %s.", self.dataset_cache_file)
with open(self.dataset_cache_file, "wb") as f: # type: ignore
yaml.dump(dataset_cache, f)
else:
logging.info("Using dataset cache from %s.", self.dataset_cache_file)
self.examples = dataset_cache[root]
# subsample if desired
if sample_rate < 1.0: # sample by slice
random.shuffle(self.examples)
num_examples = round(len(self.examples) * sample_rate)
self.examples = self.examples[:num_examples]
elif volume_sample_rate < 1.0: # sample by volume
vol_names = sorted(list({f[0].stem for f in self.examples}))
random.shuffle(vol_names)
num_volumes = round(len(vol_names) * volume_sample_rate)
sampled_vols = vol_names[:num_volumes]
self.examples = [example for example in self.examples if example[0].stem in sampled_vols]
if num_cols and not is_none(num_cols):
self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols]
self.indices_to_log = np.random.choice(
len(self.examples), int(log_images_rate * len(self.examples)), replace=False # type: ignore
)
def _retrieve_metadata(self, fname: Union[str, Path]) -> Tuple[Dict, int]:
"""Retrieve metadata from a given file.
Parameters
----------
fname : Union[str, Path]
Path to file.
Returns
-------
Tuple[Dict, int]
Metadata dictionary and number of slices in the file.
Examples
--------
>>> metadata, num_slices = _retrieve_metadata("file.h5")
>>> metadata
{'padding_left': 0, 'padding_right': 0, 'encoding_size': 0, 'recon_size': (0, 0)}
>>> num_slices
1
"""
with h5py.File(fname, "r") as hf:
if "ismrmrd_header" in hf:
et_root = fromstring(hf["ismrmrd_header"][()])
enc = ["encoding", "encodedSpace", "matrixSize"]
enc_size = (
int(et_query(et_root, enc + ["x"])),
int(et_query(et_root, enc + ["y"])),
int(et_query(et_root, enc + ["z"])),
)
rec = ["encoding", "reconSpace", "matrixSize"]
recon_size = (
int(et_query(et_root, rec + ["x"])),
int(et_query(et_root, rec + ["y"])),
int(et_query(et_root, rec + ["z"])),
)
params = ["encoding", "encodingLimits", "kspace_encoding_step_1"]
enc_limits_center = int(et_query(et_root, params + ["center"]))
enc_limits_max = int(et_query(et_root, params + ["maximum"])) + 1
padding_left = enc_size[1] // 2 - enc_limits_center
padding_right = padding_left + enc_limits_max
else:
padding_left = 0
padding_right = 0
enc_size = (0, 0, 0)
recon_size = (0, 0, 0)
if "kspace" in hf:
shape = hf["kspace"].shape
elif "reconstruction" in hf:
shape = hf["reconstruction"].shape
elif "target" in hf:
shape = hf["target"].shape
else:
raise ValueError(f"{fname} does not contain kspace, reconstruction, or target data.")
num_slices = 1 if self.data_saved_per_slice else shape[0]
metadata = {
"padding_left": padding_left,
"padding_right": padding_right,
"encoding_size": enc_size,
"recon_size": recon_size,
"num_slices": num_slices,
}
return metadata, num_slices
@staticmethod
def __parse_noise__(noise: str, fname: Path) -> List[str]:
"""Parse noise type from filename.
Parameters
----------
noise : str
json string of noise type.
fname : Path
Filename to parse noise type from.
Returns
-------
List[str]
List of noise values.
"""
return [noise[i]["noise"] for i in range(len(noise)) if noise[i]["fname"] == fname.name] # type: ignore
[docs] def get_consecutive_slices(self, data: Dict, key: str, dataslice: int) -> np.ndarray:
"""Get consecutive slices from a given data dictionary.
Parameters
----------
data : dict
Data to extract slices from.
key : str
Key to extract slices from.
dataslice : int
Slice to index.
Returns
-------
np.ndarray
Array of consecutive slices. If ``self.consecutive_slices`` is > 1, then the array will have shape
``(self.consecutive_slices, *data[key].shape[1:])``. Otherwise, the array will have shape
``data[key].shape[1:]``.
Examples
--------
>>> data = {"kspace": np.random.rand(10, 640, 368)}
>>> from atommic.collections.common.data.mri_loader import MRIDataset
>>> MRIDataset.get_consecutive_slices(data, "kspace", 1).shape
(1, 640, 368)
>>> MRIDataset.get_consecutive_slices(data, "kspace", 5).shape
(5, 640, 368)
"""
# read data
x = data[key]
if self.data_saved_per_slice:
x = np.expand_dims(x, axis=0)
if self.consecutive_slices == 1:
if x.shape[0] == 1:
return x[0]
if x.ndim != 2:
return x[dataslice]
return x
# get consecutive slices
num_slices = x.shape[0]
# If the number of consecutive slices is greater than or equal to the total slices, return the entire stack
if self.consecutive_slices >= num_slices:
# pad left and right with zero slices to match the desired number of slices
slices_to_add_start = (self.consecutive_slices - num_slices) // 2
slices_to_add_end = self.consecutive_slices - num_slices - slices_to_add_start
if slices_to_add_start > 0:
zero_slices = np.zeros((slices_to_add_start, *x.shape[1:]))
x = np.concatenate((zero_slices, x), axis=0)
if slices_to_add_end > 0:
zero_slices = np.zeros((slices_to_add_end, *x.shape[1:]))
x = np.concatenate((x, zero_slices), axis=0)
return x
# Calculate half of the consecutive slices to determine the middle position
half_slices = self.consecutive_slices // 2
# Determine the start and end slices based on the middle position
start_slice = dataslice - half_slices
end_slice = dataslice + half_slices + 1
# Handle edge cases
slices_to_add_start = 0
slices_to_add_end = 0
if start_slice < 0:
slices_to_add_start = abs(start_slice)
start_slice = 0
if end_slice > (num_slices - 1):
slices_to_add_end = end_slice - num_slices
extracted_slices = x[start_slice:]
else:
extracted_slices = x[start_slice:end_slice]
# Add slices to the start and end if needed
if slices_to_add_start > 0:
zero_slices = np.zeros((slices_to_add_start, *extracted_slices.shape[1:]))
extracted_slices = np.concatenate((zero_slices, extracted_slices), axis=0)
if slices_to_add_end > 0:
zero_slices = np.zeros((slices_to_add_end, *extracted_slices.shape[1:]))
extracted_slices = np.concatenate((extracted_slices, zero_slices), axis=0)
return extracted_slices
def __len__(self):
"""Length of :class:`MRIDataset`."""
return len(self.examples)
def __getitem__(self, i: int): # noqa: MC0001
"""Get item from :class:`CC359ReconstructionMRIDataset`."""
fname, dataslice, metadata = self.examples[i]
with h5py.File(fname, "r") as hf:
kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64)
kspace = np.transpose(kspace[..., ::2] + 1j * kspace[..., 1::2], (2, 0, 1))
sensitivity_map = np.array([])
if "sensitivity_map" in hf:
sensitivity_map = self.get_consecutive_slices(hf, "sensitivity_map", dataslice).astype(np.complex64)
elif "maps" in hf:
sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64)
elif self.coil_sensitivity_maps_root is not None and self.coil_sensitivity_maps_root != "None":
coil_sensitivity_maps_root = self.coil_sensitivity_maps_root
split_dir = str(fname).split("/")
for j in range(len(split_dir)):
coil_sensitivity_maps_root = Path(f"{self.coil_sensitivity_maps_root}/{split_dir[-j]}/")
if os.path.exists(coil_sensitivity_maps_root / Path(split_dir[-2]) / fname.name):
break
with h5py.File(Path(coil_sensitivity_maps_root) / Path(split_dir[-2]) / fname.name, "r") as sf:
if "sensitivity_map" in sf or "sensitivity_map" in next(iter(sf.keys())):
sensitivity_map = (
self.get_consecutive_slices(sf, "sensitivity_map", dataslice)
.squeeze()
.astype(np.complex64)
)
if self.mask_root is not None and self.mask_root != "None":
mask = []
with h5py.File(Path(self.mask_root) / fname.name, "r") as mf:
for key in mf.keys():
mask.append(np.asarray(self.get_consecutive_slices(mf, key, dataslice)))
else:
mask = None
prediction = np.empty([])
if not is_none(self.initial_predictions_root):
with h5py.File(Path(self.initial_predictions_root) / fname.name, "r") as ipf: # type: ignore
if "reconstruction" in hf:
prediction = (
self.get_consecutive_slices(ipf, "reconstruction", dataslice)
.squeeze()
.astype(np.complex64)
)
elif "initial_prediction" in hf:
prediction = (
self.get_consecutive_slices(ipf, "initial_prediction", dataslice)
.squeeze()
.astype(np.complex64)
)
else:
if "reconstruction" in hf:
prediction = (
self.get_consecutive_slices(hf, "reconstruction", dataslice).squeeze().astype(np.complex64)
)
elif "initial_prediction" in hf:
prediction = (
self.get_consecutive_slices(hf, "initial_prediction", dataslice).squeeze().astype(np.complex64)
)
if self.complex_target:
target = None
else:
# find key containing "reconstruction_"
rkey = re.findall(r"reconstruction_(.*)", str(hf.keys()))
self.recons_key = "reconstruction_" + rkey[0] if rkey else "target"
if "reconstruction_rss" in self.recons_key:
self.recons_key = "reconstruction_rss"
elif "reconstruction_sense" in hf:
self.recons_key = "reconstruction_sense"
target = self.get_consecutive_slices(hf, self.recons_key, dataslice) if self.recons_key in hf else None
attrs = dict(hf.attrs)
# get noise level for current slice, if metadata["noise_levels"] is not empty
if "noise_levels" in metadata and len(metadata["noise_levels"]) > 0:
metadata["noise"] = metadata["noise_levels"][dataslice]
else:
metadata["noise"] = 1.0
attrs.update(metadata)
if sensitivity_map.shape != kspace.shape and sensitivity_map.ndim > 1:
if sensitivity_map.ndim == 3:
sensitivity_map = np.transpose(sensitivity_map, (2, 0, 1))
elif sensitivity_map.ndim == 4:
sensitivity_map = np.transpose(sensitivity_map, (0, 3, 1, 2))
else:
raise ValueError(
f"Sensitivity map has invalid dimensions {sensitivity_map.shape} compared to kspace {kspace.shape}"
)
attrs["log_image"] = bool(dataslice in self.indices_to_log)
return (
(
kspace,
sensitivity_map,
mask,
prediction,
target,
attrs,
fname.name,
dataslice,
)
if self.transform is None
else self.transform(
kspace,
sensitivity_map,
mask,
prediction,
target,
attrs,
fname.name,
dataslice,
)
)
[docs]class SKMTEAReconstructionMRIDataset(MRIDataset):
"""Supports the SKM-TEA dataset for accelerated MRI reconstruction.
.. note::
Extends :class:`atommic.collections.reconstruction.data.mri_reconstruction_loader.ReconstructionMRIDataset`.
"""
def __getitem__(self, i: int): # noqa: MC0001
"""Get item from :class:`SKMTEAReconstructionMRIDataset`."""
if not is_none(self.dataset_format):
dataset_format = self.dataset_format.lower() # type: ignore
masking = "default"
if "custom_masking" in dataset_format:
masking = "custom"
dataset_format = dataset_format.replace("custom_masking", "").strip("_")
else:
dataset_format = None
masking = "custom"
fname, dataslice, metadata = self.examples[i]
with h5py.File(fname, "r") as hf:
kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64)
if not is_none(dataset_format) and dataset_format == "skm-tea-echo1":
kspace = kspace[:, :, 0, :]
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo2":
kspace = kspace[:, :, 1, :]
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2":
kspace = kspace[:, :, 0, :] + kspace[:, :, 1, :]
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2-mc":
kspace = np.concatenate([kspace[:, :, 0, :], kspace[:, :, 1, :]], axis=-1)
else:
warnings.warn(
f"Dataset format {dataset_format} is either not supported or set to None. "
"Using by default only the first echo."
)
kspace = kspace[:, :, 0, :]
kspace = kspace[48:-48, 40:-40]
sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64)
sensitivity_map = sensitivity_map[..., 0]
sensitivity_map = sensitivity_map[48:-48, 40:-40]
if masking == "custom":
mask = np.array([])
else:
masks = hf["masks"]
mask = {}
for key, val in masks.items():
mask[key.split("_")[-1].split(".")[0]] = np.asarray(val)
prediction = np.empty([])
if not is_none(self.initial_predictions_root):
if "reconstruction" in hf:
with h5py.File(Path(self.initial_predictions_root) / fname.name, "r") as ipf: # type: ignore
prediction = (
self.get_consecutive_slices(ipf, "reconstruction", dataslice)
.squeeze()
.astype(np.complex64)
)
elif "initial_prediction" in hf:
with h5py.File(Path(self.initial_predictions_root) / fname.name, "r") as ipf: # type: ignore
prediction = (
self.get_consecutive_slices(ipf, "initial_prediction", dataslice)
.squeeze()
.astype(np.complex64)
)
else:
if "reconstruction" in hf:
prediction = (
self.get_consecutive_slices(hf, "reconstruction", dataslice).squeeze().astype(np.complex64)
)
elif "initial_prediction" in hf:
prediction = (
self.get_consecutive_slices(hf, "initial_prediction", dataslice).squeeze().astype(np.complex64)
)
if self.complex_target:
target = None
else:
# find key containing "reconstruction_"
self.recons_key = "target"
target = self.get_consecutive_slices(hf, self.recons_key, dataslice) if self.recons_key in hf else None
attrs = dict(hf.attrs)
# get noise level for current slice, if metadata["noise_levels"] is not empty
if "noise_levels" in metadata and len(metadata["noise_levels"]) > 0:
metadata["noise"] = metadata["noise_levels"][dataslice]
else:
metadata["noise"] = 1.0
attrs.update(metadata)
kspace = np.transpose(kspace, (2, 0, 1))
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1))
attrs["log_image"] = bool(dataslice in self.indices_to_log)
return (
(
kspace,
sensitivity_map,
mask,
prediction,
target,
attrs,
fname.name,
dataslice,
)
if self.transform is None
else self.transform(
kspace,
sensitivity_map,
mask,
prediction,
target,
attrs,
fname.name,
dataslice,
)
)
[docs]class StanfordKneesReconstructionMRIDataset(MRIDataset):
"""Supports the Stanford Knees 2019 dataset for accelerated MRI reconstruction.
.. note::
Extends :class:`atommic.collections.reconstruction.data.mri_reconstruction_loader.ReconstructionMRIDataset`.
"""
def __getitem__(self, i: int):
"""Get item from :class:`StanfordKneesReconstructionMRIDataset`."""
fname, dataslice, metadata = self.examples[i]
with h5py.File(fname, "r") as hf:
kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64)
attrs = dict(hf.attrs)
sensitivity_map = np.array([])
if "sensitivity_map" in hf:
sensitivity_map = self.get_consecutive_slices(hf, "sensitivity_map", dataslice).astype(np.complex64)
elif "maps" in hf:
sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64)
elif self.coil_sensitivity_maps_root is not None and self.coil_sensitivity_maps_root != "None":
coil_sensitivity_maps_root = self.coil_sensitivity_maps_root
split_dir = str(fname).split("/")
for j in range(len(split_dir)):
coil_sensitivity_maps_root = Path(f"{self.coil_sensitivity_maps_root}/{split_dir[-j]}/")
if os.path.exists(coil_sensitivity_maps_root / Path(split_dir[-2]) / fname.name):
break
with h5py.File(Path(coil_sensitivity_maps_root) / Path(split_dir[-2]) / fname.name, "r") as sf:
if "sensitivity_map" in sf or "sensitivity_map" in next(iter(sf.keys())):
sensitivity_map = (
self.get_consecutive_slices(sf, "sensitivity_map", dataslice).squeeze().astype(np.complex64)
)
# get noise level for current slice, if metadata["noise_levels"] is not empty
metadata["noise"] = (
metadata["noise_levels"][dataslice]
if "noise_levels" in metadata and len(metadata["noise_levels"]) > 0
else 1.0
)
attrs.update(metadata)
attrs["log_image"] = bool(dataslice in self.indices_to_log)
mask = None
prediction = None
target = np.array([])
return (
(
kspace,
sensitivity_map,
mask,
prediction,
target,
attrs,
fname.name,
dataslice,
)
if self.transform is None
else self.transform(
kspace,
sensitivity_map,
mask,
prediction,
target,
attrs,
fname.name,
dataslice,
)
)