FFT#

atommic.collections.common.parts.fft.fft2(x: torch.Tensor, centered: bool = False, normalization: str = 'backward', spatial_dims: Optional[Sequence[int]] = None) torch.Tensor[source]#

Apply 2-dimensional Fast Fourier Transform.

Parameters
  • x (torch.Tensor) – Complex valued input data.

  • centered (bool) – Whether to center the fft. If True, the fft will be shifted so that the zero frequency component is in the center of the spectrum. Default is False.

  • normalization (str) –

    Normalization mode. For the forward transform (fft2()), these correspond to: n
    • forward - normalize by 1/n

    • backward - no normalization

    • ortho - normalize by 1/sqrt(n) (making the FFT orthonormal)

    Where n = prod(s) is the logical FFT size. Calling the backward transform (ifft2()) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make ifft2() the exact inverse. Default is backward (no normalization).

  • spatial_dims (Sequence[int]) – Dimensions to apply the FFT. Default is the last two dimensions. If tensor is viewed as real, the last dimension is assumed to be the complex dimension.

Returns

The 2D FFT of the input.

Return type

torch.Tensor

Examples

>>> import torch
>>> from atommic.collections.common.parts.fft import fft2
>>> data = torch.randn(2, 3, 4, 5, 2)
>>> fft2(data).shape
torch.Size([2, 3, 4, 5, 2])
>>> fft2(data, centered=True, normalization="ortho", spatial_dims=[-3, -2]).shape
torch.Size([2, 3, 4, 5, 2])

Note

The PyTorch fft2 function does not support complex tensors. Therefore, the input is converted to a complex tensor and then converted back to a real tensor. This is done by using the torch.view_as_complex and torch.view_as_real functions. The input is assumed to be a real tensor with the last dimension being the complex dimension.

The PyTorch fft2 function performs a separate fft, so fft2 is the same as fft(fft(data, dim=-2), dim=-1).

Source: https://pytorch.org/docs/stable/fft.html#torch.fft.fft2

atommic.collections.common.parts.fft.ifft2(x: torch.Tensor, centered: bool = False, normalization: str = 'backward', spatial_dims: Optional[Sequence[int]] = None) torch.Tensor[source]#

Apply 2-dimensional Inverse Fast Fourier Transform.

Parameters
  • x (torch.Tensor) – Complex valued input data.

  • centered (bool) – Whether to center the ifft. If True, the ifft will be shifted so that the zero frequency component is in the center of the spectrum. Default is False.

  • normalization (str) –

    Normalization mode. For the backward transform (ifft2()), these correspond to: n
    • forward - normalize by 1/n

    • backward - no normalization

    • ortho - normalize by 1/sqrt(n) (making the IFFT orthonormal)

    Where n = prod(s) is the logical IFFT size. Calling the forward transform (fft2()) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make fft2() the exact inverse. Default is backward (no normalization).

  • spatial_dims (Sequence[int]) – Dimensions to apply the IFFT. Default is the last two dimensions. If tensor is viewed as real, the last dimension is assumed to be the complex dimension.

Returns

The 2D IFFT of the input.

Return type

torch.Tensor

Examples

>>> import torch
>>> from atommic.collections.common.parts.fft import ifft2
>>> data = torch.randn(2, 3, 4, 5, 2)
>>> ifft2(data).shape
torch.Size([2, 3, 4, 5, 2])
>>> ifft2(data, centered=True, normalization="ortho", spatial_dims=[-3, -2]).shape
torch.Size([2, 3, 4, 5, 2])

Note

The PyTorch ifft2 function does not support complex tensors. Therefore, the input is converted to a complex tensor and then converted back to a real tensor. This is done by using the torch.view_as_complex and torch.view_as_real functions. The input is assumed to be a real tensor with the last dimension being the complex dimension.

The PyTorch ifft2 function performs a separate ifft, so ifft2 is the same as ifft(ifft(data, dim=-2), dim=-1).

Source: https://pytorch.org/docs/stable/fft.html#torch.fft.ifft2

atommic.collections.common.parts.fft.fftshift(x: torch.Tensor, dim: Optional[Union[List[int], Sequence[int]]] = None) torch.Tensor[source]#

Similar to np.fft.fftshift but applies to PyTorch Tensors.

Parameters
  • x (torch.Tensor) – Input data.

  • dim (Union[List[int], Sequence[int]]) – Which dimension to shift.

Returns

The shifted tensor.

Return type

torch.Tensor

Examples

>>> import torch
>>> from atommic.collections.common.parts.fft import fftshift
>>> data = torch.randn(2, 3, 4, 5)
>>> fftshift(data).shape
torch.Size([2, 3, 4, 5])
atommic.collections.common.parts.fft.ifftshift(x: torch.Tensor, dim: Optional[Union[List[int], Sequence[int]]] = None) torch.Tensor[source]#

Similar to np.fft.ifftshift but applies to PyTorch Tensors.

Parameters
  • x (torch.Tensor) – Input data.

  • dim (Union[List[int], Sequence[int]]) – Which dimension to shift.

Returns

The shifted tensor.

Return type

torch.Tensor

Examples

>>> import torch
>>> from atommic.collections.common.parts.fft import ifftshift
>>> data = torch.randn(2, 3, 4, 5)
>>> ifftshift(data).shape
torch.Size([2, 3, 4, 5])
atommic.collections.common.parts.fft.roll(x: torch.Tensor, shift: List[int], dim: Union[List[int], Sequence[int]]) torch.Tensor[source]#

Similar to np.roll but applies to PyTorch Tensors.

Parameters
  • x (torch.Tensor) – Input data.

  • shift (List[int]) – Amount to roll.

  • dim (Union[List[int], Sequence[int]]) – Which dimension to roll.

Returns

The rolled tensor.

Return type

torch.Tensor

Examples

>>> import torch
>>> from atommic.collections.common.parts.fft import roll
>>> data = torch.randn(2, 3, 4, 5)
>>> roll(data, [1, 2], [0, 1]).shape
torch.Size([2, 3, 4, 5])
atommic.collections.common.parts.fft.roll_one_dim(x: torch.Tensor, shift: int, dim: int) torch.Tensor[source]#

Similar to roll but for only one dim.

Parameters
  • x (torch.Tensor) – Input data.

  • shift (int) – Amount to roll.

  • dim (int) – Which dimension to roll.

Returns

The rolled tensor.

Return type

torch.Tensor

Examples

>>> import torch
>>> from atommic.collections.common.parts.fft import roll_one_dim
>>> data = torch.randn(2, 3, 4, 5)
>>> roll_one_dim(data, 1, 0).shape
torch.Size([2, 3, 4, 5])