Source code for atommic.collections.common.parts.fft

# coding=utf-8
__author__ = "Dimitris Karkalousos"

from typing import List, Sequence, Union

import numpy as np
import torch
from omegaconf import ListConfig

__all__ = ["fft2", "ifft2", "fftshift", "ifftshift"]


[docs]def fft2( x: torch.Tensor, centered: bool = False, normalization: str = "backward", spatial_dims: Sequence[int] = None, ) -> torch.Tensor: r"""Apply 2-dimensional Fast Fourier Transform. Parameters ---------- x : torch.Tensor Complex valued input data. centered : bool Whether to center the fft. If True, the fft will be shifted so that the zero frequency component is in the center of the spectrum. Default is ``False``. normalization : str Normalization mode. For the forward transform (fft2()), these correspond to: \n * ``forward`` - normalize by 1/n * ``backward`` - no normalization * ``ortho`` - normalize by 1/sqrt(n) (making the FFT orthonormal) Where n = prod(s) is the logical FFT size. Calling the backward transform (ifft2()) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make ifft2() the exact inverse. Default is ``backward`` (no normalization). spatial_dims : Sequence[int] Dimensions to apply the FFT. Default is the last two dimensions. If tensor is viewed as real, the last dimension is assumed to be the complex dimension. Returns ------- torch.Tensor The 2D FFT of the input. Examples -------- >>> import torch >>> from atommic.collections.common.parts.fft import fft2 >>> data = torch.randn(2, 3, 4, 5, 2) >>> fft2(data).shape torch.Size([2, 3, 4, 5, 2]) >>> fft2(data, centered=True, normalization="ortho", spatial_dims=[-3, -2]).shape torch.Size([2, 3, 4, 5, 2]) .. note:: The PyTorch fft2 function does not support complex tensors. Therefore, the input is converted to a complex tensor and then converted back to a real tensor. This is done by using the torch.view_as_complex and torch.view_as_real functions. The input is assumed to be a real tensor with the last dimension being the complex dimension. The PyTorch fft2 function performs a separate fft, so fft2 is the same as fft(fft(data, dim=-2), dim=-1). Source: https://pytorch.org/docs/stable/fft.html#torch.fft.fft2 """ if x.shape[-1] == 2: x = torch.view_as_complex(x) if spatial_dims is None: spatial_dims = [-2, -1] elif isinstance(spatial_dims, ListConfig): spatial_dims = list(spatial_dims) if centered: x = ifftshift(x, dim=spatial_dims) x = torch.fft.fft2( x, dim=spatial_dims, norm=normalization if normalization.lower() != "none" else None, ) if centered: x = fftshift(x, dim=spatial_dims) x = torch.view_as_real(x) return x
[docs]def ifft2( x: torch.Tensor, centered: bool = False, normalization: str = "backward", spatial_dims: Sequence[int] = None, ) -> torch.Tensor: r"""Apply 2-dimensional Inverse Fast Fourier Transform. Parameters ---------- x : torch.Tensor Complex valued input data. centered : bool Whether to center the ifft. If True, the ifft will be shifted so that the zero frequency component is in the center of the spectrum. Default is ``False``. normalization : str Normalization mode. For the backward transform (ifft2()), these correspond to: \n * ``forward`` - normalize by 1/n * ``backward`` - no normalization * ``ortho`` - normalize by 1/sqrt(n) (making the IFFT orthonormal) Where n = prod(s) is the logical IFFT size. Calling the forward transform (fft2()) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make fft2() the exact inverse. Default is ``backward`` (no normalization). spatial_dims : Sequence[int] Dimensions to apply the IFFT. Default is the last two dimensions. If tensor is viewed as real, the last dimension is assumed to be the complex dimension. Returns ------- torch.Tensor The 2D IFFT of the input. Examples -------- >>> import torch >>> from atommic.collections.common.parts.fft import ifft2 >>> data = torch.randn(2, 3, 4, 5, 2) >>> ifft2(data).shape torch.Size([2, 3, 4, 5, 2]) >>> ifft2(data, centered=True, normalization="ortho", spatial_dims=[-3, -2]).shape torch.Size([2, 3, 4, 5, 2]) .. note:: The PyTorch ifft2 function does not support complex tensors. Therefore, the input is converted to a complex tensor and then converted back to a real tensor. This is done by using the torch.view_as_complex and torch.view_as_real functions. The input is assumed to be a real tensor with the last dimension being the complex dimension. The PyTorch ifft2 function performs a separate ifft, so ifft2 is the same as ifft(ifft(data, dim=-2), dim=-1). Source: https://pytorch.org/docs/stable/fft.html#torch.fft.ifft2 """ if x.shape[-1] == 2: x = torch.view_as_complex(x) if spatial_dims is None: spatial_dims = [-2, -1] elif isinstance(spatial_dims, ListConfig): spatial_dims = list(spatial_dims) if centered: x = ifftshift(x, dim=spatial_dims) x = torch.fft.ifft2( x, dim=spatial_dims, norm=normalization if normalization.lower() != "none" else None, ) if centered: x = fftshift(x, dim=spatial_dims) x = torch.view_as_real(x) return x
[docs]def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: """Similar to roll but for only one dim. Parameters ---------- x : torch.Tensor Input data. shift : int Amount to roll. dim : int Which dimension to roll. Returns ------- torch.Tensor The rolled tensor. Examples -------- >>> import torch >>> from atommic.collections.common.parts.fft import roll_one_dim >>> data = torch.randn(2, 3, 4, 5) >>> roll_one_dim(data, 1, 0).shape torch.Size([2, 3, 4, 5]) .. note:: Source: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py """ shift %= x.size(dim) if shift == 0: return x left = x.narrow(dim, 0, x.size(dim) - shift) right = x.narrow(dim, x.size(dim) - shift, shift) return torch.cat((right, left), dim=dim)
[docs]def roll(x: torch.Tensor, shift: List[int], dim: Union[List[int], Sequence[int]]) -> torch.Tensor: """Similar to np.roll but applies to PyTorch Tensors. Parameters ---------- x : torch.Tensor Input data. shift : List[int] Amount to roll. dim : Union[List[int], Sequence[int]] Which dimension to roll. Returns ------- torch.Tensor The rolled tensor. Examples -------- >>> import torch >>> from atommic.collections.common.parts.fft import roll >>> data = torch.randn(2, 3, 4, 5) >>> roll(data, [1, 2], [0, 1]).shape torch.Size([2, 3, 4, 5]) .. note:: Source: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py """ if len(shift) != len(dim): raise ValueError("len(shift) must match len(dim)") if isinstance(dim, ListConfig): dim = list(dim) for s, d in zip(shift, dim): x = roll_one_dim(x, s, d) return x
[docs]def fftshift(x: torch.Tensor, dim: Union[List[int], Sequence[int]] = None) -> torch.Tensor: """Similar to np.fft.fftshift but applies to PyTorch Tensors. Parameters ---------- x : torch.Tensor Input data. dim : Union[List[int], Sequence[int]] Which dimension to shift. Returns ------- torch.Tensor The shifted tensor. Examples -------- >>> import torch >>> from atommic.collections.common.parts.fft import fftshift >>> data = torch.randn(2, 3, 4, 5) >>> fftshift(data).shape torch.Size([2, 3, 4, 5]) .. note:: Source: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py """ if dim is None: # this weird code is necessary for torch.jit.script typing dim = [0] * (x.dim()) for i in range(1, x.dim()): dim[i] = i elif isinstance(dim, ListConfig): dim = list(dim) elif not isinstance(dim, list): dim = [dim] # type: ignore # Also necessary for torch.jit.script shift = [0] * len(dim) for i, dim_num in enumerate(dim): shift[i] = np.floor_divide(x.shape[dim_num], 2) return roll(x, shift, dim)
[docs]def ifftshift(x: torch.Tensor, dim: Union[List[int], Sequence[int]] = None) -> torch.Tensor: """Similar to np.fft.ifftshift but applies to PyTorch Tensors. Parameters ---------- x : torch.Tensor Input data. dim : Union[List[int], Sequence[int]] Which dimension to shift. Returns ------- torch.Tensor The shifted tensor. Examples -------- >>> import torch >>> from atommic.collections.common.parts.fft import ifftshift >>> data = torch.randn(2, 3, 4, 5) >>> ifftshift(data).shape torch.Size([2, 3, 4, 5]) .. note:: Source: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py """ if dim is None: # this weird code is necessary for torch.jit.script typing dim = [0] * (x.dim()) for i in range(1, x.dim()): dim[i] = i elif isinstance(dim, ListConfig): dim = list(dim) elif not isinstance(dim, list): dim = [dim] # type: ignore # Also necessary for torch.jit.script shift = [0] * len(dim) for i, dim_num in enumerate(dim): shift[i] = np.floor_divide(x.shape[dim_num] + 1, 2) return roll(x, shift, dim)