Source code for atommic.collections.common.parts.transforms
# coding=utf-8
from __future__ import annotations
__author__ = "Dimitris Karkalousos"
import os
from math import sqrt
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from atommic.collections.common.parts.coil_sensitivity_maps import EspiritCalibration
from atommic.collections.common.parts.fft import fft2, ifft2
from atommic.collections.common.parts.utils import add_coil_dim_if_singlecoil, apply_mask, center_crop
from atommic.collections.common.parts.utils import coil_combination_method as coil_combination_method_func
from atommic.collections.common.parts.utils import is_none, reshape_fortran, rss, to_tensor
from atommic.collections.motioncorrection.parts.motionsimulation import MotionSimulation
__all__ = [
"Composer",
"Cropper",
"EstimateCoilSensitivityMaps",
"GeometricDecompositionCoilCompression",
"Masker",
"MRIDataTransforms",
"N2R",
"NoisePreWhitening",
"Normalizer",
"SNREstimator",
"SSDU",
"ZeroFillingPadding",
]
[docs]class Composer:
"""Composes multiple transforms together.
Returns
-------
composed_data: torch.Tensor
Composed data.
Example
--------
>>> import torch
>>> from atommic.collections.common.parts.transforms import Composer, Masker, Normalizer
>>> data = torch.randn(1, 32, 320, 320, 2) 1j * torch.randn(1, 32, 320, 320, 2)
>>> print(torch.min(torch.abs(data)), torch.max(torch.abs(data)))
tensor(1e-06) tensor(1.4142)
>>> masker = Masker(mask_func="random", padding="reflection", seed=0)
>>> normalizer = Normalizer(normalization_type="max")
>>> composer = Composer([masker, normalizer])
>>> composed_data = composer(data)
>>> print(torch.min(torch.abs(composed_data)), torch.max(torch.abs(composed_data)))
tensor(0.) tensor(1.)
"""
[docs] def __init__(self, transforms: Union[List[Callable], Callable, None]):
"""Inits :class:`Composer`.
Parameters
----------
transforms: list
List of transforms to compose.
"""
self.transforms = transforms
def __call__(
self,
data: Union[torch.Tensor, List[torch.Tensor], None],
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> List[torch.Tensor] | torch.Tensor:
"""Calls :class:`Composer`."""
for transform in self.transforms: # type: ignore
if not is_none(transform):
data = transform(data, apply_backward_transform, apply_forward_transform)
return data
def __repr__(self):
"""Representation of :class:`Composer`."""
return f"Composed transforms: {self.transforms}"
def __str__(self):
"""String representation of :class:`Composer`."""
return self.__repr__()
[docs]class Cropper:
"""Crops data to a given size.
Returns
-------
cropped_data : torch.Tensor
Cropped data.
Example
-------
>>> import torch
>>> from atommic.collections.common.parts.transforms import Cropper
>>> data = torch.randn(1, 15, 320, 320, 2)
>>> cropping = Cropper(cropping_size=(256, 256), spatial_dims=(-2, -1)) # don't account for complex dim
>>> cropped_data = cropping(data)
>>> cropped_data.shape
[1, 15, 256, 256, 2]
"""
[docs] def __init__(
self,
cropping_size: Tuple,
fft_centered: bool = False,
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = (-2, -1),
):
"""Inits :class:`Cropper`.
Parameters
----------
cropping_size : tuple
Size of the cropped data.
fft_centered : bool
If True, the input is assumed to be centered in the frequency domain. Default is `False`.
fft_normalization : str
Normalization of the FFT. Default is `backward`.
spatial_dims : tuple
Spatial dimensions.
"""
self.cropping_size = cropping_size
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims
def __call__(
self,
data: Union[torch.Tensor, List[torch.Tensor], None],
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> List[torch.Tensor] | torch.Tensor:
"""Calls :class:`Cropper`.
Parameters
----------
data : torch.Tensor
Input data to crop.
apply_backward_transform : bool
Apply backward transform, i.e. Inverse Fast Fourier Transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform, i.e. Fast Fourier Transform. Default is ``False``.
"""
if not is_none(data):
if isinstance(data, list) and len(data) > 0:
return [self.forward(d, apply_backward_transform, apply_forward_transform) for d in data]
if data.dim() > 1 and data.mean() != 1: # type: ignore
return self.forward(data, apply_backward_transform, apply_forward_transform)
return data
def __repr__(self):
"""Representation of :class:`Cropper`."""
return f"Data will be cropped to size={self.cropping_size}."
def __str__(self):
"""String representation of :class:`Cropper`."""
return self.__repr__()
[docs] def forward(
self,
data: torch.Tensor,
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> torch.Tensor:
"""Forward pass of :class:`Cropper`.
Parameters
----------
data : torch.Tensor
Input data to crop.
apply_backward_transform : bool
Apply backward transform, i.e. Inverse Fast Fourier Transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform, i.e. Fast Fourier Transform. Default is ``False``.
"""
if apply_backward_transform:
data = ifft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif apply_forward_transform:
data = fft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
is_complex = data.shape[-1] == 2
is_one = data.shape[-1] == 1
if is_complex:
data = torch.view_as_complex(data)
elif is_one:
data = data.squeeze(-1)
crop_size = (data.shape[self.spatial_dims[0]], data.shape[self.spatial_dims[1]])
# Check for smallest size against the target shape.
h = min(int(self.cropping_size[0]), crop_size[0])
w = min(int(self.cropping_size[1]), crop_size[1])
# Check for smallest size against the stored recon shape in data.
if crop_size[0] != 0:
h = h if h <= crop_size[0] else crop_size[0]
if crop_size[1] != 0:
w = w if w <= crop_size[1] else crop_size[1]
data = center_crop(data, (int(h), int(w)))
if is_complex:
data = torch.view_as_real(data)
elif is_one:
data = data.unsqueeze(-1)
if apply_backward_transform:
data = fft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif apply_forward_transform:
data = ifft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
return data
[docs]class EstimateCoilSensitivityMaps:
r"""Data Transformer for training MRI reconstruction models.
Estimates sensitivity maps given masked k-space data using one of three methods:
* Unit: unit sensitivity map in case of single coil acquisition.
* RSS-estimate: sensitivity maps estimated by using the root-sum-of-squares of the autocalibration-signal.
* ESPIRIT: sensitivity maps estimated with the ESPIRIT method [Uecker2014]_.
References
----------
.. [Uecker2014] Uecker M, Lai P, Murphy MJ, Virtue P, Elad M, Pauly JM, Vasanawala SS, Lustig M. ESPIRiT--an
eigenvalue approach to autocalibrating parallel MRI: where SENSE meets GRAPPA. Magn Reson Med. 2014
Mar;71(3):990-1001. doi: 10.1002/mrm.24751. PMID: 23649942; PMCID: PMC4142121.
"""
[docs] def __init__(
self,
coil_sensitivity_maps_type: str = "espirit",
gaussian_sigma: Optional[float] = None,
espirit_threshold: float = 0.05,
espirit_kernel_size: int = 6,
espirit_crop: float = 0.95,
espirit_max_iters: int = 30,
fft_centered: bool = False,
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = (-2, -1),
coil_dim: int = 1,
) -> None:
"""Inits :class:`EstimateSensitivityMapModule`.
Parameters
----------
type: str
Type of sensitivity map to estimate. One of "unit", "rss", "espirit". Default is ``"espirit"``.
gaussian_sigma: float, optional
If non-zero, acs_image well be calculated
espirit_threshold: float
Threshold for the calibration matrix when `type`=="espirit". Default: 0.05.
espirit_kernel_size: int
Kernel size for the calibration matrix when `type`=="espirit". Default: 6.
espirit_crop: float
Output eigenvalue cropping threshold when `type`=="espirit". Default: 0.95.
espirit_max_iters: int
Power method iterations when `type`=="espirit". Default: 30.
fft_centered: bool
Whether to center the FFT. Default is ``False``.
fft_normalization: str
Normalization to apply to the FFT. Default is ``"backward"``.
spatial_dims: Sequence[int]
Spatial dimensions of the input. Default is ``(-2, -1)``.
coil_dim: int
Dimension corresponding to coil. Default: 1.
"""
super().__init__()
self.coil_sensitivity_maps_type = coil_sensitivity_maps_type
if self.coil_sensitivity_maps_type not in ["unit", "rss", "espirit"]:
raise ValueError(
f"Expected type of map to be either `unit`, `rss`, `espirit`. Got {self.coil_sensitivity_maps_type}."
)
self.gaussian_sigma = gaussian_sigma
self.espirit_threshold = espirit_threshold
self.espirit_kernel_size = espirit_kernel_size
self.espirit_crop = espirit_crop
self.espirit_max_iters = espirit_max_iters
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims
self.coil_dim = coil_dim
# Espirit attributes
if self.coil_sensitivity_maps_type == "espirit":
self.espirit_calibrator = EspiritCalibration(
self.espirit_threshold,
self.espirit_kernel_size,
self.espirit_crop,
self.espirit_max_iters,
self.fft_centered,
self.fft_normalization,
self.spatial_dims,
)
[docs] def calculate_acs_mask(self, kspace: torch.Tensor) -> torch.Tensor:
"""Calculates the autocalibration (ACS) mask.
Parameters
----------
kspace : torch.Tensor
K-space.
Returns
-------
acs_mask: torch.Tensor
Autocalibration mask.
"""
# size of k-space
Nx = kspace.shape[-3]
Ny = kspace.shape[-2]
# create an empty mask
acs_mask = torch.zeros((Nx, Ny))
# define the indices for the center region
acs_start_x = int((Nx - self.espirit_kernel_size) / 2)
acs_start_y = int((Ny - self.espirit_kernel_size) / 2)
acs_end_x = int((Nx + self.espirit_kernel_size) / 2)
acs_end_y = int((Ny + self.espirit_kernel_size) / 2)
# set the center region to 1
acs_mask[acs_start_x:acs_end_x, acs_start_y:acs_end_y] = 1
# reshape acs_mask to kspace
acs_mask = acs_mask.unsqueeze(0).unsqueeze(-1)
return acs_mask
[docs] def estimate_acs_image(self, acs_mask: torch.Tensor, kspace: torch.Tensor, width_dim: int = -2) -> torch.Tensor:
"""Estimates the autocalibration (ACS) image by sampling the k-space using the ACS mask.
Parameters
----------
acs_mask : torch.Tensor
Autocalibration mask.
kspace : torch.Tensor
K-space.
width_dim: int
Dimension corresponding to width. Default: -2.
Returns
-------
acs_image: torch.Tensor
Estimate of the ACS image.
"""
if self.gaussian_sigma == 0 or not self.gaussian_sigma:
kspace_acs = kspace * acs_mask + 0.0 # + 0.0 removes the sign of zeros.
else:
gaussian_mask = torch.linspace(-1, 1, kspace.size(width_dim), dtype=kspace.dtype)
gaussian_mask = torch.exp(-((gaussian_mask / self.gaussian_sigma) ** 2))
gaussian_mask_shape = torch.ones(len(kspace.shape)).int()
gaussian_mask_shape[width_dim] = kspace.size(width_dim)
gaussian_mask = gaussian_mask.reshape(tuple(gaussian_mask_shape))
kspace_acs = kspace * acs_mask * gaussian_mask + 0.0
# Get complex-valued data solution
# Shape (batch, coil, height, width, complex=2)
acs_image = ifft2(kspace_acs, self.fft_centered, self.fft_normalization, self.spatial_dims)
return acs_image
def __call__(self, kspace: torch.Tensor) -> torch.Tensor:
"""Estimates sensitivity maps for the input sample."""
return self.forward(kspace)
def __repr__(self) -> str:
"""Representation of :class:`EstimateCoilSensitivityMaps`."""
return f"Estimating coil sensitivity maps of {self.coil_sensitivity_maps_type} type."
def __str__(self) -> str:
"""String representation of :class:`EstimateCoilSensitivityMaps`."""
return self.__repr__()
[docs] def forward(self, kspace: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`EstimateCoilSensitivityMaps`."""
acs_mask = self.calculate_acs_mask(kspace)
if self.coil_sensitivity_maps_type == "unit":
sensitivity_map = torch.zeros(kspace.shape).float()
# Assumes complex channel is last
# assert_complex(kspace, complex_last=True)
sensitivity_map[..., 0] = 1.0
# Shape (coil, height, width, complex=2)
sensitivity_map = sensitivity_map.to(kspace.device)
elif self.coil_sensitivity_maps_type == "rss":
# Shape (batch, coil, height, width, complex=2)
acs_image = self.estimate_acs_image(acs_mask, kspace)
# Shape (batch, height, width)
acs_image_rss = rss(acs_image, dim=self.coil_dim)
# Shape (batch, 1, height, width, 1)
acs_image_rss = acs_image_rss.unsqueeze(self.coil_dim)
# Shape (batch, coil, height, width, complex=2)
sensitivity_map = torch.where(
acs_image_rss == 0,
torch.tensor([0.0], dtype=acs_image.dtype).to(acs_image.device),
acs_image / acs_image_rss,
)
else:
sensitivity_map = self.espirit_calibrator(acs_mask, kspace)
sensitivity_map_norm = torch.sqrt((sensitivity_map**2).sum(-1).sum(self.coil_dim))
sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self.coil_dim).unsqueeze(-1)
sensitivity_map = torch.where(
sensitivity_map_norm == 0,
torch.tensor([0.0], dtype=sensitivity_map.dtype).to(sensitivity_map.device),
sensitivity_map / sensitivity_map_norm,
)
return sensitivity_map
[docs]class GeometricDecompositionCoilCompression:
"""Geometric Decomposition Coil Compression in PyTorch, as presented in [Zhang2013]_.
References
----------
.. [Zhang2013] Zhang, T., Pauly, J. M., Vasanawala, S. S., & Lustig, M. (2013). Coil compression for accelerated
imaging with Cartesian sampling. Magnetic Resonance in Medicine, 69(2), 571–582.
https://doi.org/10.1002/mrm.24267
Returns
-------
torch.Tensor
Coil compressed data.
Examples
--------
>>> import torch
>>> from atommic.collections.common.parts.transforms import GeometricDecompositionCoilCompression
>>> data = torch.randn([30, 100, 100], dtype=torch.complex64)
>>> gdcc = GeometricDecompositionCoilCompression(virtual_coils=10, calib_lines=24, spatial_dims=[-2, -1])
>>> gdcc(data).shape
torch.Size([10, 100, 100, 2])
"""
[docs] def __init__(
self,
virtual_coils: int = None,
calib_lines: int = None,
align_data: bool = True,
fft_centered: bool = False,
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = (-2, -1),
):
"""Inits :class:`GeometricDecompositionCoilCompression`.
Parameters
----------
virtual_coils : int
Number of final-"virtual" coils.
calib_lines : int
Calibration lines to sample data points.
align_data : bool
Align data to the first calibration line. Default is ``True``.
fft_centered : bool
Whether to center the fft. Default is ``False``.
fft_normalization : str
FFT normalization. Default is ``"backward"``.
spatial_dims : Sequence[int]
Dimensions to apply the FFT. Default is ``None``.
"""
super().__init__()
# TODO: account for multiple echo times
self.virtual_coils = virtual_coils
self.calib_lines = calib_lines
self.align_data = align_data
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims
def __call__(
self,
data: Union[torch.Tensor, None],
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> torch.Tensor:
"""Calls :class:`GeometricDecompositionCoilCompression`.
Parameters
----------
data : torch.Tensor
Input data to apply coil compression.
apply_backward_transform : bool
Apply backward transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform. Default is ``False``.
"""
if not is_none(data) and data.dim() > 1 and data.mean() != 1: # type: ignore
return self.forward(data, apply_backward_transform, apply_forward_transform)
return data
def __repr__(self):
"""Representation of :class:`GeometricDecompositionCoilCompression`."""
return f"Coil Compression is applied reducing coils to {self.virtual_coils}."
def __str__(self):
"""String representation of :class:`GeometricDecompositionCoilCompression`."""
return str(self.__repr__)
# pylint: disable=unused-argument
[docs] def forward(
self,
data: torch.Tensor,
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> torch.Tensor:
"""Forward pass of :class:`GeometricDecompositionCoilCompression`.
Parameters
----------
data : torch.Tensor
Input data to apply coil compression.
apply_backward_transform : bool
Apply backward transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform. Default is ``False``.
Returns
-------
torch.Tensor
Coil compressed data.
"""
if not self.virtual_coils:
raise ValueError("Number of virtual coils must be defined for geometric decomposition coil compression.")
if apply_forward_transform:
data = fft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
self.data = data
if self.data.shape[-1] == 2:
self.data = torch.view_as_complex(self.data)
curr_num_coils = self.data.shape[0]
if curr_num_coils < self.virtual_coils:
raise ValueError(
f"Tried to compress from {curr_num_coils} to {self.virtual_coils} coils, please select less coils."
)
self.data = self.data.permute(1, 2, 0)
self.init_data: torch.Tensor = self.data
self.fft_dim = [0, 1]
_, self.width, self.coils = self.data.shape
# TODO: figure out why this is happening for singlecoil data
# For singlecoil data, use no calibration lines equal to the no of coils.
if self.virtual_coils == 1:
self.calib_lines = self.data.shape[-1]
self.crop()
self.calculate_gcc()
if self.align_data:
self.align_compressed_coils()
rotated_compressed_data = self.rotate_and_compress(data_to_cc=self.aligned_data)
else:
rotated_compressed_data = self.rotate_and_compress(data_to_cc=self.unaligned_data)
rotated_compressed_data = torch.flip(rotated_compressed_data, dims=[1])
rotated_compressed_data = torch.view_as_real(rotated_compressed_data.permute(2, 0, 1))
if not apply_forward_transform:
rotated_compressed_data = fft2(
rotated_compressed_data,
self.fft_centered,
self.fft_normalization,
self.spatial_dims,
)
return rotated_compressed_data.detach().clone()
[docs] def crop(self):
"""Crop to the size of the calibration lines."""
s = torch.as_tensor([self.calib_lines, self.width, self.coils])
idx = [
torch.arange(
abs(int(self.data.shape[n] // 2 + torch.ceil(-s[n] / 2))),
abs(int(self.data.shape[n] // 2 + torch.ceil(s[n] / 2) + 1)),
)
for n in range(len(s))
]
self.data = (
self.data[idx[0][0] : idx[0][-1], idx[1][0] : idx[1][-1], idx[2][0] : idx[2][-1]]
.unsqueeze(-2)
.permute(1, 0, 2, 3)
)
[docs] def calculate_gcc(self):
"""Calculates Geometric Coil-Compression."""
ws = (self.virtual_coils // 2) * 2 + 1
Nx, Ny, Nz, Nc = self.data.shape
im = torch.view_as_complex(
ifft2(torch.view_as_real(self.data), self.fft_centered, self.fft_normalization, spatial_dims=0)
)
s = torch.as_tensor([Nx + ws - 1, Ny, Nz, Nc])
idx = [
torch.arange(
abs(int(im.shape[n] // 2 + torch.ceil((-s[n] / 2).clone().detach()))),
abs(int(im.shape[n] // 2 + torch.ceil((s[n] / 2).clone().detach())) + 1),
)
for n in range(len(s))
]
zpim = torch.zeros((Nx + ws - 1, Ny, Nz, Nc)).type(im.dtype)
zpim[idx[0][0] : idx[0][-1], idx[1][0] : idx[1][-1], idx[2][0] : idx[2][-1], idx[3][0] : idx[3][-1]] = im
self.unaligned_data = torch.zeros((Nc, min(Nc, ws * Ny * Nz), Nx)).type(im.dtype)
for n in range(Nx):
tmpc = reshape_fortran(zpim[n : n + ws, :, :, :], (ws * Ny * Nz, Nc))
_, _, v = torch.svd(tmpc, some=False)
self.unaligned_data[:, :, n] = v
self.unaligned_data = self.unaligned_data[:, : self.virtual_coils, :]
[docs] def align_compressed_coils(self):
"""Virtual Coil Alignment."""
self.aligned_data = self.unaligned_data
_, sy, nc = self.aligned_data.shape
ncc = sy
n0 = nc // 2
A00 = self.aligned_data[:, :ncc, n0 - 1]
A0 = A00
for n in range(n0, 0, -1):
A1 = self.aligned_data[:, :ncc, n - 1]
C = torch.conj(A1).T @ A0
u, _, v = torch.svd(C, some=False)
P = v @ torch.conj(u).T
self.aligned_data[:, :ncc, n - 1] = A1 @ torch.conj(P).T
A0 = self.aligned_data[:, :ncc, n - 1]
A0 = A00
for n in range(n0, nc):
A1 = self.aligned_data[:, :ncc, n]
C = torch.conj(A1).T @ A0
u, _, v = torch.svd(C, some=False)
P = v @ torch.conj(u).T
self.aligned_data[:, :ncc, n] = A1 @ torch.conj(P).T
A0 = self.aligned_data[:, :ncc, n]
[docs] def rotate_and_compress(self, data_to_cc):
"""Uses compression matrices to project the data onto them -> rotate to the compressed space."""
_data = self.init_data.permute(1, 0, 2).unsqueeze(-2)
_ncc = data_to_cc.shape[1]
data_to_cc = data_to_cc.to(_data.device)
Nx, Ny, Nz, Nc = _data.shape
im = torch.view_as_complex(
ifft2(torch.view_as_real(_data), self.fft_centered, self.fft_normalization, spatial_dims=0)
)
ccdata = torch.zeros((Nx, Ny, Nz, _ncc)).type(_data.dtype).to(_data.device)
for n in range(Nx):
tmpc = im[n, :, :, :].squeeze().reshape(Ny * Nz, Nc)
ccdata[n, :, :, :] = (tmpc @ data_to_cc[:, :, n]).reshape(Ny, Nz, _ncc).unsqueeze(0)
ccdata = (
torch.view_as_complex(
fft2(torch.view_as_real(ccdata), self.fft_centered, self.fft_normalization, spatial_dims=0)
)
.permute(1, 0, 2, 3)
.squeeze()
)
# Singlecoil
if ccdata.dim() == 2:
ccdata = ccdata.unsqueeze(-1)
gcc = torch.zeros(ccdata.shape).type(ccdata.dtype)
for n in range(ccdata.shape[-1]):
gcc[:, :, n] = torch.view_as_complex(
ifft2(
torch.view_as_real(ccdata[:, :, n]), self.fft_centered, self.fft_normalization, self.spatial_dims
)
)
return gcc
class Masker:
"""Undersamples k-space data.
Returns
-------
Tuple[List[torch.Tensor], List[torch.Tensor], List[float]]
Masked data, mask, and acceleration factor. They are returned as a tuple of lists, where each list corresponds
to a different acceleration factor. If one acceleration factor is provided, the lists will be of length 1.
Example
-------
>>> import torch
>>> from atommic.collections.common.parts.transforms import Masker
>>> data = torch.randn(1, 15, 320, 320, 2)
>>> mask = torch.ones(320, 320)
>>> masker = Masker(mask_func=None, spatial_dims=(-2, -1), shift_mask=False, partial_fourier_percentage=0.0, \
center_scale=0.02, dimensionality=2, remask=True)
>>> masked_data = masker(data, mask, seed=None)
>>> masked_data[0][0].shape # masked data
[1, 15, 320, 320, 2]
>>> masked_data[1][0].shape # mask
[320, 320]
>>> masked_data[2][0] # acceleration factor
10.0
"""
def __init__(
self,
mask_func: Optional[Callable] = None,
spatial_dims: Sequence[int] = (-2, -1),
shift_mask: bool = False,
partial_fourier_percentage: float = 0.0,
center_scale: float = 0.02,
dimensionality: int = 2,
remask: bool = True,
dataset_format: str = None,
):
"""Inits :class:`Masker`.
Parameters
----------
mask_func : callable, optional
Masker function. Default is `None`.
spatial_dims : tuple
Spatial dimensions. Default is `(-2, -1)`.
shift_mask : bool
Whether to shift the mask. Default is `False`.
partial_fourier_percentage : float
Whether to simulate half scan. Default is `0.0`, which means no half scan.
center_scale : float
Percentage of center to remain densely sampled. Default is `0.02`.
dimensionality : int
Dimensionality of the data. Default is `2`.
remask: bool
Whether to remask the data. If False, the mask will be generated only once. If True, the mask will be
enerated every time the transform is called. Default is `False`.
dataset_format : str, optional
The format of the dataset. Usefull if loading precomputed masks. For example, ``'custom_dataset'`` or
``'public_dataset_name'``. Default is ``None``.
"""
self.mask_func = mask_func
self.spatial_dims = spatial_dims
self.shift_mask = shift_mask
self.partial_fourier_percentage = partial_fourier_percentage
self.center_scale = center_scale
self.dimensionality = dimensionality
self.remask = remask
self.dataset_format = dataset_format
def __call__(
self,
data: torch.Tensor,
mask: Union[List, torch.Tensor, np.ndarray] = None,
padding: Optional[Tuple] = None,
seed: Optional[int] = None,
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> Tuple[
List[float | torch.Tensor | Any],
List[torch.Tensor | Any] | List[torch.Tensor | np.ndarray | None | Any],
List[int | torch.Tensor | Any],
]:
"""Calls :class:`Masker`.
Parameters
----------
data : torch.Tensor
Input k-space data to apply mask.
mask : Union[List, torch.Tensor, np.ndarray], optional
Mask to apply. Default is ``None``.
padding : Optional[Tuple], optional
Padding to apply. Default is ``None``.
seed : Optional[int], optional
Seed to apply. Default is ``None``.
apply_backward_transform : bool
Apply backward transform, i.e. inverse Fast Fourier Transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform, i.e. Fast Fourier Transform. Default is ``False``.
"""
if self.dataset_format is not None and "skm-tea" in self.dataset_format.lower():
if not is_none(self.mask_func) and not isinstance(mask, np.ndarray):
# if skm-tea dataset, then the mask is already computed and loaded
accelerations = list(self.mask_func[0].accelerations) # type: ignore
self.acc = []
masks = []
for i in range(len(accelerations)): # pylint: disable=consider-using-enumerate
self.acc.append(accelerations[i])
masks.append(mask[str(accelerations[i])]) # type: ignore
mask = masks
else:
mask = None
# Check if mask is precomputed or not.
if not is_none(mask):
if isinstance(mask, list):
if len(mask) == 0:
mask = None
elif mask.ndim == 0: # type: ignore
mask = None
if not is_none(mask) and isinstance(mask, list) and len(mask) > 0:
self.__type__ = "Masks are precomputed and loaded."
elif not is_none(mask) and not isinstance(mask, list) and mask.ndim != 0 and len(mask) > 0: # type: ignore
self.__type__ = "Mask is either precomputed and loaded or data are prospectively undersampled."
elif isinstance(self.mask_func, list):
self.__type__ = "Multiple accelerations are provided and masks are generated on the fly."
else:
self.__type__ = "A single acceleration is provided and mask is generated on the fly."
return self.forward(data, mask, padding, seed)
def __repr__(self) -> str:
"""Representation of :class:`Masker`."""
return f"{self.__type__}"
def __str__(self) -> str:
"""String representation of :class:`Masker`."""
return self.__repr__()
def forward( # noqa: MC0001
self,
data: torch.Tensor,
mask: Union[List, torch.Tensor, np.ndarray] = None,
padding: Optional[Tuple] = None,
seed: Optional[int] = None,
) -> Tuple[
List[float | torch.Tensor | Any],
List[torch.Tensor | Any] | List[torch.Tensor | np.ndarray | None | Any],
List[int | torch.Tensor | Any],
]:
"""Forward pass of :class:`Masker`.
Parameters
----------
data : torch.Tensor
Input k-space data to apply mask.
mask : Union[List, torch.Tensor, np.ndarray], optional
Mask to apply. Default is ``None``.
padding : Optional[Tuple], optional
Padding to apply. Default is ``None``.
seed : Optional[int], optional
Seed to apply. Default is ``None``.
"""
is_complex = data.shape[-1] == 2
spatial_dims = tuple(x - 1 for x in self.spatial_dims) if is_complex else self.spatial_dims
if not is_none(mask) and isinstance(mask, list) and len(mask) > 0:
masked_data = []
masks = []
accelerations = []
for i, m in enumerate(mask):
if list(m.shape) == [data.shape[spatial_dims[0]], data.shape[spatial_dims[1]]]:
if isinstance(m, np.ndarray):
m = torch.from_numpy(m)
m = m.unsqueeze(0).unsqueeze(-1)
if not is_none(padding[0]) and padding[0] != 0: # type: ignore
m[:, :, : padding[0]] = 0 # type: ignore
m[:, :, padding[1] :] = 0 # type: ignore
if self.shift_mask:
m = torch.fft.fftshift(m, dim=(spatial_dims[0], spatial_dims[1]))
m = m.to(torch.float32)
masked_data.append(data * m + 0.0)
masks.append(m)
if self.dataset_format is not None and "skm-tea" in self.dataset_format.lower():
accelerations.append(float(self.acc[i]))
else:
accelerations.append(np.round(m.squeeze(0).squeeze(-1).numpy().size / m.numpy().sum()))
elif not is_none(mask) and not isinstance(mask, list) and mask.ndim != 0 and len(mask) > 0: # type: ignore
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask)
mask = mask.unsqueeze(0).unsqueeze(-1)
if not is_none(padding) and padding[0] != 0: # type: ignore
mask[:, :, : padding[0]] = 0 # type: ignore
mask[:, :, padding[1] :] = 0 # type: ignore
if mask.shape[-3] != data.shape[-3] or mask.shape[-2] != data.shape[-2]:
mask = center_crop(mask.squeeze(-1), (data.shape[-3], data.shape[-2])).unsqueeze(-1)
if self.shift_mask:
mask = torch.fft.fftshift(mask, dim=(spatial_dims[0], spatial_dims[1]))
masked_data = [data * mask + 0.0]
masks = [mask]
accelerations = [np.round(mask.squeeze(0).squeeze(-1).numpy().size / mask.numpy().sum())]
elif isinstance(self.mask_func, list):
masked_data = []
masks = []
accelerations = []
for m in self.mask_func:
if self.dimensionality == 2:
_masked_data, _mask, _accelerations = apply_mask(
data,
m,
seed,
padding,
shift=self.shift_mask,
partial_fourier_percentage=self.partial_fourier_percentage,
center_scale=self.center_scale,
)
elif self.dimensionality == 3:
_masked_data = []
_masks = []
_accelerations = []
j_mask = None
for j in range(data.shape[0]):
j_masked_data, j_mask, j_acc = apply_mask(
data[j],
m,
seed,
padding,
shift=self.shift_mask,
partial_fourier_percentage=self.partial_fourier_percentage,
center_scale=self.center_scale,
existing_mask=j_mask if not self.remask else None,
)
_masked_data.append(j_masked_data)
_masks.append(j_mask)
_accelerations.append(j_acc)
_masked_data = torch.stack(_masked_data, dim=0)
_mask = torch.stack(_masks, dim=0)
_accelerations = torch.stack(_accelerations, dim=0)
else:
raise ValueError(f"Unsupported data dimensionality {self.dimensionality}D.")
masked_data.append(_masked_data)
masks.append(_mask)
accelerations.append(_accelerations)
elif not is_none(self.mask_func):
masked_data, masks, accelerations = apply_mask( # type: ignore
data,
self.mask_func[0], # type: ignore
seed,
padding,
shift=self.shift_mask,
partial_fourier_percentage=self.partial_fourier_percentage,
center_scale=self.center_scale,
)
masked_data = [masked_data]
masks = [masks]
accelerations = [accelerations] # type: ignore
else:
masked_data = [data]
masks = [torch.empty([])]
accelerations = [torch.empty([])]
return masked_data, masks, accelerations
[docs]class N2R:
"""Generates Noise to Reconstruction (N2R) sampling masks, as presented in [Desai2022]_.
References
----------
.. [Desai2022] AD Desai, BM Ozturkler, CM Sandino, et al. Noise2Recon: Enabling Joint MRI Reconstruction and
Denoising with Semi-Supervised and Self-Supervised Learning. ArXiv 2022. https://arxiv.org/abs/2110.00075
Returns
-------
sampling_mask_noise : torch.Tensor
Sampling mask with noise. The shape should be (1, nx, ny, 1).
"""
[docs] def __init__(
self,
probability: float = 0.0,
std_devs: Tuple[float, float] = (0.0, 0.0),
rhos: Tuple[float, float] = (0.0, 0.0),
use_mask: bool = True,
):
"""Inits :class:`N2R`.
Parameters
----------
probability : float, optional
Probability of sampling. Default is ``0.0``.
std_devs : Tuple[float, float], optional
Standard deviations of the Gaussian noise. Default is ``(0.0, 0.0)``.
rhos: Tuple[float, float], optional
Rho values for the Gaussian noise. Default is ``(0.0, 0.0)``.
use_mask : bool, optional
Whether to use the mask. Default is ``True``.
"""
self.probability = probability
self.std_devs = std_devs
self.rhos = rhos
self.use_mask = use_mask
def __call__(self, data: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Calls :class:`N2R`.
Parameters
----------
data : torch.Tensor
Input data. The shape should be (nc, nx, ny).
mask : torch.Tensor
Input mask. The shape should be (nx, ny).
Returns
-------
sampling_mask_noise : torch.Tensor
Sampling mask with noise. The shape should be (1, nx, ny, 1).
"""
mask = mask.squeeze(0).squeeze(-1)
# if mask is 1D, repeat it for nx
if mask.shape[0] == 1:
mask = mask.repeat_interleave(data.shape[1], 0)
return self.forward(mask)
def __repr__(self):
"""Representation of :class:`N2R`."""
return (
f"N2R(probability={self.probability}, std_devs={self.std_devs}, rhos={self.rhos}, "
f"use_mask={self.use_mask})"
)
def __str__(self):
"""String representation of :class:`N2R`."""
return self.__repr__()
[docs] def forward(self, mask: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`N2R`.
Parameters
----------
mask : torch.Tensor
Input mask. The shape should be (nx, ny).
Returns
-------
sampling_mask_noise : torch.Tensor
Sampling mask with noise. The shape should be (1, nx, ny, 1).
"""
_rand = torch.rand(1).item()
if _rand >= self.probability:
return torch.ones_like(mask).unsqueeze(0).unsqueeze(-1)
rhos = (
self._rand_range(*self.rhos)
if not is_none(self.rhos) and self.rhos[0] != 0.0 and self.rhos[1] != 0.0
else None
)
if not self.use_mask:
mask = torch.ones(mask.shape)
std_devs = (
self._rand_range(*self.std_devs)
if not is_none(self.std_devs) and self.std_devs[0] != 0.0 and self.std_devs[1] != 0.0
else 1e-6 # add a small number to avoid division by zero
)
gen = torch.Generator(device=mask.device).manual_seed(int(_rand * 1e10))
noise = std_devs * torch.randn(mask.shape + (2,), generator=gen, device=mask.device)
if noise.shape[-1] == 2:
noise = torch.view_as_complex(noise)
if rhos is not None and rhos != 1:
shape = mask.shape
mask = mask.view(-1)
# TODO: this doesn't work if the matrix is > 2*24 in size.
num_valid = torch.sum(mask)
weights = mask / num_valid
samples = torch.multinomial(weights, int((1 - rhos) * num_valid), replacement=False, generator=gen)
mask[samples] = 0
mask = mask.view(shape)
if mask is not None:
noise = noise * mask
return torch.abs(noise).to(mask).unsqueeze(0).unsqueeze(-1)
@staticmethod
def _rand_range(low, high, size: int = None) -> float:
"""Uniform float random number between [low, high).
Parameters
----------
low : float
Lower bound.
high : float
Upper bound.
size : int, optional
Number of samples. Default is ``None``.
Returns
-------
val : float
A uniformly sampled number in range [low, high).
"""
if size is None:
size = 1
if low > high:
high, low = low, high
if high - low == 0:
return low
return (low + (high - low) * torch.rand(size)).cpu().item()
[docs]class NoisePreWhitening:
"""Applies noise pre-whitening / coil decorrelation.
Examples
--------
>>> import torch
>>> from atommic.collections.common.parts.transforms import NoisePreWhitening
>>> data = torch.randn([30, 100, 100], dtype=torch.complex64)
>>> data = torch.view_as_real(data)
>>> data.mean()
tensor(-0.0011)
>>> noise_prewhitening = NoisePreWhitening(find_patch_size=True, scale_factor=1.0)
>>> noise_prewhitening(data).mean()
tensor(-0.0023)
"""
[docs] def __init__(
self,
find_patch_size: bool = True,
patch_size: List[int] = None,
scale_factor: float = 1.0,
fft_centered: bool = False,
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = (-2, -1),
):
"""Inits :class:`NoisePreWhitening`.
Parameters
----------
find_patch_size : bool
Find optimal patch size (automatically) to calculate psi. If False, patch_size must be defined.
Default is ``True``.
patch_size : list of ints
Define patch size to calculate psi, [x_start, x_end, y_start, y_end].
scale_factor : float
Applied on the noise covariance matrix. Used to adjust for effective noise bandwidth and difference in
sampling rate between noise calibration and actual measurement.
scale_factor = (T_acq_dwell/T_noise_dwell)*NoiseReceiverBandwidthRatio
Default is ``1.0``.
fft_centered : bool
If True, the zero-frequency component is located at the center of the spectrum.
Default is ``False``.
fft_normalization : str
Normalization mode. Options are ``"backward"``, ``"ortho"``, ``"forward"``.
Default is ``"backward"``.
spatial_dims : sequence of ints
Spatial dimensions of the input data.
"""
super().__init__()
# TODO: account for multiple echo times
self.find_patch_size = find_patch_size
self.patch_size = patch_size
self.scale_factor = scale_factor
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims
def __call__(
self,
data: torch.Tensor,
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> torch.Tensor:
"""Calls :class:`NoisePreWhitening`.
Parameters
----------
data : torch.Tensor
Input data to apply coil compression.
apply_backward_transform : bool
Apply backward transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform. Default is ``False``.
"""
return self.forward(data, apply_backward_transform, apply_forward_transform)
def __repr__(self):
"""Representation of :class:`NoisePreWhitening`."""
return f"Noise pre-whitening is applied with patch size {self.patch_size}."
def __str__(self):
"""String representation of :class:`NoisePreWhitening`."""
return str(self.__repr__)
# pylint: disable=unused-argument
[docs] def forward(
self,
data: torch.Tensor,
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> torch.Tensor:
"""Forward pass of :class:`NoisePreWhitening`.
Parameters
----------
data : torch.Tensor
Input data to apply noise pre-whitening.
apply_backward_transform : bool
Apply backward transform before noise pre-whitening.
apply_forward_transform : bool
Apply forward transform before noise pre-whitening.
Returns
-------
torch.Tensor
Noise pre-whitened data.
"""
if apply_forward_transform:
data = fft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if data.shape[-1] != 2:
data = torch.view_as_real(data)
if self.find_patch_size:
patch = self.find_optimal_patch_size(data)
noise = data[:, patch[0] : patch[1], patch[2] : patch[3]]
elif not is_none(self.patch_size):
noise = data[
:,
self.patch_size[0] : self.patch_size[1], # type: ignore
self.patch_size[-2] : self.patch_size[-1], # type: ignore
]
else:
raise ValueError(
"No patch size has been defined, while find_patch_size is False for noise prewhitening. Please define "
"a patch size or set find_patch_size to True."
)
noise_int = torch.reshape(noise, (noise.shape[0], int(torch.numel(noise) / noise.shape[0])))
deformation_matrix = (1 / (float(noise_int.shape[1]) - 1)) * torch.mm(noise_int, torch.conj(noise_int).t())
# ensure that the matrix is positive definite
deformation_matrix = deformation_matrix + torch.eye(deformation_matrix.shape[0]) * 1e-6
psi = torch.linalg.inv(torch.linalg.cholesky(deformation_matrix)) * sqrt(2) * sqrt(self.scale_factor)
data = torch.reshape(
torch.mm(psi, torch.reshape(data, (data.shape[0], int(torch.numel(data) / data.shape[0])))), data.shape
)
if apply_forward_transform:
data = ifft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
return data.detach().clone()
[docs] @staticmethod
def find_optimal_patch_size(data: torch.Tensor, min_noise: float = 1e10) -> List[int]:
"""Find optimal patch size for noise pre-whitening.
Parameters
----------
data : torch.Tensor
Input data to find optimal patch size.
min_noise : float
Minimum noise value. It is inversely proportional to the noise level. Default is ``1e10``.
Returns
-------
List[int]
Optimal patch size, [x_start, x_end, y_start, y_end].
"""
if data.shape[-1] == 2:
data = torch.view_as_complex(data)
best_patch = []
for patch_length in [10, 20, 30, 40, 50]:
for patch_start_x in range(0, data.shape[-2] - patch_length, 10):
for patch_start_y in range(0, data.shape[-1] - patch_length, 10):
patch = torch.abs(
rss(
data[
:,
patch_start_x : patch_start_x + patch_length,
patch_start_y : patch_start_y + patch_length,
],
)
)
noise = torch.sqrt(
torch.sum(torch.abs(patch - torch.mean(patch)) ** 2) / (len(torch.flatten(patch)) - 1)
)
if noise < min_noise:
min_noise = noise
best_patch = [
patch_start_x,
patch_start_x + patch_length,
patch_start_y,
patch_start_y + patch_length,
]
return best_patch
[docs]class Normalizer:
"""Normalizes data given a normalization type.
Returns
-------
normalized_data: torch.Tensor
Normalized data to range according to the normalization type.
Example
--------
>>> import torch
>>> from atommic.collections.common.parts.transforms import Normalizer
>>> data = torch.randn(1, 32, 320, 320, 2) 1j * torch.randn(1, 32, 320, 320, 2)
>>> print(torch.min(torch.abs(data)), torch.max(torch.abs(data)))
tensor(1e-06) tensor(1.4142)
>>> normalizer = Normalizer(normalization_type="max")
>>> normalized_data = normalizer(data)
>>> print(torch.min(torch.abs(data)), torch.max(torch.abs(data)))
tensor(0.) tensor(1.)
"""
[docs] def __init__(
self,
normalization_type: Optional[str] = None,
kspace_normalization: bool = False,
fft_centered: bool = False,
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = (-2, -1),
):
"""Inits :class:`Normalizer`.
Parameters
----------
normalization_type: str, optional
Normalization type. It can be one of the following:
- "max": normalize data by its maximum value.
- "mean": normalize data by its mean value.
- "minmax": normalize data by its minimum and maximum values.
- None: do not normalize data. It can be useful to verify FFT normalization.
Default is `None`.
kspace_normalization: str, optional
Normalize in k-space.
fft_centered: bool, optional
If True, the FFT will be centered. Default is `False`. Should be set for complex data normalization.
fft_normalization: str, optional
FFT normalization type. It can be one of the following:
- "backward": normalize the FFT by the number of elements in the input.
- "ortho": normalize the FFT by the number of elements in the input and the square root of the product of
the sizes of the input dimensions.
- "forward": normalize the FFT by the square root of the number of elements in the input.
Default is "backward".
spatial_dims: tuple, optional
Spatial dimensions. Default is `(-2, -1)`.
"""
self.normalization_type = normalization_type
self.kspace_normalization = kspace_normalization
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims
def __call__(
self,
data: Union[torch.Tensor, List[torch.Tensor], None],
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> Union[torch.Tensor, List[torch.Tensor], None]:
"""Calls :class:`Normalizer`.
Parameters
----------
data : torch.Tensor
Input data to apply coil compression.
apply_backward_transform : bool
Apply backward transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform. Default is ``False``.
"""
if not is_none(data):
if isinstance(data, list) and len(data) > 0:
return [self.forward(d, apply_backward_transform, apply_forward_transform) for d in data]
if data.dim() > 1 and data.mean() != 1: # type: ignore
return self.forward(data, apply_backward_transform, apply_forward_transform)
return data, None
def __repr__(self):
"""Representation of :class:`Normalizer`."""
return f"Normalization type is set to {self.normalization_type}."
def __str__(self):
"""String representation of :class:`Normalizer`."""
return self.__repr__()
[docs] def forward(
self,
data: torch.Tensor,
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> Tuple[torch.Tensor, Dict]:
"""Forward pass of :class:`Normalizer`.
Parameters
----------
data : torch.Tensor
Input data.
apply_backward_transform : bool, optional
If True, apply backward transform. Default is ``False``.
apply_forward_transform : bool, optional
If True, apply forward transform. Default is ``False``.
Returns
-------
data : torch.Tensor
Normalized data.
attrs : dict
Normalization attributes.
"""
if self.kspace_normalization and apply_backward_transform:
apply_backward_transform = False
if apply_backward_transform:
data = ifft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif apply_forward_transform:
data = fft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if data.shape[-1] == 2:
data = torch.view_as_complex(data)
attrs = {
"min": torch.min(torch.abs(data)),
"max": torch.max(torch.abs(data)),
"mean": torch.mean(torch.abs(data)),
"std": torch.std(torch.abs(data)),
"var": torch.var(torch.abs(data)),
}
if self.normalization_type == "max":
data = data / torch.max(torch.abs(data))
elif self.normalization_type == "minmax":
min_value = torch.min(torch.abs(data))
data = (data - min_value) / (torch.max(torch.abs(data)) - min_value)
elif self.normalization_type == "mean_std":
data = data - torch.mean(torch.abs(data))
data = data / torch.std(torch.abs(data))
elif self.normalization_type == "mean_var":
data = data - torch.mean(torch.abs(data))
data = data / torch.var(torch.abs(data))
elif self.normalization_type == "grayscale":
data = data - torch.min(torch.abs(data))
data = data / torch.max(torch.abs(data))
data = data * 255
elif is_none(self.normalization_type) or self.normalization_type == "fft":
pass
else:
raise ValueError(f"Normalization type {self.normalization_type} is not supported.")
if torch.is_complex(data):
data = torch.view_as_real(data)
if apply_backward_transform:
data = fft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif apply_forward_transform:
data = ifft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
return data, attrs
[docs]class SNREstimator:
"""Estimates Signal-to-Noise Ratio.
Returns
-------
snr : float
Estimated SNR.
Example
--------
>>> import torch
>>> from atommic.collections.common.parts.transforms import SNREstimator
>>> data = torch.randn(1, 32, 320, 320, 2) 1j * torch.randn(1, 32, 320, 320, 2)
>>> print(torch.min(torch.abs(data)), torch.max(torch.abs(data)))
tensor(1e-06) tensor(1.4142)
>>> snr_estimator = SNREstimator()
>>> snr_estimator(data)
3.2
"""
[docs] def __init__(
self,
patch_size: List[int],
apply_ifft: bool = True,
fft_centered: bool = False,
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = (-2, -1),
coil_dim: int = 1,
multicoil: bool = True,
):
"""Inits :class:`SNREstimator`.
Parameters
----------
patch_size : list of ints
Define patch size to calculate noise.
x_start, x_end, y_start, y_end
apply_ifft: bool
If data in k-space go to imspace
fft_centered: bool
If True, apply centered FFT
fft_normalization : str
Type of FFT normalization
spatial_dims : tuple of ints
Spatial dimensions
coil_dim : int
Coil dimension
multicoil : bool
If True, multicoil data. Else single coil data.
"""
super().__init__()
self.patch_size = patch_size
self.apply_ifft = apply_ifft
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims
self.coil_dim = coil_dim
self.multicoil = multicoil
def __call__(self, data):
"""Calls :class:`SNREstimator`."""
if not self.patch_size:
return data
is_complex = torch.is_complex(data)
if data.shape[-1] != 2 and is_complex:
data = torch.view_as_real(data)
if not self.multicoil:
data = torch.unsqueeze(data, self.coil_dim)
imspace = (
ifft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if self.apply_ifft
else data
)
if is_complex:
imspace = torch.view_as_complex(imspace)
rss_eta = torch.abs(rss(imspace, dim=self.coil_dim)).detach().cpu().numpy()
# TODO: numpy funcs need to be replaced from torch funcs
# pylint: disable=import-outside-toplevel
from skimage.filters import threshold_otsu
from skimage.morphology import convex_hull_image
signal = torch.mean(
torch.from_numpy(np.nonzero(convex_hull_image(np.where(rss_eta > threshold_otsu(rss_eta), 1, 0)))[0]).to(
dtype=torch.float32
)
)
kspace_patch = torch.abs(
rss(
fft2(
imspace,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)[:, self.patch_size[0] : self.patch_size[1], self.patch_size[-2] : self.patch_size[-1]],
dim=self.coil_dim,
)
)
noise = torch.sqrt(
torch.sum(torch.abs(kspace_patch - torch.mean(kspace_patch)) ** 2) / (len(torch.flatten(kspace_patch)) - 1)
)
return (signal / noise).item() if not torch.isnan(signal) and not torch.isnan(noise) else 0
[docs]class SSDU:
"""Generates Self-Supervised Data Undersampling (SSDU) masks, as presented in [Yaman2020]_.
References
----------
.. [Yaman2020] Yaman, B, Hosseini, SAH, Moeller, S, Ellermann, J, Uğurbil, K, Akçakaya, M. Self-supervised
learning of physics-guided reconstruction neural networks without fully sampled reference data. Magn Reson
Med. 2020; 84: 3172–3191. https://doi.org/10.1002/mrm.28378
Returns
-------
loss_mask: torch.Tensor
Loss mask.
training_mask: torch.Tensor
Training mask.
"""
[docs] def __init__(
self,
mask_type: str = "Gaussian",
rho: float = 0.4,
acs_block_size: Sequence[int] = (4, 4),
gaussian_std_scaling_factor: float = 4.0,
outer_kspace_fraction: float = 0.0,
export_and_reuse_masks: bool = False,
):
"""Inits :class:`SSDU`.
Parameters
----------
mask_type: str, optional
Mask type. It can be one of the following:
- "Gaussian": Gaussian sampling.
- "Uniform": Uniform sampling.
Default is "Gaussian".
rho: float, optional
Split ratio for training and loss masks. Default is ``0.4``.
acs_block_size: Sequence[int], optional
Keeps a small acs region fully-sampled for training masks, if there is no acs region. The small acs block
should be set to zero. Default is ``(4, 4)``.
gaussian_std_scaling_factor: float, optional
Scaling factor for standard deviation of the Gaussian noise. If Uniform is select this factor is ignored.
Default is ``4.0``.
outer_kspace_fraction: float, optional
Fraction of the outer k-space region to be kept/unmasked. Default is ``0.0``.
export_and_reuse_masks: bool, optional
If ``True``, the generated masks are exported to the tmp directory and reused in the next call. This
option is useful when the data is too large to be stored in memory. Default is ``False``.
"""
if mask_type not in ["Gaussian", "Uniform"]:
raise ValueError(f"SSDU mask type {mask_type} is not supported.")
self.mask_type = mask_type
self.rho = rho
self.acs_block_size = acs_block_size
self.gaussian_std_scaling_factor = gaussian_std_scaling_factor
self.outer_kspace_fraction = outer_kspace_fraction
self.export_and_reuse_masks = export_and_reuse_masks
def __call__(self, data: torch.Tensor, mask: torch.Tensor, fname: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calls :class:`SSDU`."""
return self.forward(mask, fname)
def __repr__(self):
"""Representation of :class:`SSDU`."""
return f"SSDU type is set to {self.mask_type}."
def __str__(self):
"""String representation of :class:`SSDU`."""
return self.__repr__()
[docs] def forward(self, mask: torch.Tensor, fname: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of :class:`SSDU`.
Parameters
----------
mask : torch.Tensor
Mask tensor.
fname : str
File name to save the generated masks.
Returns
-------
train_mask : torch.Tensor
Training mask.
loss_mask : torch.Tensor
Loss mask.
"""
if self.export_and_reuse_masks:
# check if masks are already generated
precomputed_masks = self.__exists__(fname, (mask.shape[0], mask.shape[1]))
if precomputed_masks is not None:
return precomputed_masks[0], precomputed_masks[1]
if self.mask_type == "Gaussian":
_mask = self.__gaussian_sampling__(mask).type(torch.float32)
else:
_mask = self.__uniform_sampling__(mask).type(torch.float32)
train_mask = torch.where(mask == 1, 1 - _mask, mask)
loss_mask = torch.where(mask == 1, _mask, mask)
# TODO: should we add the acs region to ensure linearity in FFT?
# train_mask = torch.where(self.__find_acs_region__(train_mask) == 1, 1, train_mask)
# loss_mask = torch.where(self.__find_acs_region__(mask) == 1, 1, loss_mask)
if self.outer_kspace_fraction > 0:
train_mask = self.__apply_outer_kspace_unmask__(train_mask)
loss_mask = self.__apply_outer_kspace_unmask__(loss_mask)
if self.export_and_reuse_masks:
# save masks
self.__export__(torch.stack([train_mask, loss_mask], dim=0), fname)
return train_mask, loss_mask
@staticmethod
def __find_acs_region__(mask: torch.Tensor) -> torch.Tensor: # noqa: MC0001
"""Find the acs region.
Parameters
----------
mask : torch.Tensor
Sampling mask.
Returns
-------
torch.Tensor
ACS region.
"""
center = (mask.shape[0] // 2, mask.shape[1] // 2)
# find the size of the acs region, start from the center and go left to find contiguous 1s
acs_region = torch.zeros_like(mask)
for i in range(center[0], 0, -1):
if mask[i, center[1]] == 1:
acs_region[i, :] = 1
else:
break
# go right
for i in range(center[0], mask.shape[0]):
if mask[i, center[1]] == 1:
acs_region[i, :] = 1
else:
break
# go up
for i in range(center[1], 0, -1):
if mask[center[0], i] == 1:
acs_region[:, i] = 1
else:
break
# go down
for i in range(center[1], mask.shape[1]):
if mask[center[0], i] == 0:
acs_region[:, i] = 1
else:
break
# keep only the acs region
# take only the first row and stop when you find a 1
left = 0
for i in range(acs_region.shape[0]):
if acs_region[i, 0] == 1:
left = i
break
# take only the last row and stop when you find a 1
right = 0
for i in range(acs_region.shape[0] - 1, 0, -1):
if acs_region[i, 0] == 1:
right = i
break
# take only the first column and stop when you find a 1
up = 0
for i in range(acs_region.shape[1]):
if acs_region[0, i] == 1:
up = i
break
# take only the last column and stop when you find a 1
down = 0
for i in range(acs_region.shape[1] - 1, 0, -1):
if acs_region[0, i] == 1:
down = i
break
acs_region = torch.zeros_like(mask)
acs_region[left:right, up:down] = 1
# keep only the part of the acs region that is in the mask
return acs_region * mask
def __gaussian_sampling__(self, mask: torch.Tensor) -> torch.Tensor:
"""Applies Gaussian sampling."""
nrow, ncol = mask.shape[0], mask.shape[1]
center_kx = nrow // 2
center_ky = ncol // 2
tmp_mask = mask.clone()
tmp_mask[
center_kx - self.acs_block_size[0] // 2 : center_kx + self.acs_block_size[0] // 2,
center_ky - self.acs_block_size[1] // 2 : center_ky + self.acs_block_size[1] // 2,
] = 0
_mask = torch.zeros_like(mask)
count = 0
total = int(torch.ceil(torch.sum(mask[:]) * self.rho))
while count <= total:
indx = int(np.round(np.random.normal(loc=center_kx, scale=(nrow - 1) / self.gaussian_std_scaling_factor)))
indy = int(np.round(np.random.normal(loc=center_ky, scale=(ncol - 1) / self.gaussian_std_scaling_factor)))
if 0 <= indx < nrow and 0 <= indy < ncol and tmp_mask[indx, indy] == 1 and _mask[indx, indy] != 1:
_mask[indx, indy] = 1
count = count + 1
return _mask
def __uniform_sampling__(self, mask: torch.Tensor) -> torch.Tensor:
"""Applies uniform sampling."""
nrow, ncol = mask.shape[0], mask.shape[1]
center_kx = nrow // 2
center_ky = ncol // 2
tmp_mask = mask.clone()
tmp_mask[
center_kx - self.acs_block_size[0] // 2 : center_kx + self.acs_block_size[0] // 2,
center_ky - self.acs_block_size[1] // 2 : center_ky + self.acs_block_size[1] // 2,
] = 0
_mask = tmp_mask.view(-1) if tmp_mask.is_contiguous() else tmp_mask.reshape(-1)
num_valid = torch.sum(_mask)
ind = torch.multinomial(_mask / num_valid, int(self.rho * num_valid), replacement=False)
_mask[ind] = 0
return _mask.view(mask.shape)
@staticmethod
def __find_center_ind__(data: torch.Tensor, dims: tuple = (1, 2, 3)) -> int:
"""Calculates the center of the k-space.
Parameters
----------
data : torch.Tensor
Input data. The shape should be (nx, ny, nc).
dims : tuple, optional
Dimensions to calculate the norm. Default is ``(1, 2, 3)``.
Returns
-------
center_ind : int
The center of the k-space
"""
for dim in dims:
data = torch.linalg.norm(data, dim=dim, keepdims=True)
return torch.argsort(data.squeeze())[-1:]
def __apply_outer_kspace_unmask__(self, mask: torch.Tensor) -> torch.Tensor:
"""Applies outer k-space (un)mask.
Parameters
----------
mask : torch.Tensor
Input mask. The shape should be (nx, ny).
Returns
-------
mask : torch.Tensor
Output mask. The shape should be (nx, ny).
"""
mask_out = int(mask.shape[1] * self.outer_kspace_fraction)
mask[:, 0:mask_out] = torch.ones((mask.shape[0], mask_out))
mask[:, mask.shape[1] - mask_out : mask.shape[1]] = torch.ones((mask.shape[0], mask_out))
return mask
@staticmethod
def __exists__(fname: str, shape: Tuple) -> Union[np.ndarray, None]:
"""Checks if the sampling mask exists.
Parameters
----------
fname : str
Filename to save the sampling mask.
shape : tuple
Shape of the sampling mask.
Returns
-------
exists : bool
True if the sampling mask exists.
"""
if ".h5" in fname:
fname = fname.replace(".h5", ".npy")
else:
fname = fname + ".npy"
# set path to the tmp directory of the home directory
path = os.path.join(os.path.expanduser("~"), "tmp", fname)
spatial_dims = (2, 3)
if os.path.exists(path):
masks = np.load(path)
if masks.ndim == 3:
spatial_dims = (1, 2)
if (masks.shape[spatial_dims[0]], masks.shape[spatial_dims[1]]) == shape:
return torch.from_numpy(masks)
return None
@staticmethod
def __export__(mask: torch.Tensor, fname: str) -> None:
"""Exports the sampling mask to a numpy file.
Parameters
----------
mask : torch.Tensor
Sampling mask. The shape should be (1, nx, ny, 1).
fname : str
Filename to save the sampling mask.
"""
if ".h5" in fname:
fname = fname.replace(".h5", ".npy")
else:
fname = fname + ".npy"
# set path to the tmp directory of the home directory
path = os.path.join(os.path.expanduser("~"), "tmp", fname)
np.save(path, mask.cpu().numpy())
[docs]class ZeroFillingPadding:
"""Zero-Filling padding in k-space -> changes the Field-of-View (FoV) in image space.
Returns
-------
zero_filled_data : torch.Tensor
Zero filled data.
spatial_dims : tuple
Spatial dimensions.
Example
-------
>>> import torch
>>> from atommic.collections.common.parts.transforms import ZeroFillingPadding
>>> data = torch.randn(1, 15, 320, 320, 2)
>>> zero_filling = ZeroFillingPadding(zero_filling_size=(400, 400), spatial_dims=(-2, -1))
>>> zero_filled_data = zero_filling(data)
>>> zero_filled_data.shape
[1, 15, 400, 400, 2]
"""
[docs] def __init__(
self,
zero_filling_size: Tuple,
fft_centered: bool = False,
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = (-2, -1),
):
"""Inits :class:`ZeroFillingPadding`.
Parameters
----------
zero_filling_size : tuple
Size of the zero filled data.
fft_centered : bool, optional
If True, the FFT will be centered. Default is ``False``. Should be set for complex data normalization.
fft_normalization : str, optional
FFT normalization type. It can be one of the following:
spatial_dims : tuple, optional
Spatial dimensions. Default is ``(-2, -1)``.
"""
self.zero_filling_size = zero_filling_size
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims
def __call__(
self,
data: Union[torch.Tensor, None],
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> torch.Tensor:
"""Calls :class:`ZeroFillingPadding`.
Parameters
----------
data : torch.Tensor
Input data to crop.
apply_backward_transform : bool
Apply backward transform, i.e. Inverse Fast Fourier Transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform, i.e. Fast Fourier Transform. Default is ``False``.
"""
if not is_none(data) and data.dim() > 1 and data.mean() != 1: # type: ignore
return self.forward(data, apply_backward_transform, apply_forward_transform)
return data
def __repr__(self) -> str:
"""Representation of :class:`ZeroFillingPadding`."""
return f"Zero-Filling will be applied to data with size {self.zero_filling_size}."
def __str__(self) -> str:
"""String representation of :class:`ZeroFillingPadding`."""
return self.__repr__()
[docs] def forward(
self,
data: torch.Tensor,
apply_backward_transform: bool = False,
apply_forward_transform: bool = False,
) -> torch.Tensor:
"""Forward pass of :class:`ZeroFillingPadding`.
Parameters
----------
data : torch.Tensor
Input data to crop.
apply_backward_transform : bool
Apply backward transform, i.e. Inverse Fast Fourier Transform. Default is ``False``.
apply_forward_transform : bool
Apply forward transform, i.e. Fast Fourier Transform. Default is ``False``.
"""
if apply_backward_transform:
data = ifft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif apply_forward_transform:
data = fft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
is_complex = data.shape[-1] == 2
if is_complex:
data = torch.view_as_complex(data)
padding_top = np.floor_divide(abs(int(self.zero_filling_size[0]) - data.shape[self.spatial_dims[0]]), 2)
padding_bottom = padding_top
padding_left = np.floor_divide(abs(int(self.zero_filling_size[1]) - data.shape[self.spatial_dims[1]]), 2)
padding_right = padding_left
data = torch.nn.functional.pad(
data, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0
)
if is_complex:
data = torch.view_as_real(data)
if apply_backward_transform:
data = fft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
elif apply_forward_transform:
data = ifft2(
data,
centered=self.fft_centered,
normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
return data
[docs]class MRIDataTransforms:
"""Generic class to apply transforms for MRI data."""
[docs] def __init__(
self,
dataset_format: str = None,
apply_prewhitening: bool = False,
find_patch_size: bool = True,
prewhitening_scale_factor: float = 1.0,
prewhitening_patch_start: int = 10,
prewhitening_patch_length: int = 30,
apply_gcc: bool = False,
gcc_virtual_coils: int = 10,
gcc_calib_lines: int = 24,
gcc_align_data: bool = True,
apply_random_motion: bool = False,
random_motion_type: str = "gaussian",
random_motion_percentage: Sequence[int] = (10, 20),
random_motion_angle: int = 10,
random_motion_translation: int = 10,
random_motion_center_percentage: float = 0.02,
random_motion_num_segments: int = 8,
random_motion_random_num_segments: bool = True,
random_motion_non_uniform: bool = False,
estimate_coil_sensitivity_maps: bool = False,
coil_sensitivity_maps_type: str = "ESPIRiT",
coil_sensitivity_maps_gaussian_sigma: float = 0.0,
coil_sensitivity_maps_espirit_threshold: float = 0.05,
coil_sensitivity_maps_espirit_kernel_size: int = 6,
coil_sensitivity_maps_espirit_crop: float = 0.95,
coil_sensitivity_maps_espirit_max_iters: int = 30,
coil_combination_method: str = "SENSE",
dimensionality: int = 2,
mask_func: Optional[Callable] = None,
shift_mask: bool = False,
mask_center_scale: float = 0.02,
partial_fourier_percentage: float = 0.0,
remask: bool = False,
ssdu: bool = False,
ssdu_mask_type: str = "Gaussian",
ssdu_rho: float = 0.4,
ssdu_acs_block_size: Sequence[int] = (4, 4),
ssdu_gaussian_std_scaling_factor: float = 4.0,
ssdu_outer_kspace_fraction: float = 0.0,
ssdu_export_and_reuse_masks: bool = False,
n2r: bool = False,
n2r_supervised_rate: float = 0.0,
n2r_probability: float = 0.0,
n2r_std_devs: Optional[Tuple[float, float]] = None,
n2r_rhos: Optional[Tuple[float, float]] = None,
n2r_use_mask: bool = False,
unsupervised_masked_target: bool = False,
crop_size: Optional[Tuple[int, int]] = None,
kspace_crop: bool = False,
crop_before_masking: bool = True,
kspace_zero_filling_size: Optional[Tuple] = None,
normalize_inputs: bool = True,
normalization_type: str = "max",
kspace_normalization: bool = False,
fft_centered: bool = False,
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = None,
coil_dim: int = 0,
consecutive_slices: int = 1, # pylint: disable=unused-argument
use_seed: bool = True,
):
"""Inits :class:`MRIDataTransforms`.
Parameters
----------
dataset_format : str, optional
The format of the dataset. For example, ``'custom_dataset'`` or ``'public_dataset_name'``.
Default is ``None``.
apply_prewhitening : bool, optional
Apply prewhitening. If ``True`` then the prewhitening arguments are used. Default is ``False``.
find_patch_size : bool, optional
Find optimal patch size (automatically) to calculate psi. If False, patch_size must be defined.
Default is ``True``.
prewhitening_scale_factor : float, optional
Prewhitening scale factor. Default is ``1.0``.
prewhitening_patch_start : int, optional
Prewhitening patch start. Default is ``10``.
prewhitening_patch_length : int, optional
Prewhitening patch length. Default is ``30``.
apply_gcc : bool, optional
Apply Geometric Decomposition Coil Compression. If ``True`` then the GCC arguments are used.
Default is ``False``.
gcc_virtual_coils : int, optional
GCC virtual coils. Default is ``10``.
gcc_calib_lines : int, optional
GCC calibration lines. Default is ``24``.
gcc_align_data : bool, optional
GCC align data. Default is ``True``.
apply_random_motion : bool, optional
Simulate random motion in k-space. Default is ``False``.
random_motion_type : str, optional
Random motion type. It can be one of the following: ``piecewise_transient``, ``piecewise_constant``,
``gaussian``. Default is ``gaussian``.
random_motion_percentage : Sequence[int], optional
Random motion percentage. For example, 10%-20% motion can be defined as ``(10, 20)``.
Default is ``(10, 10)``.
random_motion_angle : float, optional
Random motion angle. Default is ``10.0``.
random_motion_translation : float, optional
Random motion translation. Default is ``10.0``.
random_motion_center_percentage : float, optional
Random motion center percentage. Default is ``0.0``.
random_motion_num_segments : int, optional
Random motion number of segments to partition the k-space. Default is ``8``.
random_motion_random_num_segments : bool, optional
Whether to randomly generate the number of segments. Default is ``True``.
random_motion_non_uniform : bool, optional
Random motion non-uniform sampling. Default is ``False``.
estimate_coil_sensitivity_maps : bool, optional
Automatically estimate coil sensitivity maps. Default is ``False``. If ``True`` then the coil sensitivity
maps arguments are used. Note that this is different from the ``estimate_coil_sensitivity_maps_with_nn``
argument, which uses a neural network to estimate the coil sensitivity maps. The
``estimate_coil_sensitivity_maps`` estimates the coil sensitivity maps with methods such as ``ESPIRiT``,
``RSS`` or ``UNit``. ``ESPIRiT`` is the ``Eigenvalue to Self-Consistent Parallel Imaging Reconstruction
Technique`` method. ``RSS`` is the ``Root Sum of Squares`` method. ``UNit`` returns a uniform coil
sensitivity map.
coil_sensitivity_maps_type : str, optional
Coil sensitivity maps type. It can be one of the following: ``ESPIRiT``, ``RSS`` or ``UNit``. Default is
``ESPIRiT``.
coil_sensitivity_maps_gaussian_sigma : float, optional
Coil sensitivity maps Gaussian sigma. Default is ``0.0``.
coil_sensitivity_maps_espirit_threshold : float, optional
Coil sensitivity maps ESPRIT threshold. Default is ``0.05``.
coil_sensitivity_maps_espirit_kernel_size : int, optional
Coil sensitivity maps ESPRIT kernel size. Default is ``6``.
coil_sensitivity_maps_espirit_crop : float, optional
Coil sensitivity maps ESPRIT crop. Default is ``0.95``.
coil_sensitivity_maps_espirit_max_iters : int, optional
Coil sensitivity maps ESPRIT max iterations. Default is ``30``.
coil_combination_method : str, optional
Coil combination method. Default is ``"SENSE"``.
dimensionality : int, optional
Dimensionality. Default is ``2``.
mask_func : Optional[Callable], optional
Mask function to retrospectively undersample the k-space. Default is ``None``.
shift_mask : bool, optional
Whether to shift the mask. This needs to be set alongside with the ``fft_centered`` argument.
Default is ``False``.
mask_center_scale : Optional[float], optional
Center scale of the mask. This defines how much densely sampled will be the center of k-space.
Default is ``0.02``.
partial_fourier_percentage : float, optional
Whether to simulate a half scan. Default is ``0.0``.
remask : bool, optional
Use the same mask. Default is ``False``.
ssdu : bool, optional
Whether to apply Self-Supervised Data Undersampling (SSDU) masks. Default is ``False``.
ssdu_mask_type: str, optional
Mask type. It can be one of the following:
- "Gaussian": Gaussian sampling.
- "Uniform": Uniform sampling.
Default is "Gaussian".
ssdu_rho: float, optional
Split ratio for training and loss masks. Default is ``0.4``.
ssdu_acs_block_size: tuple, optional
Keeps a small acs region fully-sampled for training masks, if there is no acs region. The small acs block
should be set to zero. Default is ``(4, 4)``.
ssdu_gaussian_std_scaling_factor: float, optional
Scaling factor for standard deviation of the Gaussian noise. If Uniform is select this factor is ignored.
Default is ``4.0``.
ssdu_outer_kspace_fraction: float, optional
Fraction of the outer k-space to be kept/unmasked. Default is ``0.0``.
ssdu_export_and_reuse_masks: bool, optional
Whether to export and reuse the masks. Default is ``False``.
n2r : bool, optional
Whether to apply Noise to Reconstruction (N2R) masks. Default is ``False``.
n2r_supervised_rate : Optional[float], optional
A float between 0 and 1. This controls what fraction of the subjects should be loaded for Noise to
Reconstruction (N2R) supervised loss, if N2R is enabled. Default is ``0.0``.
n2r_probability : float, optional
Probability of applying N2R. Default is ``0.0``.
n2r_std_devs : Optional[Tuple[float, float]], optional
Standard deviations for the noise. Default is ``(0.0, 0.0)``.
n2r_rhos : Optional[Tuple[float, float]], optional
Rho values for the noise. Default is ``(0.0, 0.0)``.
n2r_use_mask : bool, optional
Whether to use a mask for N2R. Default is ``False``.
unsupervised_masked_target : bool, optional
Whether to use the masked initial estimation for unsupervised learning. Default is ``False``.
crop_size : Optional[Tuple[int, int]], optional
Center crop size. It applies cropping in image space. Default is ``None``.
kspace_crop : bool, optional
Whether to crop in k-space. Default is ``False``.
crop_before_masking : bool, optional
Whether to crop before masking. Default is ``True``.
kspace_zero_filling_size : Optional[Tuple], optional
Whether to apply zero filling in k-space. Default is ``None``.
normalize_inputs : bool, optional
Whether to normalize the inputs. Default is ``True``.
normalization_type : str, optional
Normalization type. Can be ``max`` or ``mean`` or ``minmax``. Default is ``max``.
kspace_normalization : bool, optional
Whether to normalize the k-space. Default is ``False``.
fft_centered : bool, optional
Whether to center the FFT. Default is ``False``.
fft_normalization : str, optional
FFT normalization. Default is ``"backward"``.
spatial_dims : Sequence[int], optional
Spatial dimensions. Default is ``None``.
coil_dim : int, optional
Coil dimension. Default is ``0``, meaning that the coil dimension is the first dimension before applying
batch.
consecutive_slices : int, optional
Consecutive slices. Default is ``1``.
use_seed : bool, optional
Whether to use seed. Default is ``True``.
"""
super().__init__()
self.dataset_format = dataset_format
self.coil_combination_method = coil_combination_method
self.kspace_crop = kspace_crop
self.crop_before_masking = crop_before_masking
self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
self.spatial_dims = spatial_dims if spatial_dims is not None else [-2, -1]
self.coil_dim = coil_dim - 1 if dimensionality == 2 else coil_dim
self.prewhitening = (
NoisePreWhitening(
find_patch_size=find_patch_size,
patch_size=[
prewhitening_patch_start,
prewhitening_patch_length + prewhitening_patch_start,
prewhitening_patch_start,
prewhitening_patch_length + prewhitening_patch_start,
],
scale_factor=prewhitening_scale_factor,
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if apply_prewhitening
else None
)
self.gcc = (
GeometricDecompositionCoilCompression(
virtual_coils=gcc_virtual_coils,
calib_lines=gcc_calib_lines,
align_data=gcc_align_data,
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if apply_gcc
else None
)
self.random_motion = (
MotionSimulation(
motion_type=random_motion_type,
angle=random_motion_angle,
translation=random_motion_translation,
center_percentage=random_motion_center_percentage,
motion_percentage=random_motion_percentage,
num_segments=random_motion_num_segments,
random_num_segments=random_motion_random_num_segments,
non_uniform=random_motion_non_uniform,
spatial_dims=self.spatial_dims,
)
if apply_random_motion
else None
)
self.coil_sensitivity_maps_estimator = (
EstimateCoilSensitivityMaps(
coil_sensitivity_maps_type=coil_sensitivity_maps_type.lower(),
gaussian_sigma=coil_sensitivity_maps_gaussian_sigma,
espirit_threshold=coil_sensitivity_maps_espirit_threshold,
espirit_kernel_size=coil_sensitivity_maps_espirit_kernel_size,
espirit_crop=coil_sensitivity_maps_espirit_crop,
espirit_max_iters=coil_sensitivity_maps_espirit_max_iters,
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
coil_dim=self.coil_dim,
)
if estimate_coil_sensitivity_maps
else None
)
self.kspace_zero_filling = (
ZeroFillingPadding(
zero_filling_size=kspace_zero_filling_size, # type: ignore
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if not is_none(kspace_zero_filling_size)
else None
)
self.shift_mask = shift_mask
self.masking = Masker(
mask_func=mask_func,
spatial_dims=self.spatial_dims,
shift_mask=shift_mask,
partial_fourier_percentage=partial_fourier_percentage,
center_scale=mask_center_scale,
dimensionality=dimensionality,
remask=remask,
dataset_format=self.dataset_format,
)
self.ssdu = ssdu
self.ssdu_masking = (
SSDU(
mask_type=ssdu_mask_type,
rho=ssdu_rho,
acs_block_size=ssdu_acs_block_size,
gaussian_std_scaling_factor=ssdu_gaussian_std_scaling_factor,
outer_kspace_fraction=ssdu_outer_kspace_fraction,
export_and_reuse_masks=ssdu_export_and_reuse_masks,
)
if self.ssdu
else None
)
self.n2r = n2r
self.n2r_supervised_rate = n2r_supervised_rate
self.n2r_masking = (
N2R(
probability=n2r_probability,
std_devs=n2r_std_devs, # type: ignore
rhos=n2r_rhos, # type: ignore
use_mask=n2r_use_mask,
)
if self.n2r
else None
)
self.unsupervised_masked_target = unsupervised_masked_target
self.cropping = (
Cropper(
cropping_size=crop_size, # type: ignore
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if not is_none(crop_size)
else None
)
self.normalization_type = normalization_type
self.normalization = (
Normalizer(
normalization_type=self.normalization_type,
kspace_normalization=kspace_normalization,
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
)
if normalize_inputs
else None
)
self.prewhitening = Composer([self.prewhitening]) # type: ignore
self.coils_shape_transforms = Composer(
[
self.gcc, # type: ignore
self.kspace_zero_filling, # type: ignore
]
)
self.cropping = Composer([self.cropping]) # type: ignore
self.random_motion = Composer([self.random_motion]) # type: ignore
self.normalization = Composer([self.normalization]) # type: ignore
self.use_seed = use_seed
def __call__(
self,
kspace: np.ndarray,
sensitivity_map: np.ndarray,
mask: np.ndarray,
initial_prediction: np.ndarray,
target: np.ndarray,
attrs: Dict,
fname: str,
slice_idx: int,
) -> Tuple[
Union[torch.Tensor, List[torch.Tensor]],
Union[List[torch.Tensor], torch.Tensor],
torch.Tensor,
Union[List[torch.Tensor], torch.Tensor],
Union[List[torch.Tensor], torch.Tensor],
torch.tensor,
str,
int,
Union[List[Union[float, torch.Tensor, Any]]],
Dict,
]:
"""Calls :class:`MRIDataTransforms`.
Parameters
----------
kspace : np.ndarray
The fully-sampled kspace, if exists. Otherwise, the subsampled kspace.
sensitivity_map : np.ndarray
The coil sensitivity map.
mask : np.ndarray
The subsampling mask, if exists, meaning that the data are either prospectively undersampled or the mask is
stored and loaded.
initial_prediction : np.ndarray
The initial prediction, if exists. Otherwise, it will be estimated with the chosen coil combination method.
target : np.ndarray
The target, if exists. Otherwise, it will be estimated with the chosen coil combination method.
attrs : Dict
The attributes, if stored in the data.
fname : str
The file name.
slice_idx : int
The slice index.
"""
kspace, masked_kspace, mask, kspace_pre_normalization_vars, acc = self.__process_kspace__( # type: ignore
kspace, mask, attrs, fname
)
sensitivity_map, sensitivity_pre_normalization_vars = self.__process_coil_sensitivities_map__(
sensitivity_map, kspace
)
if self.n2r and len(masked_kspace) > 1: # type: ignore
prediction, prediction_pre_normalization_vars = self.__initialize_prediction__(
initial_prediction, masked_kspace[0], sensitivity_map # type: ignore
)
if isinstance(masked_kspace, list) and not masked_kspace[1][0].dim() < 2: # type: ignore
noise_prediction, noise_prediction_pre_normalization_vars = self.__initialize_prediction__(
None, masked_kspace[1], sensitivity_map # type: ignore
)
else:
noise_prediction = torch.tensor([])
noise_prediction_pre_normalization_vars = None
prediction = [prediction, noise_prediction]
else:
prediction, prediction_pre_normalization_vars = self.__initialize_prediction__(
initial_prediction, masked_kspace, sensitivity_map # type: ignore
)
noise_prediction_pre_normalization_vars = None
if self.unsupervised_masked_target:
target, target_pre_normalization_vars = prediction, prediction_pre_normalization_vars
else:
target, target_pre_normalization_vars = self.__initialize_prediction__(
None if self.ssdu else target, kspace, sensitivity_map
)
attrs.update(
self.__parse_normalization_vars__(
kspace_pre_normalization_vars, # type: ignore
sensitivity_pre_normalization_vars,
prediction_pre_normalization_vars,
noise_prediction_pre_normalization_vars,
target_pre_normalization_vars,
)
)
attrs["fname"] = fname
attrs["slice_idx"] = slice_idx
return (
kspace,
masked_kspace, # type: ignore
sensitivity_map,
mask,
prediction,
target,
fname,
slice_idx,
acc, # type: ignore
attrs,
)
def __repr__(self) -> str:
"""Representation of :class:`MRIDataTransforms`."""
return (
f"Preprocessing transforms initialized for {self.__class__.__name__}: "
f"prewhitening = {self.prewhitening}, "
f"masking = {self.masking}, "
f"SSDU masking = {self.ssdu_masking}, "
f"kspace zero-filling = {self.kspace_zero_filling}, "
f"cropping = {self.cropping}, "
f"normalization = {self.normalization}, "
)
def __str__(self) -> str:
"""String representation of :class:`MRIDataTransforms`."""
return self.__repr__()
def __process_kspace__( # noqa: MC0001
self, kspace: np.ndarray, mask: Union[np.ndarray, None], attrs: Dict, fname: str
) -> Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor], Union[List[torch.Tensor], torch.Tensor], int]:
"""Apply the preprocessing transforms to the kspace.
Parameters
----------
kspace : torch.Tensor
The kspace.
mask : torch.Tensor
The mask, if None, the mask is generated.
attrs : Dict
The attributes, if stored in the file.
fname : str
The file name.
Returns
-------
Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor], Union[List[torch.Tensor], torch.Tensor], int]
The transformed (fully-sampled) kspace, the masked kspace, the mask, the attributes and the acceleration
factor.
"""
kspace = to_tensor(kspace)
kspace = add_coil_dim_if_singlecoil(kspace, dim=self.coil_dim)
kspace = self.coils_shape_transforms(kspace, apply_backward_transform=True)
kspace = self.prewhitening(kspace) # type: ignore
if self.crop_before_masking:
kspace = self.cropping(kspace, apply_backward_transform=not self.kspace_crop) # type: ignore
masked_kspace, mask, acc = self.masking(
self.random_motion(kspace), # type: ignore
mask,
(
attrs["padding_left"] if "padding_left" in attrs else 0,
attrs["padding_right"] if "padding_right" in attrs else 0,
),
tuple(map(ord, fname)) if self.use_seed else None, # type: ignore
)
if not self.crop_before_masking:
kspace = self.cropping(kspace, apply_backward_transform=not self.kspace_crop) # type: ignore
masked_kspace = self.cropping(masked_kspace, apply_backward_transform=not self.kspace_crop) # type: ignore
mask = self.cropping(mask) # type: ignore
init_kspace = kspace
init_masked_kspace = masked_kspace
init_mask = mask
if isinstance(kspace, list):
kspaces = []
pre_normalization_vars = []
for i in range(len(kspace)): # pylint: disable=consider-using-enumerate
if not is_none(self.normalization.__repr__()):
_kspace, _pre_normalization_vars = self.normalization( # type: ignore
kspace[i], apply_backward_transform=True
)
else:
_kspace = kspace[i]
is_complex = _kspace.shape[-1] == 2
if is_complex:
_kspace = torch.view_as_complex(_kspace)
_pre_normalization_vars = {
"min": torch.min(torch.abs(_kspace)),
"max": torch.max(torch.abs(_kspace)),
"mean": torch.mean(torch.abs(_kspace)),
"std": torch.std(torch.abs(_kspace)),
"var": torch.var(torch.abs(_kspace)),
}
if is_complex:
_kspace = torch.view_as_real(_kspace)
kspaces.append(_kspace)
pre_normalization_vars.append(_pre_normalization_vars)
kspace = kspaces
else:
if not is_none(self.normalization.__repr__()):
kspace, pre_normalization_vars = self.normalization( # type: ignore
kspace, apply_backward_transform=True
)
else:
is_complex = kspace.shape[-1] == 2
if is_complex:
kspace = torch.view_as_complex(kspace)
pre_normalization_vars = { # type: ignore
"min": torch.min(torch.abs(kspace)),
"max": torch.max(torch.abs(kspace)),
"mean": torch.mean(torch.abs(kspace)),
"std": torch.std(torch.abs(kspace)),
"var": torch.var(torch.abs(kspace)),
}
if is_complex:
kspace = torch.view_as_real(kspace)
if isinstance(masked_kspace, list):
masked_kspaces = []
masked_pre_normalization_vars = []
for i in range(len(masked_kspace)): # pylint: disable=consider-using-enumerate
if not is_none(self.normalization.__repr__()):
_masked_kspace, _masked_pre_normalization_vars = self.normalization( # type: ignore
masked_kspace[i], apply_backward_transform=True
)
else:
_masked_kspace = masked_kspace[i]
is_complex = _masked_kspace.shape[-1] == 2
if is_complex:
_masked_kspace = torch.view_as_complex(_masked_kspace)
_masked_pre_normalization_vars = {
"min": torch.min(torch.abs(_masked_kspace)),
"max": torch.max(torch.abs(_masked_kspace)),
"mean": torch.mean(torch.abs(_masked_kspace)),
"std": torch.std(torch.abs(_masked_kspace)),
"var": torch.var(torch.abs(_masked_kspace)),
}
if is_complex:
_masked_kspace = torch.view_as_real(_masked_kspace)
masked_kspaces.append(_masked_kspace)
masked_pre_normalization_vars.append(_masked_pre_normalization_vars)
masked_kspace = masked_kspaces
else:
if not is_none(self.normalization.__repr__()):
masked_kspace, masked_pre_normalization_vars = self.normalization(
masked_kspace, apply_backward_transform=True
)
else:
is_complex = masked_kspace.shape[-1] == 2
if is_complex:
masked_kspace = torch.view_as_complex(masked_kspace)
masked_pre_normalization_vars = {
"min": torch.min(torch.abs(masked_kspace)),
"max": torch.max(torch.abs(masked_kspace)),
"mean": torch.mean(torch.abs(masked_kspace)),
"std": torch.std(torch.abs(masked_kspace)),
"var": torch.var(torch.abs(masked_kspace)),
}
if is_complex:
masked_kspace = torch.view_as_real(masked_kspace)
if self.ssdu:
kspace, masked_kspace, mask = self.__self_supervised_data_undersampling__( # type: ignore
kspace, masked_kspace, mask, fname
)
n2r_pre_normalization_vars = None
if self.n2r and (not attrs["n2r_supervised"] or self.ssdu):
n2r_masked_kspace, n2r_mask = self.__noise_to_reconstruction__(init_kspace, init_masked_kspace, init_mask)
if self.ssdu:
if isinstance(mask, list):
for i in range(len(mask)): # pylint: disable=consider-using-enumerate
if init_mask[i].dim() != mask[i][0].dim(): # type: ignore
# find dimensions == 1 in mask[i][0] and add them to init_mask
unitary_dims = [j for j in range(mask[i][0].dim()) if mask[i][0].shape[j] == 1]
# unsqueeze init_mask to the index of the unitary dimensions
for j in unitary_dims:
init_mask[i] = init_mask[i].unsqueeze(j) # type: ignore
masked_kspace[i] = init_masked_kspace[i]
mask[i][0] = init_mask[i]
else:
if init_mask.dim() != mask[0].dim(): # type: ignore
# find dimensions == 1 in mask[0] and add them to init_mask
unitary_dims = [j for j in range(mask[0].dim()) if mask[0].shape[j] == 1]
# unsqueeze init_mask to the index of the unitary dimensions
for j in unitary_dims:
init_mask = init_mask.unsqueeze(j) # type: ignore
masked_kspace = init_masked_kspace
mask[0] = init_mask
if "None" not in self.normalization.__repr__():
if isinstance(masked_kspace, list):
masked_kspaces = []
masked_pre_normalization_vars = []
for i in range(len(masked_kspace)): # pylint: disable=consider-using-enumerate
_masked_kspace, _masked_pre_normalization_vars = self.normalization( # type: ignore
masked_kspace[i], apply_backward_transform=True
)
masked_kspaces.append(_masked_kspace)
masked_pre_normalization_vars.append(_masked_pre_normalization_vars)
masked_kspace = masked_kspaces
else:
masked_kspace, masked_pre_normalization_vars = self.normalization( # type: ignore
masked_kspace, apply_backward_transform=True
)
if isinstance(n2r_masked_kspace, list):
n2r_masked_kspaces = []
n2r_pre_normalization_vars = []
for i in range(len(n2r_masked_kspace)): # pylint: disable=consider-using-enumerate
_n2r_masked_kspace, _n2r_pre_normalization_vars = self.normalization( # type: ignore
n2r_masked_kspace[i], apply_backward_transform=True
)
n2r_masked_kspaces.append(_n2r_masked_kspace)
n2r_pre_normalization_vars.append(_n2r_pre_normalization_vars)
n2r_masked_kspace = n2r_masked_kspaces
else:
n2r_masked_kspace, n2r_pre_normalization_vars = self.normalization( # type: ignore
n2r_masked_kspace, apply_backward_transform=True
)
else:
masked_pre_normalization_vars = None # type: ignore
n2r_pre_normalization_vars = None # type: ignore
masked_kspace = [masked_kspace, n2r_masked_kspace]
mask = [mask, n2r_mask]
if self.normalization_type == "grayscale":
if isinstance(mask, list):
masks = []
for i in range(len(mask)): # pylint: disable=consider-using-enumerate
_mask, _ = self.normalization(mask[i], apply_backward_transform=False) # type: ignore
masks.append(_mask)
mask = masks
else:
mask, _ = self.normalization(mask, apply_backward_transform=False) # type: ignore
pre_normalization_vars = { # type: ignore
"kspace_pre_normalization_vars": pre_normalization_vars,
"masked_kspace_pre_normalization_vars": masked_pre_normalization_vars,
"noise_masked_kspace_pre_normalization_vars": n2r_pre_normalization_vars,
}
return kspace, masked_kspace, mask, pre_normalization_vars, acc # type: ignore
def __noise_to_reconstruction__(
self,
kspace: torch.Tensor,
masked_kspace: torch.Tensor,
mask: Union[List, torch.Tensor],
) -> Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]:
"""Apply the noise-to-reconstruction transform.
Parameters
----------
kspace : torch.Tensor
The fully-sampled kspace.
masked_kspace : torch.Tensor
The undersampled kspace.
mask : Union[List, torch.Tensor]
The undersampling mask.
Returns
-------
n2r_masked_kspace : Union[List, torch.Tensor]
The noise-to-reconstruction undersampled kspace.
n2r_mask : Union[List, torch.Tensor]
The noise-to-reconstruction mask.
"""
if isinstance(mask, list):
n2r_masked_kspaces = []
n2r_masks = []
for i in range(len(mask)): # pylint: disable=consider-using-enumerate
n2r_mask = self.n2r_masking(kspace, mask[i]) # type: ignore # pylint: disable=not-callable
n2r_masks.append(n2r_mask)
n2r_masked_kspaces.append(masked_kspace[i] * n2r_mask + 0.0)
n2r_mask = n2r_masks
n2r_masked_kspace = n2r_masked_kspaces
else:
n2r_mask = self.n2r_masking(kspace, mask) # type: ignore # pylint: disable=not-callable
n2r_masked_kspace = masked_kspace * n2r_mask + 0.0
return n2r_masked_kspace, n2r_mask
def __self_supervised_data_undersampling__( # noqa: MC0001
self,
kspace: torch.Tensor,
masked_kspace: Union[List, torch.Tensor],
mask: Union[List, torch.Tensor],
fname: str,
) -> Tuple[
List[float | Any] | float | Any,
List[float | Any] | float | Any,
List[List[torch.Tensor | Any]] | List[torch.Tensor | Any],
]:
"""Self-supervised data undersampling.
Parameters
----------
kspace : torch.Tensor
The fully-sampled kspace.
masked_kspace : Union[List, torch.Tensor]
The undersampled kspace.
mask : Union[List, torch.Tensor]
The undersampling mask.
fname : str
The filename of the current sample.
Returns
-------
kspace : torch.Tensor
The kspace with the loss mask applied.
masked_kspace : torch.Tensor
The kspace with the train mask applied.
mask : list, [torch.Tensor, torch.Tensor]
The train and loss masks.
"""
if isinstance(mask, list):
kspaces = []
masked_kspaces = []
masks = []
for i in range(len(mask)): # pylint: disable=consider-using-enumerate
is_1d = mask[i].squeeze().dim() == 1
if self.shift_mask:
mask[i] = torch.fft.fftshift(mask[i].squeeze(-1), dim=(-2, -1)).unsqueeze(-1)
mask[i] = mask[i].squeeze()
if is_1d:
mask[i] = mask[i].unsqueeze(0).repeat_interleave(kspace.shape[1], dim=0)
train_mask, loss_mask = self.ssdu_masking( # type: ignore # pylint: disable=not-callable
kspace, mask[i], fname
)
if self.shift_mask:
train_mask = torch.fft.fftshift(train_mask, dim=(0, 1))
loss_mask = torch.fft.fftshift(loss_mask, dim=(0, 1))
if is_1d:
train_mask = train_mask.unsqueeze(0).unsqueeze(-1)
loss_mask = loss_mask.unsqueeze(0).unsqueeze(-1)
else:
# find unitary dims in mask
dims = [i for i, x in enumerate(mask[i].shape) if x == 1]
# unsqueeze to broadcast
for d in dims:
train_mask = train_mask.unsqueeze(d)
loss_mask = loss_mask.unsqueeze(d)
if train_mask.dim() != kspace.dim():
# find dims != to any train_mask dim
dims = [i for i, x in enumerate(kspace.shape) if x not in train_mask.shape]
# unsqueeze to broadcast
for d in dims:
train_mask = train_mask.unsqueeze(d)
loss_mask = loss_mask.unsqueeze(d)
kspaces.append(kspace * loss_mask + 0.0)
masked_kspaces.append(masked_kspace[i] * train_mask + 0.0)
masks.append([train_mask, loss_mask])
kspace = kspaces
masked_kspace = masked_kspaces
mask = masks
else:
is_1d = mask.squeeze().dim() == 1
if self.shift_mask:
mask = torch.fft.fftshift(mask.squeeze(-1), dim=(-2, -1)).unsqueeze(-1)
mask = mask.squeeze()
if is_1d:
mask = mask.unsqueeze(0).repeat_interleave(kspace.shape[1], dim=0)
train_mask, loss_mask = self.ssdu_masking( # type: ignore # pylint: disable=not-callable
kspace, mask, fname
)
if self.shift_mask:
train_mask = torch.fft.fftshift(train_mask, dim=(0, 1))
loss_mask = torch.fft.fftshift(loss_mask, dim=(0, 1))
if is_1d:
train_mask = train_mask.unsqueeze(0).unsqueeze(-1)
loss_mask = loss_mask.unsqueeze(0).unsqueeze(-1)
else:
# find unitary dims in mask
dims = [i for i, x in enumerate(mask.shape) if x == 1]
# unsqueeze to broadcast
for d in dims:
train_mask = train_mask.unsqueeze(d)
loss_mask = loss_mask.unsqueeze(d)
if train_mask.dim() != kspace.dim():
# find dims != to any train_mask dim
dims = [i for i, x in enumerate(kspace.shape) if x not in train_mask.shape]
# unsqueeze to broadcast
for d in dims:
train_mask = train_mask.unsqueeze(d)
loss_mask = loss_mask.unsqueeze(d)
kspace = kspace * loss_mask + 0.0
masked_kspace = masked_kspace * train_mask + 0.0
mask = [train_mask, loss_mask]
return kspace, masked_kspace, mask
def __process_coil_sensitivities_map__(
self, sensitivity_map: np.ndarray, kspace: torch.Tensor
) -> Union[torch.Tensor, Dict]:
"""Preprocesses the coil sensitivities map.
Parameters
----------
sensitivity_map : np.ndarray
The coil sensitivities map.
kspace : torch.Tensor
The kspace.
Returns
-------
List[torch.Tensor, Dict]
The preprocessed coil sensitivities map and the normalization variables.
"""
# This condition is necessary in case of auto estimation of sense maps.
if self.coil_sensitivity_maps_estimator is not None:
sensitivity_map = self.coil_sensitivity_maps_estimator(kspace)
elif sensitivity_map is not None and sensitivity_map.size != 0:
sensitivity_map = to_tensor(sensitivity_map)
sensitivity_map = self.coils_shape_transforms(sensitivity_map, apply_forward_transform=True)
sensitivity_map = self.cropping(sensitivity_map, apply_forward_transform=self.kspace_crop) # type: ignore
else:
# If no sensitivity map is provided, either the data is singlecoil or the sense net is used.
# Initialize the sensitivity map to 1 to assure for the singlecoil case.
sensitivity_map = torch.ones_like(kspace) if not isinstance(kspace, list) else torch.ones_like(kspace[0])
if not is_none(self.normalization.__repr__()):
sensitivity_map, pre_normalization_vars = self.normalization( # type: ignore
sensitivity_map, apply_forward_transform=self.kspace_crop
)
else:
is_complex = sensitivity_map.shape[-1] == 2
if is_complex:
sensitivity_map = torch.view_as_complex(sensitivity_map)
pre_normalization_vars = {
"min": torch.min(torch.abs(sensitivity_map)),
"max": torch.max(torch.abs(sensitivity_map)),
"mean": torch.mean(torch.abs(sensitivity_map)),
"std": torch.std(torch.abs(sensitivity_map)),
"var": torch.var(torch.abs(sensitivity_map)),
}
if is_complex:
sensitivity_map = torch.view_as_real(sensitivity_map)
return sensitivity_map, pre_normalization_vars
def __initialize_prediction__(
self, prediction: Union[np.ndarray, None], kspace: torch.Tensor, sensitivity_map: torch.Tensor
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Dict]:
"""Predicts a coil-combined image.
Parameters
----------
prediction : np.ndarray
The initial estimation, if None, the prediction is initialized.
kspace : torch.Tensor
The kspace.
sensitivity_map : torch.Tensor
The sensitivity map.
Returns
-------
Tuple[Union[List[torch.Tensor], torch.Tensor], Dict]
The initialized prediction, either a list of coil-combined images or a single coil-combined image and the
pre-normalization variables (min, max, mean, std).
"""
if is_none(prediction) or prediction.ndim < 2 or isinstance(kspace, list): # type: ignore
if isinstance(kspace, list):
prediction = []
pre_normalization_vars = []
for y in kspace:
pred = coil_combination_method_func(
ifft2(y, self.fft_centered, self.fft_normalization, self.spatial_dims),
sensitivity_map,
method=self.coil_combination_method,
dim=self.coil_dim,
)
pred = self.cropping(pred, apply_forward_transform=self.kspace_crop) # type: ignore
if not is_none(self.normalization.__repr__()):
pred, _pre_normalization_vars = self.normalization( # type: ignore
pred, apply_forward_transform=self.kspace_crop
)
else:
if pred.shape[-1] == 2:
pred = torch.view_as_complex(pred)
_pre_normalization_vars = {
"min": torch.min(torch.abs(pred)),
"max": torch.max(torch.abs(pred)),
"mean": torch.mean(torch.abs(pred)),
"std": torch.std(torch.abs(pred)),
"var": torch.var(torch.abs(pred)),
}
prediction.append(pred)
pre_normalization_vars.append(_pre_normalization_vars)
if prediction[0].shape[-1] != 2 and torch.is_complex(prediction[0]):
prediction = [torch.view_as_real(x) for x in prediction]
else:
prediction = coil_combination_method_func(
ifft2(kspace, self.fft_centered, self.fft_normalization, self.spatial_dims),
sensitivity_map,
method=self.coil_combination_method,
dim=self.coil_dim,
)
prediction = self.cropping(prediction, apply_forward_transform=self.kspace_crop) # type: ignore
if not is_none(self.normalization.__repr__()):
prediction, pre_normalization_vars = self.normalization( # type: ignore
prediction, apply_forward_transform=self.kspace_crop
)
else:
if prediction.shape[-1] == 2:
prediction = torch.view_as_complex(prediction)
pre_normalization_vars = { # type: ignore
"min": torch.min(torch.abs(prediction)),
"max": torch.max(torch.abs(prediction)),
"mean": torch.mean(torch.abs(prediction)),
"std": torch.std(torch.abs(prediction)),
"var": torch.var(torch.abs(prediction)),
}
if prediction.shape[-1] != 2 and torch.is_complex(prediction):
prediction = torch.view_as_real(prediction)
else:
if isinstance(prediction, np.ndarray):
prediction = to_tensor(prediction)
prediction = self.cropping(prediction, apply_forward_transform=self.kspace_crop) # type: ignore
if not is_none(self.normalization.__repr__()):
prediction, pre_normalization_vars = self.normalization( # type: ignore
prediction, apply_forward_transform=self.kspace_crop
)
else:
if prediction.shape[-1] == 2: # type: ignore
prediction = torch.view_as_complex(prediction)
pre_normalization_vars = { # type: ignore
"min": torch.min(torch.abs(prediction)),
"max": torch.max(torch.abs(prediction)),
"mean": torch.mean(torch.abs(prediction)),
"std": torch.std(torch.abs(prediction)),
"var": torch.var(torch.abs(prediction)),
}
if prediction.shape[-1] != 2 and torch.is_complex(prediction):
prediction = torch.view_as_real(prediction)
return prediction, pre_normalization_vars # type: ignore
def __parse_normalization_vars__( # noqa: MC0001
self, kspace_vars, sensitivity_vars, prediction_vars, noise_prediction_vars, target_vars
) -> Dict:
"""
Parses the normalization variables and returns a unified dictionary.
Parameters
----------
kspace_vars : Dict
The kspace normalization variables.
sensitivity_vars : Dict
The sensitivity map normalization variables.
prediction_vars : Dict
The prediction normalization variables.
noise_prediction_vars : Union[Dict, None]
The noise prediction normalization variables.
target_vars : Dict
The target normalization variables.
Returns
-------
Dict
The normalization variables.
"""
normalization_vars = {}
masked_kspace_vars = kspace_vars["masked_kspace_pre_normalization_vars"]
if isinstance(masked_kspace_vars, list):
if masked_kspace_vars[0] is not None:
for i, masked_kspace_var in enumerate(masked_kspace_vars):
normalization_vars[f"masked_kspace_min_{i}"] = masked_kspace_var["min"]
normalization_vars[f"masked_kspace_max_{i}"] = masked_kspace_var["max"]
normalization_vars[f"masked_kspace_mean_{i}"] = masked_kspace_var["mean"]
normalization_vars[f"masked_kspace_std_{i}"] = masked_kspace_var["std"]
normalization_vars[f"masked_kspace_var_{i}"] = masked_kspace_var["var"]
else:
if masked_kspace_vars is not None:
normalization_vars["masked_kspace_min"] = masked_kspace_vars["min"]
normalization_vars["masked_kspace_max"] = masked_kspace_vars["max"]
normalization_vars["masked_kspace_mean"] = masked_kspace_vars["mean"]
normalization_vars["masked_kspace_std"] = masked_kspace_vars["std"]
normalization_vars["masked_kspace_var"] = masked_kspace_vars["var"]
noise_masked_kspace_vars = kspace_vars["noise_masked_kspace_pre_normalization_vars"]
if noise_masked_kspace_vars is not None:
if isinstance(noise_masked_kspace_vars, list):
if noise_masked_kspace_vars[0] is not None:
for i, noise_masked_kspace_var in enumerate(noise_masked_kspace_vars):
normalization_vars[f"noise_masked_kspace_min_{i}"] = noise_masked_kspace_var["min"]
normalization_vars[f"noise_masked_kspace_max_{i}"] = noise_masked_kspace_var["max"]
normalization_vars[f"noise_masked_kspace_mean_{i}"] = noise_masked_kspace_var["mean"]
normalization_vars[f"noise_masked_kspace_std_{i}"] = noise_masked_kspace_var["std"]
normalization_vars[f"noise_masked_kspace_var_{i}"] = noise_masked_kspace_var["var"]
else:
if noise_masked_kspace_vars is not None:
normalization_vars["noise_masked_kspace_min"] = noise_masked_kspace_vars["min"]
normalization_vars["noise_masked_kspace_max"] = noise_masked_kspace_vars["max"]
normalization_vars["noise_masked_kspace_mean"] = noise_masked_kspace_vars["mean"]
normalization_vars["noise_masked_kspace_std"] = noise_masked_kspace_vars["std"]
normalization_vars["noise_masked_kspace_var"] = noise_masked_kspace_vars["var"]
kspace_vars = kspace_vars["kspace_pre_normalization_vars"]
if isinstance(kspace_vars, list):
if kspace_vars[0] is not None:
for i, kspace_var in enumerate(kspace_vars):
normalization_vars[f"kspace_min_{i}"] = kspace_var["min"]
normalization_vars[f"kspace_max_{i}"] = kspace_var["max"]
normalization_vars[f"kspace_mean_{i}"] = kspace_var["mean"]
normalization_vars[f"kspace_std_{i}"] = kspace_var["std"]
normalization_vars[f"kspace_var_{i}"] = kspace_var["var"]
else:
if kspace_vars is not None:
normalization_vars["kspace_min"] = kspace_vars["min"]
normalization_vars["kspace_max"] = kspace_vars["max"]
normalization_vars["kspace_mean"] = kspace_vars["mean"]
normalization_vars["kspace_std"] = kspace_vars["std"]
normalization_vars["kspace_var"] = kspace_vars["var"]
if sensitivity_vars is not None:
normalization_vars["sensitivity_maps_min"] = sensitivity_vars["min"]
normalization_vars["sensitivity_maps_max"] = sensitivity_vars["max"]
normalization_vars["sensitivity_maps_mean"] = sensitivity_vars["mean"]
normalization_vars["sensitivity_maps_std"] = sensitivity_vars["std"]
normalization_vars["sensitivity_maps_var"] = sensitivity_vars["var"]
if isinstance(prediction_vars, list):
if prediction_vars[0] is not None:
for i, prediction_var in enumerate(prediction_vars):
normalization_vars[f"prediction_min_{i}"] = prediction_var["min"]
normalization_vars[f"prediction_max_{i}"] = prediction_var["max"]
normalization_vars[f"prediction_mean_{i}"] = prediction_var["mean"]
normalization_vars[f"prediction_std_{i}"] = prediction_var["std"]
normalization_vars[f"prediction_var_{i}"] = prediction_var["var"]
else:
if prediction_vars is not None:
normalization_vars["prediction_min"] = prediction_vars["min"]
normalization_vars["prediction_max"] = prediction_vars["max"]
normalization_vars["prediction_mean"] = prediction_vars["mean"]
normalization_vars["prediction_std"] = prediction_vars["std"]
normalization_vars["prediction_var"] = prediction_vars["var"]
if noise_prediction_vars is not None:
if isinstance(noise_prediction_vars, list):
for i, noise_prediction_var in enumerate(noise_prediction_vars):
normalization_vars[f"noise_prediction_min_{i}"] = noise_prediction_var["min"]
normalization_vars[f"noise_prediction_max_{i}"] = noise_prediction_var["max"]
normalization_vars[f"noise_prediction_mean_{i}"] = noise_prediction_var["mean"]
normalization_vars[f"noise_prediction_std_{i}"] = noise_prediction_var["std"]
normalization_vars[f"noise_prediction_var_{i}"] = noise_prediction_var["var"]
else:
normalization_vars["noise_prediction_min"] = noise_prediction_vars["min"]
normalization_vars["noise_prediction_max"] = noise_prediction_vars["max"]
normalization_vars["noise_prediction_mean"] = noise_prediction_vars["mean"]
normalization_vars["noise_prediction_std"] = noise_prediction_vars["std"]
normalization_vars["noise_prediction_var"] = noise_prediction_vars["var"]
if isinstance(target_vars, list):
if target_vars[0] is not None:
for i, target_var in enumerate(target_vars):
normalization_vars[f"target_min_{i}"] = target_var["min"]
normalization_vars[f"target_max_{i}"] = target_var["max"]
normalization_vars[f"target_mean_{i}"] = target_var["mean"]
normalization_vars[f"target_std_{i}"] = target_var["std"]
normalization_vars[f"target_var_{i}"] = target_var["var"]
else:
if target_vars is not None:
normalization_vars["target_min"] = target_vars["min"]
normalization_vars["target_max"] = target_vars["max"]
normalization_vars["target_mean"] = target_vars["mean"]
normalization_vars["target_std"] = target_vars["std"]
normalization_vars["target_var"] = target_vars["var"]
return normalization_vars