Source code for atommic.collections.common.parts.coil_sensitivity_maps

# coding=utf-8
__author__ = "Dimitris Karkalousos"

# Taken and adapted from: https://github.com/NKI-AI/direct/blob/main/direct/algorithms/mri_algorithms.py

from abc import ABC
from typing import Callable, Optional, Sequence

import numpy as np
import torch

from atommic.collections.common.parts.fft import ifft2
from atommic.collections.common.parts.utils import crop_to_acs


[docs]class MaximumEigenvaluePowerMethod(ABC): """A class for solving the maximum eigenvalue problem using the Power Method. The Power Method is an iterative algorithm that can be used to find the largest eigenvalue of a matrix. The algorithm is initialized with a random vector and iteratively updates the vector by multiplying it with the matrix and normalizing it. The algorithm converges to the eigenvector corresponding to the largest eigenvalue. The largest eigenvalue is then estimated by taking the dot product of the eigenvector with the matrix. """
[docs] def __init__( self, forward_operator: Callable, norm_func: Optional[Callable] = None, max_iter: int = 30, ): """Inits :class:`MaximumEigenvaluePowerMethod`. Parameters ---------- forward_operator : Callable The forward operator for the problem. norm_func : Callable, optional An optional function for normalizing the eigenvector. Default is ``None``. max_iter : int, optional Maximum number of iterations to run the algorithm. Default is ``30``. """ super().__init__() self.forward_operator = forward_operator self.norm_func = norm_func self.max_iter = max_iter self.iter = 0
def _update(self) -> None: """Perform a single update step of the maximum eigenvalue guess and corresponding eigenvector.""" y = self.forward_operator(self.x) # type: ignore if self.norm_func is None: self.max_eig = (y * self.x.conj()).sum() / (self.x * self.x.conj()).sum() # type: ignore else: self.max_eig = self.norm_func(y) self.x = y / self.max_eig def _done(self) -> bool: """Check if the algorithm is done.""" return self.iter >= self.max_iter def _fit(self, x: torch.Tensor) -> None: """Sets initial maximum eigenvector guess.""" self.x = x
[docs] def update(self) -> None: """Update the algorithm's parameters and increment the iteration count.""" self._update() self.iter += 1
[docs] def done(self) -> bool: """Check if the algorithm has converged.""" return self._done()
[docs] def fit(self, *args, **kwargs) -> None: """Fit the algorithm. Parameters ---------- *args : tuple Tuple of arguments for `_fit` method. **kwargs : dict Keyword arguments for `_fit` method. """ self._fit(*args, **kwargs) while not self.done(): self.update()
[docs]class EspiritCalibration(ABC, torch.nn.Module): """Estimates sensitivity maps estimated with the ESPIRIT calibration method as described in [1]_. We adapted code for ESPIRIT method adapted from [2]_. References ---------- .. [1] Uecker M, Lai P, Murphy MJ, Virtue P, Elad M, Pauly JM, Vasanawala SS, Lustig M. ESPIRiT--an eigenvalue approach to autocalibrating parallel MRI: where SENSE meets GRAPPA. Magn Reson Med. 2014 Mar;71(3):990-1001. doi: 10.1002/mrm.24751. PMID: 23649942; PMCID: PMC4142121. .. [2] https://github.com/mikgroup/sigpy/blob/1817ff849d34d7cbbbcb503a1b310e7d8f95c242/sigpy/mri/app.py#L388-L491 """
[docs] def __init__( self, threshold: float = 0.05, kernel_size: int = 6, crop: float = 0.95, max_iter: int = 100, fft_centered: bool = False, fft_normalization: str = "backward", spatial_dims: Sequence[int] = (-2, -1), ): """Inits :class:`EstimateSensitivityMap`. Parameters ---------- threshold: float, optional Threshold for the calibration matrix. Default: 0.05. kernel_size: int, optional Kernel size for the calibration matrix. Default: 6. crop: float, optional Output eigenvalue cropping threshold. Default: 0.95. max_iter: int, optional Power method iterations. Default: 30. fft_centered: bool, optional Whether to center the FFT. Default is ``False``. fft_normalization: str, optional Normalization to apply to the FFT. Default is ``"backward"``. spatial_dims: Sequence[int], optional Spatial dimensions of the input. Default is ``(-2, -1)``. """ super().__init__() self.threshold = threshold self.kernel_size = kernel_size self.crop = crop self.max_iter = max_iter self.fft_centered = fft_centered self.fft_normalization = fft_normalization self.spatial_dims = spatial_dims
[docs] def calculate_sensitivity_map(self, acs_mask: torch.Tensor, kspace: torch.Tensor) -> torch.Tensor: """Calculates sensitivity map given as input the `acs_mask` and the `k-space`. Parameters ---------- acs_mask : torch.Tensor Autocalibration mask. kspace : torch.Tensor K-space. Returns ------- sensitivity_map : torch.Tensor Coil sensitivity maps. """ ndim = kspace.ndim - 2 spatial_size = kspace.shape[1:-1] # Used in case the k-space is padded (e.g. for batches) non_padded_dim = kspace.clone().sum(dim=tuple(range(1, kspace.ndim))).bool() num_coils = non_padded_dim.sum() acs_kspace_cropped = torch.view_as_complex(crop_to_acs(acs_mask.squeeze(), kspace[non_padded_dim])) # Get calibration matrix. calibration_matrix = ( torch.nn.functional.unfold(acs_kspace_cropped.unsqueeze(0), kernel_size=self.kernel_size, stride=1) .transpose(1, 2) .to(acs_kspace_cropped.device) .reshape( num_coils, *(np.array(acs_kspace_cropped.shape[1:3]) - self.kernel_size + 1), *([self.kernel_size] * ndim), ) ) calibration_matrix = calibration_matrix.reshape(num_coils, -1, self.kernel_size**ndim) calibration_matrix = calibration_matrix.permute(1, 0, 2) calibration_matrix = calibration_matrix.reshape(-1, num_coils * self.kernel_size**ndim) _, s, vh = torch.linalg.svd(calibration_matrix, full_matrices=True) vh = torch.where(s > (self.threshold * s.max()), vh, torch.zeros_like(vh)) # Get kernels num_kernels = vh.shape[0] kernels = vh.reshape([num_kernels, num_coils] + [self.kernel_size] * ndim) # Get covariance matrix in image domain covariance = torch.zeros( spatial_size[::-1] + (num_coils, num_coils), dtype=kernels.dtype, device=kernels.device, ) for kernel in kernels: pad_h, pad_w = ( spatial_size[0] - self.kernel_size, spatial_size[1] - self.kernel_size, ) pad = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) kernel_padded = torch.nn.functional.pad(kernel, pad) img_kernel = ifft2(kernel_padded, self.fft_centered, self.fft_normalization, self.spatial_dims) img_kernel = torch.view_as_complex(img_kernel) aH = img_kernel.permute(*torch.arange(img_kernel.ndim - 1, -1, -1)).unsqueeze(-1) a = aH.transpose(-1, -2).conj() covariance = covariance + aH @ a covariance = covariance * (np.prod(spatial_size) / self.kernel_size**ndim) sensitivity_map = torch.ones( (*spatial_size[::-1], num_coils, 1), dtype=kernels.dtype, device=kernels.device, ) def forward(x): return covariance @ x def normalize(x): return (x.abs() ** 2).sum(dim=-2, keepdims=True) ** 0.5 power_method = MaximumEigenvaluePowerMethod(forward, max_iter=self.max_iter, norm_func=normalize) power_method.fit(x=sensitivity_map) temp_sensitivity_map = power_method.x.squeeze(-1) temp_sensitivity_map = temp_sensitivity_map.permute( *torch.arange(temp_sensitivity_map.ndim - 1, -1, -1) ).squeeze(-1) temp_sensitivity_map = temp_sensitivity_map * temp_sensitivity_map.conj() / temp_sensitivity_map.abs() max_eig = power_method.max_eig.squeeze() max_eig = max_eig.permute(*torch.arange(max_eig.ndim - 1, -1, -1)) temp_sensitivity_map = temp_sensitivity_map * (max_eig > self.crop) sensitivity_map = torch.zeros_like(kspace, device=kspace.device, dtype=kspace.dtype) sensitivity_map[non_padded_dim] = torch.view_as_real(temp_sensitivity_map) return sensitivity_map
[docs] def forward(self, acs_mask: torch.Tensor, kspace: torch.Tensor) -> torch.Tensor: """Forward pass of :class:`EspiritCalibration`. Parameters ---------- acs_mask : torch.Tensor Autocalibration mask. kspace : torch.Tensor K-space. Returns ------- sensitivity_map : torch.Tensor Coil sensitivity maps. """ return self.calculate_sensitivity_map(acs_mask, kspace)