Source code for atommic.collections.motioncorrection.parts.motionsimulation

# coding=utf-8
__author__ = "Dimitris Karkalousos"

# Taken and adapted from https://github.com/bduffy0/motion-correction/blob/master/layer/motion_sim.py

import math
import random
from typing import Any, Dict, Optional, Sequence, Tuple

import torch

from atommic.collections.common.parts import utils


[docs]def get_center_rect(image: torch.tensor, center_percentage: float = 0.02, dim: int = 0) -> torch.tensor: """Get a center rectangle of a given dimension. Parameters ---------- image : torch.tensor The image to get the center rectangle from. center_percentage : float The percentage of the image to take as the center rectangle. dim : int The dimension to take the center rectangle from. Returns ------- torch.tensor The center rectangle. """ shape = (image[0].item(), image[1].item()) mask = torch.zeros(shape) half_pct = center_percentage / 2 center = [int(x / 2) for x in shape] mask = torch.swapaxes(mask, 0, dim) mask[:, center[1] - math.ceil(shape[1] * half_pct) : math.ceil(center[1] + shape[1] * half_pct)] = 1 mask = torch.swapaxes(mask, 0, dim) return mask
[docs]def segment_array_by_locs(shape: Sequence[int], locations: Sequence[int]) -> torch.tensor: """Generate a segmentation mask based on a list of locations. Parameters ---------- shape : Sequence[int] The shape of the array to segment. locations : Sequence[int] The locations to segment the array into. Returns ------- torch.tensor The segmentation mask. """ mask_out = torch.zeros(torch.prod(shape), dtype=int) for i in range(len(locations) - 1): loc = [locations[i], locations[i + 1]] mask_out[loc[0] : loc[1]] = i + 1 return mask_out.reshape(shape)
[docs]def segments_to_random_indices(shape: Sequence[int], seg_lengths: Sequence[int]) -> torch.tensor: """Generate a segmentation mask based on a list of locations. Parameters ---------- shape : Sequence[int] The shape of the array to segment. seg_lengths : Sequence[int] The lengths of the segments to generate. Returns ------- torch.tensor The segmentation mask. """ random_indices = torch.randint(low=0, high=shape, size=(sum(seg_lengths),)).sort()[0] seg_mask = torch.zeros(shape).type(torch.int) seg_new_indices = torch.cumsum(torch.tensor(seg_lengths), 0).tolist() seg_new_indices = [0] + seg_new_indices for i in range(len(seg_new_indices) - 1): seg_mask[random_indices[seg_new_indices[i] : seg_new_indices[i + 1]]] = i + 1 return seg_mask
[docs]def segments_to_random_blocks(shape: Sequence[int], seg_lengths: Sequence[int]) -> torch.tensor: """Generate a segmentation mask based on a list of locations. Parameters ---------- shape : Sequence[int] The shape of the array to segment. seg_lengths : Sequence[int] The lengths of the segments to generate. Returns ------- torch.tensor The segmentation mask. """ seg_mask = torch.zeros(shape).type(torch.int) seg_lengths_sorted = sorted(seg_lengths, reverse=True) for i, seg_len in enumerate(seg_lengths_sorted): loc = torch.randint(low=0, high=seg_mask.size()[0], size=(1,)) while (sum(seg_mask[loc : loc + seg_len]) != 0) or (loc + seg_len > seg_mask.size()[0]): loc = torch.randint(low=0, high=seg_mask.size()[0], size=(1,)) seg_mask[loc : loc + seg_len] = i + 1 return seg_mask
[docs]def create_rand_partition(im_length: int, num_segments: int): """Create a random partition of an array. Parameters ---------- im_length : int The length of the array to partition. num_segments : int The number of segments to partition the array into. Returns ------- list The partition locations. """ rand_segment_locs = sorted(list(torch.randint(im_length, size=(num_segments,)))) rand_segment_locs[0] = 0 rand_segment_locs[-1] = None return rand_segment_locs
[docs]def create_rotation_matrix_3d(angles: Sequence[float]) -> torch.tensor: """Create a 3D rotation matrix. Parameters ---------- angles : Sequence[float] The angles to rotate the matrix by. Returns ------- torch.tensor The rotation matrix. """ mat1 = torch.FloatTensor( [ [1.0, 0.0, 0.0], [0.0, math.cos(angles[0]), math.sin(angles[0])], [0.0, -math.sin(angles[0]), math.cos(angles[0])], ] ) mat2 = torch.FloatTensor( [ [math.cos(angles[1]), 0.0, -math.sin(angles[1])], [0.0, 1.0, 0.0], [math.sin(angles[1]), 0.0, math.cos(angles[1])], ] ) mat3 = torch.FloatTensor( [ [math.cos(angles[2]), math.sin(angles[2]), 0.0], [-math.sin(angles[2]), math.cos(angles[2]), 0.0], [0.0, 0.0, 1.0], ] ) return (mat1 @ mat2) @ mat3
[docs]def translate_kspace(freq_domain: torch.tensor, translations: torch.tensor) -> torch.tensor: """Translate a k-space array. Parameters ---------- freq_domain : torch.tensor The k-space array to translate. translations : torch.tensor The translations to apply to the k-space array. Returns ------- torch.tensor The translated k-space array. """ lin_spaces = [torch.linspace(-0.5, 0.5, x) for x in freq_domain.shape[:-1]] meshgrids = torch.meshgrid(*lin_spaces, indexing="ij") grid_coords = torch.stack([mg.flatten() for mg in meshgrids], 0) phase_shift = torch.multiply(grid_coords, translations).sum(axis=0) # phase shift is added exp_phase_shift = torch.exp(-2j * math.pi * phase_shift).to(freq_domain.device) motion_kspace = torch.view_as_real( torch.multiply(exp_phase_shift, torch.view_as_complex(freq_domain).flatten()).reshape(freq_domain.shape[:-1]) ) return motion_kspace
[docs]class MotionSimulation: """Simulates random translations and rotations in the frequency domain. Examples -------- >>> from atommic.collections.motioncorrection.parts import MotionSimulation >>> import torch >>> motion_simulation = MotionSimulation() >>> kspace = torch.randn(1, 1, 256, 256, 2) >>> motion_kspace = motion_simulation(kspace) >>> motion_kspace.shape torch.Size([1, 1, 256, 256, 2]) """
[docs] def __init__( self, motion_type: str = "piecewise_transient", angle: float = 0, translation: float = 10, center_percentage: float = 0.02, motion_percentage: Sequence[float] = (15, 20), num_segments: int = 8, random_num_segments: bool = False, non_uniform: bool = False, spatial_dims: Sequence[int] = (-2, -1), ): """Inits :class:`MotionSimulation`. Parameters ---------- motion_type : str The motion_type of motion to simulate. angle : float The angle to rotate the k-space array by. translation : float The translation to apply to the k-space array. center_percentage : float The percentage of the k-space array to center the motion. motion_percentage : Sequence[float] The percentage of the k-space array to apply the motion. num_segments : int The number of segments to partition the k-space array into. random_num_segments : bool Whether to randomly generate the number of segments. non_uniform : bool Whether to use non-uniform sampling. spatial_dims : Sequence[int] The spatial dimensions to apply the motion to. """ self.motion_type = motion_type self.angle, self.translation = angle, translation self.center_percentage = center_percentage if motion_percentage[1] == motion_percentage[0]: motion_percentage[1] += 1 # type: ignore elif motion_percentage[1] < motion_percentage[0]: raise ValueError("Uniform is not defined when low>= high.") self.motion_percentage = motion_percentage self.spatial_dims = spatial_dims self._spatial_dims = random.choice(spatial_dims) self.num_segments = num_segments self.random_num_segments = random_num_segments if non_uniform: raise NotImplementedError("NUFFT is not implemented. This is a feature to be added in the future.") self.trajectory = None self.params: Dict[Any, Any] = {}
def _calc_dimensions(self, shape): """Calculate the dimensions to apply the motion to. Parameters ---------- shape : Sequence[int] The shape of the image. Returns ------- Sequence[int] The dimensions to apply the motion to. """ pe_dims = [0, 1, 2] pe_dims.pop(self._spatial_dims) self.phase_encoding_dims = pe_dims shape = list(shape) if shape[-1] == 2: shape = shape[:-1] self.shape = shape.copy() shape.pop(self._spatial_dims) self.phase_encoding_shape = torch.tensor(shape) self.num_phase_encoding_steps = self.phase_encoding_shape[0] * self.phase_encoding_shape[1] self._spatial_dims = len(self.shape) - 1 if self._spatial_dims == -1 else self._spatial_dims def _generate_random_segments(self): """Generate random segments.""" pct_corrupt = torch.distributions.Uniform(*[x / 100 for x in self.motion_percentage]).sample((1, 1)) corrupt_matrix_shape = torch.tensor([int(x * math.sqrt(pct_corrupt)) for x in self.phase_encoding_shape]) if torch.prod(corrupt_matrix_shape) == 0: corrupt_matrix_shape = [1, 1] if self.motion_type in {"gaussian"}: num_segments = torch.prod(corrupt_matrix_shape) else: if not self.random_num_segments: num_segments = self.num_segments else: num_segments = random.randint(1, self.num_segments) # segment a smaller vector occupying pct_corrupt percent of the space if self.motion_type in {"piecewise_transient", "piecewise_constant"}: seg_locs = create_rand_partition(torch.prod(corrupt_matrix_shape), num_segments=num_segments) else: seg_locs = list(range(num_segments)) rand_segmentation = segment_array_by_locs(shape=torch.prod(corrupt_matrix_shape), locations=seg_locs) seg_lengths = [(rand_segmentation == seg_num).sum() for seg_num in torch.unique(rand_segmentation)] # assign segments to a vector with same number of elements as pe-steps if self.motion_type in {"piecewise_transient", "gaussian"}: seg_vector = segments_to_random_indices(torch.prod(self.phase_encoding_shape), seg_lengths) else: seg_vector = segments_to_random_blocks(torch.prod(self.phase_encoding_shape), seg_lengths) # reshape to phase encoding shape with a random order reshape_order = random.choice(["F", "C"]) if reshape_order == "F": seg_array = utils.reshape_fortran( seg_vector, (self.phase_encoding_shape[0].item(), self.phase_encoding_shape[1].item()) ) else: seg_array = seg_vector.reshape((self.phase_encoding_shape[0].item(), self.phase_encoding_shape[1].item())) self.order = reshape_order # mask center k-space mask_not_including_center = ( get_center_rect( self.phase_encoding_shape, center_percentage=self.center_percentage, dim=1 if reshape_order == "C" else 0, ) == 0 ) self.seg_array = seg_array * mask_not_including_center self.num_segments = num_segments def _get_motion_trajectory(self, translation_rotation=None, random_segments=True): """Obtain a motion trajectory. Returns ------- torch.tensor The random trajectory. """ if random_segments: self._generate_random_segments() else: raise NotImplementedError("Custom segments (masks) not supported") if not translation_rotation: translations, rotations = self._simulate_random_trajectory() else: (translations, rotations) = translation_rotation if translations.shape[0] != self.num_segments: translations = torch.cat((torch.tensor([[0, 0, 0]]), translations), dim=0) if rotations.shape[0] != self.num_segments: rotations = torch.cat((torch.tensor([[0, 0, 0]]), rotations), dim=0) # if segment==0, then no motion translations[0, :] = 0 rotations[0, :] = 0 # lookup values for each segment translations_pe = [translations[:, i][self.seg_array.long()] for i in range(3)] rotations_pe = [rotations[:, i][self.seg_array.long()] for i in range(3)] # reshape and convert to radians translations = torch.stack( [torch.broadcast_to(x.unsqueeze(self._spatial_dims), self.shape) for x in translations_pe], 0 ) rotations = torch.stack( [torch.broadcast_to(x.unsqueeze(self._spatial_dims), self.shape) for x in rotations_pe], 0 ) rotations = rotations * (math.pi / 180.0) # convert to radians self.translations = translations.reshape(3, -1) self.rotations = rotations.reshape(3, -1).reshape(3, -1) def _simulate_random_trajectory(self): """Simulate a random trajectory.""" # generate random translations and rotations rand_translations = torch.distributions.normal.Normal(loc=0, scale=self.translation).sample( (self.num_segments, 3) ) rand_rotations = torch.distributions.normal.Normal(loc=0, scale=self.angle).sample((self.num_segments, 3)) return rand_translations, rand_rotations
[docs] def forward( self, kspace, translations_rotations: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, apply_backward_transform: bool = False, # pylint: disable=unused-argument apply_forward_transform: bool = False, # pylint: disable=unused-argument ) -> torch.Tensor: """Forward pass of :class:`MotionSimulation`. Parameters ---------- kspace : torch.Tensor The kspace to apply the motion to. translations_rotations : Optional[Tuple[torch.Tensor, torch.Tensor]] The translations and rotations to apply to the kspace. If None, a random trajectory is generated. apply_backward_transform : bool Placeholder for the backward transform. Generalizes the Composer, but not used. apply_forward_transform : bool Placeholder for the forward transform. Generalizes the Composer, but not used. Returns ------- torch.Tensor The kspace with the motion applied. """ self._calc_dimensions(kspace.shape) self._get_motion_trajectory(translations_rotations) motion_kspace = translate_kspace(freq_domain=kspace, translations=self.translations) return motion_kspace
def __call__(self, *args, **kwargs): """Call :class:`MotionSimulation`.""" return self.forward(*args, **kwargs)