# coding=utf-8
__author__ = "Dimitris Karkalousos"
# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import h5py
import numpy as np
import torch
__all__ = [
"add_coil_dim_if_singlecoil",
"apply_mask",
"batched_mask_center",
"center_crop",
"center_crop_to_smallest",
"check_stacked_complex",
"check_one_hot",
"coil_combination_method",
"complex_abs",
"complex_abs_sq",
"complex_center_crop",
"complex_conj",
"complex_mul",
"crop_to_acs",
"expand_op",
"is_none",
"mask_center",
"normalize_inplace",
"parse_list_and_keep_last",
"reshape_fortran",
"rnn_weights_init",
"rss",
"rss_complex",
"save_predictions",
"sense",
"to_tensor",
"unnormalize",
"zero_nan_inf",
]
[docs]def add_coil_dim_if_singlecoil(x: torch.tensor, dim: int = 0) -> torch.tensor:
"""
Add dummy coil dimension if single coil data.
Parameters
----------
x : torch.tensor
The input data.
dim : int
The dimension to add coil dimension. Default is ``0``.
Returns
-------
torch.tensor
The input data with coil dimension added if single coil.
Examples
--------
>>> import torch
>>> from atommic.collections.common.parts.utils import add_coil_dim_if_singlecoil
>>> data = torch.rand(10, 10)
>>> data.shape
(10, 10)
>>> add_coil_dim_if_singlecoil(data).shape
(1, 10, 10)
>>> add_coil_dim_if_singlecoil(data, dim=-1).shape
(10, 10, 1)
"""
if len(x.shape) >= 4:
return x
return torch.unsqueeze(x, dim=dim)
[docs]def apply_mask(
x: torch.Tensor,
mask_func: Callable,
seed: Optional[Union[int, Tuple[int, ...]]] = None,
padding: Optional[Sequence[int]] = None,
shift: bool = False,
partial_fourier_percentage: Optional[float] = 0.0,
center_scale: Optional[float] = 0.02,
existing_mask: Optional[torch.Tensor] = None,
) -> Tuple[Any, Any, int]:
"""
Retrospectively accelerate/subsample k-space data by applying a mask to the input data.
Parameters
----------
x : torch.Tensor
The input k-space data. This should have at least 3 dimensions, where dimensions -3 and -2 are the spatial
dimensions, and the final dimension has size 2 (for complex values).
mask_func : Callable
A function that takes a shape (tuple of ints) and a random number seed and returns a mask.
seed : Optional[Union[int, Tuple[int, ...]]], optional
Seed for the random number generator. Default is ``None``.
padding : Optional[Sequence[int]], optional
Padding value to apply for mask. Default is ``None``.
shift : bool, optional
Toggle to shift mask when subsampling. Applicable on 2D data. Default is ``False``.
partial_fourier_percentage : Optional[float], optional
Percentage of kspace to be dropped. Default is ``0.0``.
center_scale : Optional[float], optional
Scale of the center of the mask. Applicable on Gaussian masks. Default is ``0.02``.
existing_mask : Optional[torch.Tensor], optional
When given, use this mask instead of generating a new one. Default is ``None``.
Returns
-------
Tuple[Any, Any, int]
Tuple containing the masked k-space data, the mask, and the acceleration factor.
Examples
--------
>>> from atommic.collections.common.parts.utils import apply_mask
>>> import torch
>>> data = torch.tensor([[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]], \
[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]]])
>>> data.shape
torch.Size([2, 2, 3, 2])
>>> mask = torch.tensor([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]])
>>> mask.shape
torch.Size([2, 2, 3])
>>> apply_mask(data, mask)
(tensor([[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]],
[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]]]),
tensor([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]),
6)
>>> masked_data, subsampling_mask, acceleration_factor = apply_mask(data, mask)
>>> masked_data.shape
torch.Size([2, 2, 3, 2])
>>> subsampling_mask.shape
torch.Size([2, 2, 3])
>>> acceleration_factor
6
>>> apply_mask(data, mask, padding=[1, 2], shift=True)
(tensor([[[[0., 0.], [0., 0.], [0., 0.]], [[1., 1.], [2., 2.], [3., 3.]]],
[[[0., 0.], [0., 0.], [0., 0.]], [[1., 1.], [2., 2.], [3., 3.]]]]),
tensor([[[0., 0., 0.], [1., 1., 1.]], [[0., 0., 0.], [1., 1., 1.]]]),
3)
>>> masked_data, subsampling_mask, acceleration_factor = apply_mask(data, mask, padding=[1, 2], shift=True)
>>> masked_data.shape
torch.Size([2, 2, 3, 2])
>>> subsampling_mask.shape
torch.Size([2, 2, 3])
>>> acceleration_factor
3
"""
shape = np.array(x.shape)
shape[:-3] = 1
if existing_mask is None:
mask, acc = mask_func(shape, seed, partial_fourier_percentage=partial_fourier_percentage, scale=center_scale)
else:
mask = existing_mask
acc = mask.size / mask.sum()
mask = mask.to(x.device)
if padding is not None and (padding[0] > 0 or padding[1] > 0):
mask[..., : padding[0], :] = 0
mask[..., padding[1] :, :] = 0 # padding value inclusive on right of zeros
if shift:
mask = torch.fft.fftshift(mask, dim=(1, 2))
masked_x = x * mask + 0.0 # the + 0.0 removes the sign of the zeros
return masked_x, mask, acc
[docs]def batched_mask_center(
x: torch.Tensor, mask_from: torch.Tensor, mask_to: torch.Tensor, mask_type: str = "2D"
) -> torch.Tensor:
"""
Initializes a mask with the center filled in. Can operate with different masks for each batch element.
Parameters
----------
x : torch.Tensor
The input image or batch of images. This should have at least 3 dimensions, where dimensions -3 and -2 are the
spatial dimensions, and the final dimension has size 1 (for real values).
mask_from : torch.Tensor
Part of center to start filling.
mask_to : torch.Tensor
Part of center to end filling.
mask_type : str, optional
Type of mask to apply. Can be either ``1D`` or ``2D``. Default is ``2D``.
Returns
-------
torch.Tensor
The masked image or batch of images with filled center.
Examples
--------
>>> from atommic.collections.common.parts.utils import batched_mask_center
>>> import torch
>>> data = torch.randn(1, 32, 320, 320)
>>> batched_mask_center(data, torch.tensor([140]), torch.tensor([180]))
"""
if mask_from.shape != mask_to.shape:
raise ValueError("mask_from and mask_to must match shapes.")
if mask_from.ndim != 1:
raise ValueError("mask_from and mask_to must have 1 dimension.")
if mask_from.shape[0] not in (1, x.shape[0]) or x.shape[0] != mask_to.shape[0]:
raise ValueError("mask_from and mask_to must have batch_size length.")
if mask_from.shape[0] == 1:
mask = mask_center(x, int(mask_from), int(mask_to), mask_type=mask_type)
else:
mask = torch.zeros_like(x)
for i, (start, end) in enumerate(zip(mask_from, mask_to)):
mask[i, :, :, start:end] = x[i, :, :, start:end]
return mask
[docs]def center_crop(x: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor:
"""
Apply a center crop to the input complex image or batch of complex images or real image or batch of real images
without a complex dimension.
Parameters
----------
x : torch.Tensor
The input tensor to be center cropped. It should have at least 2 dimensions and the cropping is applied along
the last two dimensions.
shape : Tuple[int, int]
The output shape. The shape should be smaller than the corresponding dimensions of data.
Returns
-------
torch.Tensor
The center cropped image or batch of images.
Examples
--------
>>> from atommic.collections.common.parts.utils import center_crop
>>> import torch
>>> data = torch.tensor([[[1+1j, 2+2j, 3+3j], [1+1j, 2+2j, 3+3j]], [[1+1j, 2+2j, 3+3j], [1+1j, 2+2j, 3+3j]]])
>>> data.shape
torch.Size([2, 2, 3])
>>> center_crop(data, (1, 2))
tensor([[[2.+2.j, 3.+3.j]], [[2.+2.j, 3.+3.j]]])
>>> center_crop(data, (1, 2)).shape
torch.Size([2, 1, 2])
"""
if not (0 < shape[0] <= x.shape[-2] and 0 < shape[1] <= x.shape[-1]):
raise ValueError("Invalid shapes.")
w_from = torch.div((x.shape[-2] - shape[0]), 2, rounding_mode="trunc")
h_from = torch.div((x.shape[-1] - shape[1]), 2, rounding_mode="trunc")
w_to = w_from + shape[0]
h_to = h_from + shape[1]
return x[..., w_from:w_to, h_from:h_to]
[docs]def center_crop_to_smallest(
x: Union[torch.Tensor, np.ndarray], y: Union[torch.Tensor, np.ndarray]
) -> Tuple[Union[torch.Tensor, np.ndarray], Union[torch.Tensor, np.ndarray]]:
"""
Apply a center crop on the larger image to the size of the smaller.
The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at dim=-1 and y is smaller than x at dim=-2,
then the returned dimension will be a mixture of the two.
Parameters
----------
x : torch.Tensor or np.ndarray
The first image.
y : torch.Tensor or np.ndarray
The second image.
Returns
-------
Tuple[torch.Tensor or np.ndarray, torch.Tensor or np.ndarray]
Tuple of x and y, cropped to the minimum size.
Examples
--------
>>> from atommic.collections.common.parts.utils import center_crop_to_smallest
>>> import torch
>>> data1 = torch.tensor([[[1+1j, 2+2j, 3+3j], [1+1j, 2+2j, 3+3j]], [[1+1j, 2+2j, 3+3j], [1+1j, 2+2j, 3+3j]]])
>>> data2 = torch.tensor([[[1+1j, 2+2j, 3+3j, 4+4j, 5+5j], [1+1j, 2+2j, 3+3j, 4+4j, 5+5j]], \
[[1+1j, 2+2j, 3+3j, 4+4j, 5+5j], [1+1j, 2+2j, 3+3j, 4+4j, 5+5j], [1+1j, 2+2j, 3+3j, 4+4j, 5+5j]]])
>>> data1.shape
torch.Size([2, 2, 3])
>>> data2.shape
torch.Size([2, 3, 5])
>>> center_crop_to_smallest(data1, data2)
(tensor([[[1+1j, 2+2j, 3+3j], [1+1j, 2+2j, 3+3j]], [[1+1j, 2+2j, 3+3j], [1+1j, 2+2j, 3+3j]]]), \
tensor([[[1.+1.j, 2.+2.j, 3.+3.j], [1.+1.j, 2.+2.j, 3.+3.j]], \
[[1.+1.j, 2.+2.j, 3.+3.j], [1.+1.j, 2.+2.j, 3.+3.j]]]))
>>> center_crop_to_smallest(data1, data2)[0].shape
torch.Size([2, 2, 3])
>>> center_crop_to_smallest(data1, data2)[1].shape
torch.Size([2, 2, 3])
>>> center_crop_to_smallest(data2, data1)
(tensor([[[1.+1.j, 2.+2.j, 3.+3.j], [1.+1.j, 2.+2.j, 3.+3.j]], \
[[1.+1.j, 2.+2.j, 3.+3.j], [1.+1.j, 2.+2.j, 3.+3.j]]]), \
tensor([[[1+1j, 2+2j, 3+3j], [1+1j, 2+2j, 3+3j]], [[1+1j, 2+2j, 3+3j], [1+1j, 2+2j, 3+3j]]]))
>>> center_crop_to_smallest(data2, data1)[0].shape
torch.Size([2, 2, 3])
>>> center_crop_to_smallest(data2, data1)[1].shape
torch.Size([2, 2, 3])
"""
smallest_width = min(x.shape[-1], y.shape[-1])
smallest_height = min(x.shape[-2], y.shape[-2])
x = center_crop(x, (smallest_height, smallest_width))
y = center_crop(y, (smallest_height, smallest_width))
return x, y
[docs]def check_stacked_complex(x: torch.Tensor) -> torch.Tensor:
"""
Check if tensor is stacked complex (real & imaginary parts stacked along last dim) and convert it to a combined
complex tensor.
Parameters
----------
x : torch.Tensor
Tensor to check.
Returns
-------
torch.Tensor
Tensor with stacked complex converted to combined complex.
Examples
--------
>>> from atommic.collections.common.parts.utils import check_stacked_complex
>>> import torch
>>> data = torch.tensor([1+1j, 2+2j, 3+3j])
>>> data.shape
torch.Size([3])
>>> data = torch.view_as_real(data)
>>> data.shape
>>> check_stacked_complex(data)
tensor([1.+1.j, 2.+2.j, 3.+3.j])
>>> check_stacked_complex(data).shape
torch.Size([3])
>>> data = torch.tensor([1+1j, 2+2j, 3+3j])
>>> data.shape
torch.Size([3])
>>> check_stacked_complex(data)
tensor([1.+1.j, 2.+2.j, 3.+3.j])
>>> check_stacked_complex(data).shape
torch.Size([3])
"""
return torch.view_as_complex(x) if x.shape[-1] == 2 else x
def check_one_hot(x: torch.Tensor) -> bool:
"""
Check if tensor is one_hot encoded and returns a boolean.
Parameters
----------
x : torch.Tensor
Tensor to check with atleast two dimensions.
Returns
-------
Boolean
True or False
Examples
--------
>>> from atommic.collections.common.parts.utils import check_one_hot
>>> import torch
>>> data = torch.Tensor([[1,0], [0,1]])
>>> check_one_hot(data)
True
>>> data = torch.Tensor([[1,1], [0,1]])
>>> check_one_hot(data)
False
>>> data = torch.Tensor([[1,2], [0,0]])
>>> check_one_hot(data)
False
>>> data = torch.Tensor([[1,1], [0,0]])
>>> check_one_hot(data)
True
"""
# Ensure tensor is at least 2D
x = x if x.ndimension() >= 2 else x.unsqueeze(0)
for dim in range(x.ndimension()):
# Permute the candidate class dimension to the last position
permuted_tensor = x.permute(*[i for i in range(x.ndimension()) if i != dim], dim)
# Check conditions for one-hot encoding:
# 1. All values must be either 0 or 1
is_binary = torch.all((permuted_tensor == 0) | (permuted_tensor == 1))
# 2. Exactly one '1' per row along the last axis
has_single_one = torch.all(torch.sum(permuted_tensor, dim=-1) == 1)
if is_binary and has_single_one:
return True
return False
[docs]def coil_combination_method(
x: torch.Tensor, sensitivity_maps: torch.Tensor, method: str = "SENSE", dim: int = 0
) -> torch.Tensor:
"""
Selects the coil combination method.
Parameters
----------
x : torch.Tensor
The tensor to coil-combine.
sensitivity_maps : torch.Tensor
The coil sensitivity maps.
method : str, optional
The coil combination method to use. Options are ``"SENSE"``, ``"RSS"``, ``"RSS_COMPLEX"``.
Default is ``"SENSE"``.
dim : int, optional
The dimension to coil-combine along. Default is ``0``.
Returns
-------
torch.Tensor
Coil-combined tensor with the selected method applied.
Examples
--------
>>> from atommic.collections.common.parts.utils import coil_combination_method
>>> import torch
>>> data = torch.tensor([[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]], \
[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]]])
>>> data.shape
torch.Size([2, 2, 3, 2])
>>> coil_sensitivity_maps = torch.tensor([[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]], \
[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]]])
>>> coil_sensitivity_maps.shape
torch.Size([2, 2, 3, 2])
>>> coil_combination_method(data, coil_sensitivity_maps, method="SENSE")
tensor([[[2.8284, 2.8284],
[5.6569, 5.6569],
[8.4853, 8.4853]],
[[2.8284, 2.8284],
[5.6569, 5.6569],
[8.4853, 8.4853]]])
>>> coil_combination_method(data, coil_sensitivity_maps, method="SENSE").shape
torch.Size([2, 3, 2])
>>> coil_combination_method(data, coil_sensitivity_maps, method="RSS")
tensor([[[1.4142, 1.4142],
[2.8284, 2.8284],
[4.2426, 4.2426]],
[[1.4142, 1.4142],
[2.8284, 2.8284],
[4.2426, 4.2426]]])
>>> coil_combination_method(data, coil_sensitivity_maps, method="RSS").shape
torch.Size([2, 3, 2])
>>> coil_combination_method(data, coil_sensitivity_maps, method="RSS_COMPLEX")
tensor([[[1.4142, 1.4142],
[2.8284, 2.8284],
[4.2426, 4.2426]],
[[1.4142, 1.4142],
[2.8284, 2.8284],
[4.2426, 4.2426]]])
>>> coil_combination_method(data, coil_sensitivity_maps, method="RSS_COMPLEX").shape
torch.Size([2, 3, 2])
"""
if method == "SENSE":
return sense(x, sensitivity_maps, dim)
if method == "RSS":
return rss(x, dim)
if method == "RSS_COMPLEX":
return rss_complex(x, dim)
raise ValueError("Output type not supported.")
[docs]def complex_abs(x: torch.Tensor) -> torch.Tensor:
"""
Compute the absolute value of a complex valued input tensor.
Parameters
----------
x : torch.Tensor
Complex tensor. The last dimension must be of size 2.
Returns
-------
torch.Tensor
Absolute value of complex tensor.
Examples
--------
>>> from atommic.collections.common.parts.utils import complex_abs
>>> import torch
>>> data = torch.tensor([1+1j, 2+2j, 3+3j])
>>> complex_abs(data)
tensor([1.4142, 2.8284, 4.2426])
"""
if x.shape[-1] != 2:
if torch.is_complex(x):
x = torch.view_as_real(x)
else:
raise ValueError("Tensor does not have separate complex dim.")
return (x**2).sum(dim=-1)
[docs]def complex_abs_sq(x: torch.Tensor) -> torch.Tensor:
"""
Compute the squared absolute value of a complex tensor.
Parameters
----------
x : torch.Tensor
Complex tensor. The last dimension must be of size 2.
Returns
-------
torch.Tensor
Squared absolute value of complex tensor.
Examples
--------
>>> from atommic.collections.common.parts.utils import complex_abs_sq
>>> import torch
>>> data = torch.tensor([1+1j, 2+2j, 3+3j])
>>> complex_abs_sq(data)
tensor([2., 8., 18.])
"""
if x.shape[-1] != 2:
if torch.is_complex(x):
x = torch.view_as_real(x)
else:
raise ValueError("Tensor does not have separate complex dim.")
return (x**2).sum(dim=-1).sqrt()
[docs]def complex_center_crop(x: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor:
"""
Apply a center crop to the input image or batch of complex images.
Parameters
----------
x : torch.Tensor
The input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along
the last two dimensions.
shape : Tuple[int, int]
The output shape. The shape should be smaller than the corresponding dimensions of data.
Returns
-------
torch.Tensor
The complex center cropped image or batch of images.
Examples
--------
>>> from atommic.collections.common.parts.utils import complex_center_crop
>>> import torch
>>> data = torch.tensor([[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]], \
[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]]])
>>> data.shape
torch.Size([2, 2, 3, 2])
>>> complex_center_crop(data, (1, 2))
tensor([[[[2., 2.]]],
[[[2., 2.]]]])
>>> complex_center_crop(data, (1, 2)).shape
torch.Size([2, 1, 1, 2])
"""
if not (0 < shape[0] <= x.shape[-3] and 0 < shape[1] <= x.shape[-2]):
raise ValueError("Invalid shapes.")
w_from = torch.div((x.shape[-3] - shape[0]), 2, rounding_mode="trunc")
h_from = torch.div((x.shape[-2] - shape[1]), 2, rounding_mode="trunc")
w_to = w_from + shape[0]
h_to = h_from + shape[1]
return x[..., w_from:w_to, h_from:h_to, :]
[docs]def complex_conj(x: torch.Tensor) -> torch.Tensor:
"""
Complex conjugate.
This applies the complex conjugate assuming that the input array has the last dimension as the complex dimension.
Parameters
----------
x : torch.Tensor
Complex tensor to apply the complex conjugate to. The last dimension must be of size 2.
Returns
-------
torch.Tensor
Result of complex conjugate.
Examples
--------
>>> from atommic.collections.common.parts.utils import complex_conj
>>> import torch
>>> data = torch.tensor([1+1j, 2+2j, 3+3j])
>>> complex_conj(data)
tensor([1.-1.j, 2.-2.j, 3.-3.j])
"""
if x.shape[-1] != 2:
raise ValueError("Tensor does not have separate complex dim.")
return torch.stack((x[..., 0], -x[..., 1]), dim=-1)
[docs]def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Complex multiplication.
This multiplies two complex tensors assuming that they are both stored as real arrays with the last dimension
being the complex dimension.
Parameters
----------
x : torch.Tensor
First complex tensor to multiply. The last dimension must be of size 2.
y : torch.Tensor
Second complex tensor to multiply. The last dimension must be of size 2.
Returns
-------
torch.Tensor
Result of complex multiplication.
Examples
--------
>>> from atommic.collections.common.parts.utils import complex_mul
>>> import torch
>>> datax = torch.tensor([1+1j, 2+2j, 3+3j])
>>> datay = torch.tensor([4+4j, 5+5j, 6+6j])
>>> complex_mul(datax, datay)
tensor([[-7.+20.j],
[-4.+16.j],
[-1.+12.j]])
"""
if not x.shape[-1] == y.shape[-1] == 2:
raise ValueError("Tensors do not have separate complex dim.")
re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]
im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]
return torch.stack((re, im), dim=-1)
[docs]def crop_to_acs(acs_mask: torch.Tensor, kspace: torch.Tensor) -> torch.Tensor:
r"""Crops k-space to autocalibration region given the acs_mask.
Parameters
----------
acs_mask : torch.Tensor
Autocalibration mask of shape (height, width).
kspace : torch.Tensor
K-space of shape (coil, height, width, *).
Returns
-------
torch.Tensor
Cropped k-space of shape (coil, height, width, *), where height and width are the new dimensions derived from
the acs_mask.
"""
nonzero_idxs = torch.nonzero(acs_mask)
x, y = nonzero_idxs[..., 0], nonzero_idxs[..., 1]
xl, xr = x.min(), x.max()
yl, yr = y.min(), y.max()
return kspace[:, xl : xr + 1, yl : yr + 1, :]
[docs]def expand_op(x: torch.Tensor, sensitivity_maps: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""
Expand a coil-combined image to a multi-coil image.
Parameters
----------
x : torch.Tensor
The coil-combined image.
sensitivity_maps : torch.Tensor
The sensitivity maps.
dim : int
The coil dimension to expand. Default is ``1``.
Returns
-------
torch.Tensor
The multi-coil image.
Examples
--------
>>> import torch
>>> from atommic.collections.common.parts.utils import expand_op
>>> data = torch.rand(1, 200, 200, 2)
>>> sens = torch.rand(1, 30, 200, 200, 2)
>>> expand_op(data, sens).shape
(1, 30, 200, 200, 2)
"""
return torch.unsqueeze(x, dim=dim) * sensitivity_maps
[docs]def is_none(x: Union[Any, None]) -> bool:
"""
Check if input is None or "None".
Parameters
----------
x : Union[Any, None]
Input to check.
Returns
-------
bool
True if x is None or "None", False otherwise.
Examples
--------
>>> from atommic.collections.common.parts.utils import is_none
>>> is_none(None)
True
>>> is_none("None")
True
"""
return x is None or str(x).lower() == "none" or "none" in str(x).lower()
[docs]def mask_center(
x: torch.Tensor, mask_from: Optional[int], mask_to: Optional[int], mask_type: str = "2D"
) -> torch.Tensor:
"""
Apply a center crop to the input real image or batch of real images.
Parameters
----------
x : torch.Tensor
The input image or batch of images. This should have at least 3 dimensions, where dimensions -3 and -2 are the
spatial dimensions, and the final dimension has size 1 (for real values).
mask_from : Optional[int]
Part of center to start filling.
mask_to : Optional[int]
Part of center to end filling.
mask_type : str, optional
Type of mask to apply. Can be either ``1D`` or ``2D``. Default is ``2D``.
Returns
-------
torch.Tensor
The masked image or batch of images with filled center.
Examples
--------
>>> from atommic.collections.common.parts.utils import mask_center
>>> import torch
>>> data = torch.tensor([[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]], \
[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]]])
>>> data.shape
torch.Size([2, 2, 3, 2])
>>> mask_center(data, 1, 2)
tensor([[[[0., 0.], [1., 1.], [0., 0.]], [[0., 0.], [1., 1.], [0., 0.]]],
[[[0., 0.], [1., 1.], [0., 0.]], [[0., 0.], [1., 1.], [0., 0.]]]])
>>> mask_center(data, 1, 2).shape
torch.Size([2, 2, 3, 2])
"""
mask = torch.zeros_like(x)
if isinstance(mask_from, list):
mask_from = mask_from[0]
if isinstance(mask_to, list):
mask_to = mask_to[0]
if mask_type == "1D":
mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to]
elif mask_type == "2D":
mask[:, :, mask_from:mask_to] = x[:, :, mask_from:mask_to]
else:
raise ValueError(f"Unknown mask type {mask_type}")
return mask
[docs]def normalize_inplace(x: torch.Tensor, normalization_type: str = "max") -> torch.Tensor:
"""
Normalize the input data inplace. This is different from the ``Normalizer`` transformation that normalizes the data
in a non batch-wise manner.
Parameters
----------
x : torch.Tensor
The input data.
normalization_type : str
The normalization type. Default is ``"max"``.
Returns
-------
torch.Tensor
The unnormalized data.
Examples
--------
>>> import torch
>>> from atommic.collections.common.parts.utils import normalize_inplace
>>> data = torch.rand(1, 200, 200, 2)
>>> attrs = {"max": 164.4672133, "min": 0.000279681}
>>> normalize_inplace(data, attrs).shape
(1, 200, 200, 2)
"""
if normalization_type == "max":
return x / torch.max(torch.abs(x))
if normalization_type == "minmax":
min_value = torch.min(torch.abs(x))
return (x - min_value) / (torch.max(torch.abs(x)) - min_value)
if normalization_type == "mean_std":
x = x - torch.mean(torch.abs(x))
return x / torch.std(torch.abs(x))
if normalization_type == "mean_var":
x = x - torch.mean(torch.abs(x))
return x / torch.var(torch.abs(x))
if normalization_type == "grayscale":
x = x - torch.min(torch.abs(x))
x = x / torch.max(torch.abs(x))
return x * 255.0
return x
[docs]def parse_list_and_keep_last(x: Union[Any, List[Any]]) -> List[Any]:
"""Parse a list of values and keep the last one, until the last value is not a list."""
if isinstance(x, list):
while isinstance(x, list):
x = x[-1]
return x
[docs]def reshape_fortran(x, shape) -> torch.Tensor:
"""
Reshapes a tensor in Fortran order. Taken from https://stackoverflow.com/a/63964246
Parameters
----------
x : torch.Tensor
Input tensor to be reshaped.
shape : Sequence[int]
Shape to reshape the tensor to.
Returns
-------
torch.Tensor
Reshaped tensor.
Examples
--------
>>> from atommic.collections.common.parts.utils import reshape_fortran
>>> import torch
>>> data = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
>>> data.shape
torch.Size([2, 2, 3])
>>> reshape_fortran(data, (3, 2, 2))
tensor([[[ 1, 7],
[ 4, 10]],
[[ 2, 8],
[ 5, 11]],
[[ 3, 9],
[ 6, 12]]])
>>> reshape_fortran(data, (3, 2, 2)).shape
torch.Size([3, 2, 2])
"""
return x.permute(*reversed(range(len(x.shape)))).reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
[docs]def rnn_weights_init(module: torch.nn.Module, std_init_range: float = 0.02, xavier: bool = True):
r"""Initialize weights in Recurrent Neural Network.
Parameters
----------
module : torch.nn.Module
Module to initialize.
std_init_range : float
Standard deviation of normal initializer. Default is ``0.02``.
xavier : bool
If True, xavier initializer will be used in Linear layers as in [Vaswani2017]_. Otherwise, normal initializer
will be used. Default is ``True``.
References
----------
.. [Vaswani2017] Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I. Attention
is all you need. Advances in neural information processing systems. 2017;30.
Examples
--------
>>> import torch
>>> from atommic.collections.common.parts.utils import rnn_weights_init
>>> rnn = torch.nn.GRU(10, 20, 2)
>>> rnn.apply(rnn_weights_init)
GRU(10, 20, num_layers=2)
"""
if isinstance(module, torch.nn.Linear):
if xavier:
torch.nn.init.xavier_uniform_(module.weight)
else:
torch.nn.init.normal_(module.weight, mean=0.0, std=std_init_range)
if module.bias is not None:
torch.nn.init.constant_(module.bias, 0.0)
elif isinstance(module, torch.nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=std_init_range)
elif isinstance(module, torch.nn.LayerNorm):
torch.nn.init.constant_(module.weight, 1.0)
torch.nn.init.constant_(module.bias, 0.0)
[docs]def save_predictions(
predictions: Dict[str, np.ndarray], out_dir: Path, key: str = "reconstructions", file_format: str = "h5"
) -> None:
"""
Save predictions to selected format.
Parameters
----------
predictions : Dict[str, np.ndarray]
A dictionary mapping input filenames to corresponding predictions.
out_dir : Path
The output directory to save the predictions to.
key : str, optional
The key to save the predictions under. Default is ``reconstructions``.
file_format : str, optional
The format to save the predictions in. Default is ``h5``.
Examples
--------
>>> from atommic.collections.common.parts.utils import save_predictions
>>> import numpy as np
>>> from pathlib import Path
>>> data = {"test.h5": np.array([[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]])}
>>> data["test.h5"].shape
(2, 3, 2)
>>> output_directory = Path("predictions")
>>> save_predictions(data, output_directory, key="reconstructions", file_format="h5")
>>> save_predictions(data, output_directory, key="segmentations", file_format="h5")
"""
if file_format != "h5":
raise ValueError(f"Output format {file_format} is not supported.")
out_dir.mkdir(exist_ok=True, parents=True)
for fname, preds in predictions.items():
with h5py.File(out_dir / fname, "w") as hf:
hf.create_dataset(key, data=preds)
[docs]def sense(x: torch.Tensor, sensitivity_maps: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""
Coil-combination according to the SENSitivity Encoding (SENSE) method [Pruessmann1999]_.
References
----------
.. [Pruessmann1999] Pruessmann KP, Weiger M, Scheidegger MB, Boesiger P. SENSE: Sensitivity encoding for fast MRI.
Magn Reson Med 1999; 42:952-962.
Parameters
----------
x : torch.Tensor
The tensor to coil-combine.
sensitivity_maps : torch.Tensor
The coil sensitivity maps.
dim : int, optional
The dimension to coil-combine along. Default is ``0``.
Returns
-------
torch.Tensor
Coil-combined tensor with SENSE applied.
Examples
--------
>>> from atommic.collections.common.parts.utils import sense
>>> import torch
>>> data = torch.tensor([[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]], \
[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]]])
>>> data.shape
torch.Size([2, 2, 3, 2])
>>> coil_sensitivity_maps = torch.tensor([[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]], \
[[[1., 1.], [2., 2.], [3., 3.]], [[1., 1.], [2., 2.], [3., 3.]]]])
>>> coil_sensitivity_maps.shape
torch.Size([2, 2, 3, 2])
>>> sense(data, coil_sensitivity_maps)
tensor([[[2.8284, 2.8284],
[5.6569, 5.6569],
[8.4853, 8.4853]],
[[2.8284, 2.8284],
[5.6569, 5.6569],
[8.4853, 8.4853]]])
>>> sense(data, coil_sensitivity_maps).shape
torch.Size([2, 3, 2])
"""
return complex_mul(x, complex_conj(sensitivity_maps)).sum(dim)
[docs]def to_tensor(x: np.ndarray) -> torch.Tensor:
"""
Converts a numpy array to a torch tensor. For complex arrays, the real and imaginary parts are stacked along the
last dimension.
Parameters
----------
x : np.ndarray
Input numpy array to be converted to torch.
Returns
-------
torch.Tensor
Torch tensor version of input.
Examples
--------
>>> from atommic.collections.common.parts.utils import to_tensor
>>> import numpy as np
>>> data = np.array([[1+1j, 2+2j, 3+3j], [4+4j, 5+5j, 6+6j]])
>>> data.shape
(2, 3)
>>> to_tensor(data)
tensor([[[1., 1.],
[2., 2.],
[3., 3.]],
[[4., 4.],
[5., 5.],
[6., 6.]]], dtype=torch.float64)
>>> to_tensor(data).shape
torch.Size([2, 3, 2])
"""
if np.iscomplexobj(x):
x = np.stack((x.real, x.imag), axis=-1)
return torch.from_numpy(x)
[docs]def unnormalize(x: torch.Tensor, attrs: Dict, normalization_type: str = "max") -> torch.Tensor:
"""
Unnormalize the input data.
Parameters
----------
x : torch.Tensor
The input data.
attrs : Dict
The attributes of the input data.
normalization_type : str
The normalization type. Default is ``"max"``.
Returns
-------
torch.Tensor
The unnormalized data.
Examples
--------
>>> import torch
>>> from atommic.collections.common.parts.utils import unnormalize
>>> data = torch.rand(1, 200, 200, 2)
>>> attrs = {"max": 1.0, "min": 0.0}
>>> unnormalize(data, attrs).shape
(1, 200, 200, 2)
"""
if normalization_type == "max":
return x * attrs["max"]
if normalization_type == "minmax":
return x * (attrs["max"] - attrs["min"]) + attrs["min"]
if normalization_type == "mean_std":
return x * attrs["std"] + attrs["mean"]
if normalization_type == "mean_var":
return x * attrs["var"] + attrs["mean"]
if normalization_type == "grayscale":
return x / 255.0
return x
[docs]def zero_nan_inf(x):
"""If x is nan or inf, return 0."""
if torch.isnan(x).any() or torch.isinf(x).any():
x = torch.tensor(0.0)
return x