Source code for atommic.collections.multitask.rs.parts.transforms

# coding=utf-8
__author__ = "Dimitris Karkalousos"

from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch

from atommic.collections.common.parts.fft import ifft2
from atommic.collections.common.parts.transforms import (
    N2R,
    SSDU,
    Composer,
    Cropper,
    EstimateCoilSensitivityMaps,
    GeometricDecompositionCoilCompression,
    Masker,
    NoisePreWhitening,
    Normalizer,
    ZeroFillingPadding,
)
from atommic.collections.common.parts.utils import add_coil_dim_if_singlecoil
from atommic.collections.common.parts.utils import coil_combination_method as coil_combination_method_func
from atommic.collections.common.parts.utils import is_none, to_tensor
from atommic.collections.motioncorrection.parts.motionsimulation import MotionSimulation

__all__ = ["RSMRIDataTransforms"]


[docs]class RSMRIDataTransforms: """Data transforms for accelerated-MRI reconstruction and MRI segmentation. Returns ------- RSMRIDataTransforms Preprocessed data for accelerated-MRI reconstruction and MRI segmentation. """
[docs] def __init__( self, complex_data: bool = True, segmentation_mode: str = "multilabel", dataset_format: str = None, apply_prewhitening: bool = False, find_patch_size: bool = True, prewhitening_scale_factor: float = 1.0, prewhitening_patch_start: int = 10, prewhitening_patch_length: int = 30, apply_gcc: bool = False, gcc_virtual_coils: int = 10, gcc_calib_lines: int = 24, gcc_align_data: bool = True, apply_random_motion: bool = False, random_motion_type: str = "gaussian", random_motion_percentage: Sequence[int] = (10, 10), random_motion_angle: int = 10, random_motion_translation: int = 10, random_motion_center_percentage: float = 0.02, random_motion_num_segments: int = 8, random_motion_random_num_segments: bool = True, random_motion_non_uniform: bool = False, estimate_coil_sensitivity_maps: bool = False, coil_sensitivity_maps_type: str = "ESPIRiT", coil_sensitivity_maps_gaussian_sigma: float = 0.0, coil_sensitivity_maps_espirit_threshold: float = 0.05, coil_sensitivity_maps_espirit_kernel_size: int = 6, coil_sensitivity_maps_espirit_crop: float = 0.95, coil_sensitivity_maps_espirit_max_iters: int = 30, coil_combination_method: str = "SENSE", dimensionality: int = 2, mask_func: Optional[List] = None, shift_mask: bool = False, mask_center_scale: Optional[float] = 0.02, partial_fourier_percentage: float = 0.0, remask: bool = False, ssdu: bool = False, ssdu_mask_type: str = "Gaussian", ssdu_rho: float = 0.4, ssdu_acs_block_size: Sequence[int] = (4, 4), ssdu_gaussian_std_scaling_factor: float = 4.0, ssdu_outer_kspace_fraction: float = 0.0, ssdu_export_and_reuse_masks: bool = False, n2r: bool = False, n2r_supervised_rate: float = 0.0, n2r_probability: float = 0.0, n2r_std_devs: Tuple[float, float] = None, n2r_rhos: Tuple[float, float] = None, n2r_use_mask: bool = False, unsupervised_masked_target: bool = False, crop_size: Optional[Tuple[int, int]] = None, kspace_crop: bool = False, crop_before_masking: bool = True, kspace_zero_filling_size: Optional[Tuple] = None, normalize_inputs: bool = True, normalization_type: str = "max", kspace_normalization: bool = False, fft_centered: bool = False, fft_normalization: str = "backward", spatial_dims: Sequence[int] = None, coil_dim: int = 0, consecutive_slices: int = 1, use_seed: bool = True, ): """Inits :class:`RSMRIDataTransforms`. Parameters ---------- complex_data : bool, optional Whether to use complex data. If ``False`` the data are assumed to be magnitude only. Default is ``True``. segmentation_mode: str, optional Defines the segmentation labels model, either ``multiclass``or ``multilabel``. In ``multiclass`` mode, only one class is assigned per voxel. In ``multilabel`` mode, multiple (overlapping) classes are allowed per voxel. Default is ``multilabel``. dataset_format : str, optional The format of the dataset. For example, ``'custom_dataset'`` or ``'public_dataset_name'``. Default is ``None``. apply_prewhitening : bool, optional Apply prewhitening. If ``True`` then the prewhitening arguments are used. Default is ``False``. find_patch_size : bool, optional Find optimal patch size (automatically) to calculate psi. If False, patch_size must be defined. Default is ``True``. prewhitening_scale_factor : float, optional Prewhitening scale factor. Default is ``1.0``. prewhitening_patch_start : int, optional Prewhitening patch start. Default is ``10``. prewhitening_patch_length : int, optional Prewhitening patch length. Default is ``30``. apply_gcc : bool, optional Apply Geometric Decomposition Coil Compression. If ``True`` then the GCC arguments are used. Default is ``False``. gcc_virtual_coils : int, optional GCC virtual coils. Default is ``10``. gcc_calib_lines : int, optional GCC calibration lines. Default is ``24``. gcc_align_data : bool, optional GCC align data. Default is ``True``. apply_random_motion : bool, optional Simulate random motion in k-space. Default is ``False``. random_motion_type : str, optional Random motion type. It can be one of the following: ``piecewise_transient``, ``piecewise_constant``, ``gaussian``. Default is ``gaussian``. random_motion_percentage : Sequence[int], optional Random motion percentage. For example, 10%-20% motion can be defined as ``(10, 20)``. Default is ``(10, 10)``. random_motion_angle : float, optional Random motion angle. Default is ``10.0``. random_motion_translation : float, optional Random motion translation. Default is ``10.0``. random_motion_center_percentage : float, optional Random motion center percentage. Default is ``0.0``. random_motion_num_segments : int, optional Random motion number of segments to partition the k-space. Default is ``8``. random_motion_random_num_segments : bool, optional Whether to randomly generate the number of segments. Default is ``True``. random_motion_non_uniform : bool, optional Random motion non-uniform sampling. Default is ``False``. estimate_coil_sensitivity_maps : bool, optional Automatically estimate coil sensitivity maps. Default is ``False``. If ``True`` then the coil sensitivity maps arguments are used. Note that this is different from the ``estimate_coil_sensitivity_maps_with_nn`` argument, which uses a neural network to estimate the coil sensitivity maps. The ``estimate_coil_sensitivity_maps`` estimates the coil sensitivity maps with methods such as ``ESPIRiT``, ``RSS`` or ``UNit``. ``ESPIRiT`` is the ``Eigenvalue to Self-Consistent Parallel Imaging Reconstruction Technique`` method. ``RSS`` is the ``Root Sum of Squares`` method. ``UNit`` returns a uniform coil sensitivity map. coil_sensitivity_maps_type : str, optional Coil sensitivity maps type. It can be one of the following: ``ESPIRiT``, ``RSS`` or ``UNit``. Default is ``ESPIRiT``. coil_sensitivity_maps_gaussian_sigma : float, optional Coil sensitivity maps Gaussian sigma. Default is ``0.0``. coil_sensitivity_maps_espirit_threshold : float, optional Coil sensitivity maps ESPRIT threshold. Default is ``0.05``. coil_sensitivity_maps_espirit_kernel_size : int, optional Coil sensitivity maps ESPRIT kernel size. Default is ``6``. coil_sensitivity_maps_espirit_crop : float, optional Coil sensitivity maps ESPRIT crop. Default is ``0.95``. coil_sensitivity_maps_espirit_max_iters : int, optional Coil sensitivity maps ESPRIT max iterations. Default is ``30``. coil_combination_method : str, optional Coil combination method. Default is ``"SENSE"``. dimensionality : int, optional Dimensionality. Default is ``2``. mask_func : Optional[List["MaskFunc"]], optional Mask function to retrospectively undersample the k-space. Default is ``None``. shift_mask : bool, optional Whether to shift the mask. This needs to be set alongside with the ``fft_centered`` argument. Default is ``False``. mask_center_scale : Optional[float], optional Center scale of the mask. This defines how much densely sampled will be the center of k-space. Default is ``0.02``. partial_fourier_percentage : float, optional Whether to simulate a half scan. Default is ``0.0``. remask : bool, optional Use the same mask. Default is ``False``. ssdu : bool, optional Whether to apply Self-Supervised Data Undersampling (SSDU) masks. Default is ``False``. ssdu_mask_type: str, optional Mask type. It can be one of the following: - "Gaussian": Gaussian sampling. - "Uniform": Uniform sampling. Default is "Gaussian". ssdu_rho: float, optional Split ratio for training and loss masks. Default is ``0.4``. ssdu_acs_block_size: tuple, optional Keeps a small acs region fully-sampled for training masks, if there is no acs region. The small acs block should be set to zero. Default is ``(4, 4)``. ssdu_gaussian_std_scaling_factor: float, optional Scaling factor for standard deviation of the Gaussian noise. If Uniform is select this factor is ignored. Default is ``4.0``. ssdu_outer_kspace_fraction: float, optional Fraction of the outer k-space to be kept/unmasked. Default is ``0.0``. ssdu_export_and_reuse_masks: bool, optional Whether to export and reuse the masks. Default is ``False``. n2r : bool, optional Whether to apply Noise to Reconstruction (N2R) masks. Default is ``False``. n2r_supervised_rate : Optional[float], optional A float between 0 and 1. This controls what fraction of the subjects should be loaded for Noise to Reconstruction (N2R) supervised loss, if N2R is enabled. Default is ``0.0``. n2r_probability : float, optional Probability of applying N2R. Default is ``0.0``. n2r_std_devs : Tuple[float, float], optional Standard deviations for the noise. Default is ``(0.0, 0.0)``. n2r_rhos : Tuple[float, float], optional Rho values for the noise. Default is ``(0.0, 0.0)``. n2r_use_mask : bool, optional Whether to use a mask for N2R. Default is ``False``. unsupervised_masked_target : bool, optional Whether to use the masked initial estimation for unsupervised learning. Default is ``False``. crop_size : Optional[Tuple[int, int]], optional Center crop size. It applies cropping in image space. Default is ``None``. kspace_crop : bool, optional Whether to crop in k-space. Default is ``False``. crop_before_masking : bool, optional Whether to crop before masking. Default is ``True``. kspace_zero_filling_size : Optional[Tuple], optional Whether to apply zero filling in k-space. Default is ``None``. normalize_inputs : bool, optional Whether to normalize the inputs. Default is ``True``. normalization_type : str, optional Normalization type. Can be ``max`` or ``mean`` or ``minmax``. Default is ``max``. kspace_normalization : bool, optional Whether to normalize the k-space. Default is ``False``. fft_centered : bool, optional Whether to center the FFT. Default is ``False``. fft_normalization : str, optional FFT normalization. Default is ``"backward"``. spatial_dims : Sequence[int], optional Spatial dimensions. Default is ``None``. coil_dim : int, optional Coil dimension. Default is ``0``, meaning that the coil dimension is the first dimension before applying batch. consecutive_slices : int, optional Consecutive slices. Default is ``1``. use_seed : bool, optional Whether to use seed. Default is ``True``. """ self.complex_data = complex_data self.dataset_format = dataset_format self.fft_centered = fft_centered self.fft_normalization = fft_normalization self.spatial_dims = spatial_dims if spatial_dims is not None else [-2, -1] self.coil_dim = coil_dim - 1 if dimensionality == 2 and not is_none(coil_dim) else coil_dim if not self.complex_data: if not is_none(coil_combination_method): raise ValueError("Coil combination method for non-complex data should be None.") if not is_none(mask_func): raise ValueError("Mask function for non-complex data should be None.") self.kspace_crop = kspace_crop if self.kspace_crop: raise ValueError("K-space crop for non-complex data should be False.") if not is_none(kspace_zero_filling_size): raise ValueError("K-space zero filling size for non-complex data should be None.") if not is_none(coil_dim): raise ValueError("Coil dimension for non-complex data should be None.") if apply_prewhitening: raise ValueError("Prewhitening for non-complex data cannot be applied.") if apply_gcc: raise ValueError("GCC for non-complex data cannot be applied.") if apply_random_motion: raise ValueError("Random motion for non-complex data cannot be applied.") else: self.prewhitening = ( NoisePreWhitening( find_patch_size=find_patch_size, patch_size=[ prewhitening_patch_start, prewhitening_patch_length + prewhitening_patch_start, prewhitening_patch_start, prewhitening_patch_length + prewhitening_patch_start, ], scale_factor=prewhitening_scale_factor, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if apply_prewhitening else None ) self.gcc = ( GeometricDecompositionCoilCompression( virtual_coils=gcc_virtual_coils, calib_lines=gcc_calib_lines, align_data=gcc_align_data, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if apply_gcc else None ) self.random_motion = ( MotionSimulation( motion_type=random_motion_type, angle=random_motion_angle, translation=random_motion_translation, center_percentage=random_motion_center_percentage, motion_percentage=random_motion_percentage, num_segments=random_motion_num_segments, random_num_segments=random_motion_random_num_segments, non_uniform=random_motion_non_uniform, spatial_dims=self.spatial_dims, ) if apply_random_motion else None ) self.coil_sensitivity_maps_estimator = ( EstimateCoilSensitivityMaps( coil_sensitivity_maps_type=coil_sensitivity_maps_type.lower(), gaussian_sigma=coil_sensitivity_maps_gaussian_sigma, espirit_threshold=coil_sensitivity_maps_espirit_threshold, espirit_kernel_size=coil_sensitivity_maps_espirit_kernel_size, espirit_crop=coil_sensitivity_maps_espirit_crop, espirit_max_iters=coil_sensitivity_maps_espirit_max_iters, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, coil_dim=self.coil_dim, ) if estimate_coil_sensitivity_maps else None ) self.kspace_zero_filling = ( ZeroFillingPadding( zero_filling_size=kspace_zero_filling_size, # type: ignore fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if not is_none(kspace_zero_filling_size) else None ) self.shift_mask = shift_mask self.masking = Masker( mask_func=mask_func, # type: ignore spatial_dims=self.spatial_dims, shift_mask=shift_mask, partial_fourier_percentage=partial_fourier_percentage, center_scale=mask_center_scale, # type: ignore dimensionality=dimensionality, remask=remask, dataset_format=self.dataset_format, ) self.ssdu = ssdu self.ssdu_masking = ( SSDU( mask_type=ssdu_mask_type, rho=ssdu_rho, acs_block_size=ssdu_acs_block_size, gaussian_std_scaling_factor=ssdu_gaussian_std_scaling_factor, outer_kspace_fraction=ssdu_outer_kspace_fraction, export_and_reuse_masks=ssdu_export_and_reuse_masks, ) if self.ssdu else None ) self.n2r = n2r self.n2r_supervised_rate = n2r_supervised_rate self.n2r_masking = ( N2R( probability=n2r_probability, std_devs=n2r_std_devs, # type: ignore rhos=n2r_rhos, # type: ignore use_mask=n2r_use_mask, ) if self.n2r else None ) self.unsupervised_masked_target = unsupervised_masked_target self.kspace_crop = kspace_crop self.crop_before_masking = crop_before_masking self.coil_combination_method = coil_combination_method self.prewhitening = Composer([self.prewhitening]) # type: ignore self.coils_shape_transforms = Composer( [ self.gcc, # type: ignore self.kspace_zero_filling, # type: ignore ] ) self.random_motion = Composer([self.random_motion]) # type: ignore self.cropping = ( Cropper( cropping_size=crop_size, # type: ignore fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if not is_none(crop_size) else None ) self.normalization_type = normalization_type self.normalization = ( Normalizer( normalization_type=self.normalization_type, kspace_normalization=kspace_normalization, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if normalize_inputs else None ) self.crop_normalize = Composer( [ self.cropping, # type: ignore self.normalization, # type: ignore ] ) self.consecutive_slices = consecutive_slices self.segmentation_mode = segmentation_mode self.cropping = Composer([self.cropping]) # type: ignore self.normalization = Composer([self.normalization]) # type: ignore self.use_seed = use_seed
def __call__( self, kspace: np.ndarray, imspace: np.ndarray, sensitivity_map: np.ndarray, mask: np.ndarray, initial_prediction_reconstruction: np.ndarray, segmentation_labels: np.ndarray, attrs: Dict, fname: str, slice_idx: int, ) -> Tuple[ Union[torch.Tensor, List[torch.Tensor]], Union[List[torch.Tensor], torch.Tensor], torch.Tensor, Union[List[torch.Tensor], torch.Tensor], Union[List[torch.Tensor], torch.Tensor], torch.tensor, torch.tensor, str, int, Union[List[Union[float, torch.Tensor, Any]]], Dict, ]: """Calls :class:`RSMRIDataTransforms`. Parameters ---------- kspace : np.ndarray The fully-sampled kspace, if exists. Otherwise, the subsampled kspace. imspace : np.ndarray The image space for segmentation, if exists. sensitivity_map : np.ndarray The coil sensitivity map. mask : np.ndarray The subsampling mask, if exists, meaning that the data are either prospectively undersampled or the mask is stored and loaded. initial_prediction_reconstruction : np.ndarray The initial prediction, if exists. Otherwise, it will be estimated with the chosen coil combination method. segmentation_labels : np.ndarray The segmentation labels. attrs : Dict The attributes, if stored in the data. fname : str The file name. slice_idx : int The slice index. """ initial_prediction_reconstruction = ( to_tensor(initial_prediction_reconstruction) if initial_prediction_reconstruction is not None and initial_prediction_reconstruction.size != 0 else torch.tensor([]) ) if not self.complex_data: kspace = torch.empty([]) kspace_pre_normalization_vars = None sensitivity_map = torch.empty([]) sensitivity_pre_normalization_vars = None masked_kspace = torch.empty([]) mask = torch.empty([]) acc = torch.empty([]) ( initial_prediction_reconstruction, initial_prediction_pre_normalization_vars, ) = self.__initialize_prediction__(imspace, kspace, sensitivity_map) if "min" in attrs: initial_prediction_pre_normalization_vars["min"] = attrs["min"] if "max" in attrs: initial_prediction_pre_normalization_vars["max"] = attrs["max"] if "mean" in attrs: initial_prediction_pre_normalization_vars["mean"] = attrs["mean"] if "std" in attrs: initial_prediction_pre_normalization_vars["std"] = attrs["std"] noise_prediction_pre_normalization_vars = None target_reconstruction = initial_prediction_reconstruction target_pre_normalization_vars = initial_prediction_pre_normalization_vars else: kspace, masked_kspace, mask, kspace_pre_normalization_vars, acc = self.__process_kspace__( # type: ignore kspace, mask, attrs, fname ) sensitivity_map, sensitivity_pre_normalization_vars = self.__process_coil_sensitivities_map__( sensitivity_map, kspace ) target_reconstruction, target_pre_normalization_vars = self.__initialize_prediction__( torch.empty([]), kspace, sensitivity_map ) target_prediction_pre_normalization_vars = None if self.n2r and len(masked_kspace) > 1: ( initial_prediction_reconstruction, initial_prediction_pre_normalization_vars, ) = self.__initialize_prediction__( initial_prediction_reconstruction, masked_kspace[0], sensitivity_map ) if isinstance(masked_kspace, list) and not masked_kspace[1][0].dim() < 2: noise_prediction, noise_prediction_pre_normalization_vars = self.__initialize_prediction__( None, masked_kspace[1], sensitivity_map ) else: noise_prediction = torch.tensor([]) noise_prediction_pre_normalization_vars = None initial_prediction_reconstruction = [initial_prediction_reconstruction, noise_prediction] else: ( initial_prediction_reconstruction, initial_prediction_pre_normalization_vars, ) = self.__initialize_prediction__(initial_prediction_reconstruction, masked_kspace, sensitivity_map) noise_prediction_pre_normalization_vars = None if self.unsupervised_masked_target: target_reconstruction, target_prediction_pre_normalization_vars = ( initial_prediction_reconstruction, noise_prediction_pre_normalization_vars, ) else: target_reconstruction, target_prediction_pre_normalization_vars = self.__initialize_prediction__( None if self.ssdu else target_prediction_pre_normalization_vars, kspace, sensitivity_map ) if not is_none(segmentation_labels) and segmentation_labels.ndim > 1: segmentation_labels = self.cropping(torch.from_numpy(segmentation_labels)) # type: ignore else: segmentation_labels = torch.empty([]) # if segmentation_labels is Bool type, convert to float if segmentation_labels.dtype == torch.bool: segmentation_labels = segmentation_labels.float() segmentation_labels = torch.abs(segmentation_labels) if self.segmentation_mode == "multiclass": # Ensures background class is explicitly added when performing multiclass segmentation -> final total # number of classes should be N + 1 if self.consecutive_slices > 1 and not torch.all(torch.sum(segmentation_labels[0], dim=0) == 1): segmentation_labels_bg = torch.zeros( (segmentation_labels.shape[0], segmentation_labels.shape[2], segmentation_labels.shape[3]) ) segmentation_labels_new = torch.zeros( ( segmentation_labels.shape[0], segmentation_labels.shape[1] + 1, segmentation_labels.shape[2], segmentation_labels.shape[3], ) ) for i in range(target_reconstruction.shape[0]): idx_background = torch.where(torch.sum(segmentation_labels[i], dim=0) == 0) segmentation_labels_bg[i][idx_background] = 1 segmentation_labels_new[i] = torch.concat( (segmentation_labels_bg[i].unsqueeze(0), segmentation_labels[i]), dim=0 ) segmentation_labels = segmentation_labels_new elif not torch.all(torch.sum(segmentation_labels, dim=0) == 1): segmentation_labels_bg = torch.zeros((segmentation_labels.shape[-2], segmentation_labels.shape[-1])) idx_background = torch.where(torch.sum(segmentation_labels, dim=0) == 0) segmentation_labels_bg[idx_background] = 1 segmentation_labels = torch.concat((segmentation_labels_bg.unsqueeze(0), segmentation_labels), dim=0) attrs.update( self.__parse_normalization_vars__( kspace_pre_normalization_vars, sensitivity_pre_normalization_vars, initial_prediction_pre_normalization_vars, noise_prediction_pre_normalization_vars, target_pre_normalization_vars, ) ) attrs["fname"] = fname attrs["slice_idx"] = slice_idx return ( kspace, masked_kspace, sensitivity_map, mask, initial_prediction_reconstruction, target_reconstruction, segmentation_labels, fname, slice_idx, acc, attrs, ) def __repr__(self) -> str: """Representation of :class:`RSMRIDataTransforms`.""" return ( f"Preprocessing transforms initialized for {self.__class__.__name__}: " f"prewhitening = {self.prewhitening}, " f"masking = {self.masking}, " f"SSDU masking = {self.ssdu_masking}, " f"kspace zero-filling = {self.kspace_zero_filling}, " f"cropping = {self.cropping}, " f"normalization = {self.normalization}, " ) def __str__(self) -> str: """String representation of :class:`RSMRIDataTransforms`.""" return self.__repr__() def __process_kspace__( # noqa: MC0001 self, kspace: np.ndarray, mask: Union[np.ndarray, None], attrs: Dict, fname: str ) -> Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor], Union[List[torch.Tensor], torch.Tensor], int]: """Apply the preprocessing transforms to the kspace. Parameters ---------- kspace : torch.Tensor The kspace. mask : torch.Tensor The mask, if None, the mask is generated. attrs : Dict The attributes, if stored in the file. fname : str The file name. Returns ------- Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor], Union[List[torch.Tensor], torch.Tensor], int] The transformed (fully-sampled) kspace, the masked kspace, the mask, the attributes and the acceleration factor. """ kspace = to_tensor(kspace) kspace = add_coil_dim_if_singlecoil(kspace, dim=self.coil_dim) kspace = self.coils_shape_transforms(kspace, apply_backward_transform=True) kspace = self.prewhitening(kspace) # type: ignore if self.crop_before_masking: kspace = self.cropping(kspace, apply_backward_transform=not self.kspace_crop) # type: ignore masked_kspace, mask, acc = self.masking( self.random_motion(kspace), # type: ignore mask, ( attrs["padding_left"] if "padding_left" in attrs else 0, attrs["padding_right"] if "padding_right" in attrs else 0, ), tuple(map(ord, fname)) if self.use_seed else None, # type: ignore ) if not self.crop_before_masking: kspace = self.cropping(kspace, apply_backward_transform=not self.kspace_crop) # type: ignore masked_kspace = self.cropping(masked_kspace, apply_backward_transform=not self.kspace_crop) # type: ignore mask = self.cropping(mask) # type: ignore init_kspace = kspace init_masked_kspace = masked_kspace init_mask = mask if isinstance(kspace, list): kspaces = [] pre_normalization_vars = [] for i in range(len(kspace)): # pylint: disable=consider-using-enumerate if not is_none(self.normalization.__repr__()): _kspace, _pre_normalization_vars = self.normalization( # type: ignore kspace[i], apply_backward_transform=True ) else: _kspace = kspace[i] is_complex = _kspace.shape[-1] == 2 if is_complex: _kspace = torch.view_as_complex(_kspace) _pre_normalization_vars = { "min": torch.min(torch.abs(_kspace)), "max": torch.max(torch.abs(_kspace)), "mean": torch.mean(torch.abs(_kspace)), "std": torch.std(torch.abs(_kspace)), "var": torch.var(torch.abs(_kspace)), } if is_complex: _kspace = torch.view_as_real(_kspace) kspaces.append(_kspace) pre_normalization_vars.append(_pre_normalization_vars) kspace = kspaces else: if not is_none(self.normalization.__repr__()): kspace, pre_normalization_vars = self.normalization( # type: ignore kspace, apply_backward_transform=True ) else: is_complex = kspace.shape[-1] == 2 if is_complex: kspace = torch.view_as_complex(kspace) pre_normalization_vars = { # type: ignore "min": torch.min(torch.abs(kspace)), "max": torch.max(torch.abs(kspace)), "mean": torch.mean(torch.abs(kspace)), "std": torch.std(torch.abs(kspace)), "var": torch.var(torch.abs(kspace)), } if is_complex: kspace = torch.view_as_real(kspace) if isinstance(masked_kspace, list): masked_kspaces = [] masked_pre_normalization_vars = [] for i in range(len(masked_kspace)): # pylint: disable=consider-using-enumerate if not is_none(self.normalization.__repr__()): _masked_kspace, _masked_pre_normalization_vars = self.normalization( # type: ignore masked_kspace[i], apply_backward_transform=True ) else: _masked_kspace = masked_kspace[i] is_complex = _masked_kspace.shape[-1] == 2 if is_complex: _masked_kspace = torch.view_as_complex(_masked_kspace) _masked_pre_normalization_vars = { "min": torch.min(torch.abs(_masked_kspace)), "max": torch.max(torch.abs(_masked_kspace)), "mean": torch.mean(torch.abs(_masked_kspace)), "std": torch.std(torch.abs(_masked_kspace)), "var": torch.var(torch.abs(_masked_kspace)), } if is_complex: _masked_kspace = torch.view_as_real(_masked_kspace) masked_kspaces.append(_masked_kspace) masked_pre_normalization_vars.append(_masked_pre_normalization_vars) masked_kspace = masked_kspaces else: if not is_none(self.normalization.__repr__()): masked_kspace, masked_pre_normalization_vars = self.normalization( masked_kspace, apply_backward_transform=True ) else: is_complex = masked_kspace.shape[-1] == 2 if is_complex: masked_kspace = torch.view_as_complex(masked_kspace) masked_pre_normalization_vars = { "min": torch.min(torch.abs(masked_kspace)), "max": torch.max(torch.abs(masked_kspace)), "mean": torch.mean(torch.abs(masked_kspace)), "std": torch.std(torch.abs(masked_kspace)), "var": torch.var(torch.abs(masked_kspace)), } if is_complex: masked_kspace = torch.view_as_real(masked_kspace) if self.ssdu: kspace, masked_kspace, mask = self.__self_supervised_data_undersampling__( # type: ignore kspace, masked_kspace, mask, fname ) n2r_pre_normalization_vars = None if self.n2r and (not attrs["n2r_supervised"] or self.ssdu): n2r_masked_kspace, n2r_mask = self.__noise_to_reconstruction__(init_kspace, init_masked_kspace, init_mask) if self.ssdu: if isinstance(mask, list): for i in range(len(mask)): # pylint: disable=consider-using-enumerate if init_mask[i].dim() != mask[i][0].dim(): # type: ignore # find dimensions == 1 in mask[i][0] and add them to init_mask unitary_dims = [j for j in range(mask[i][0].dim()) if mask[i][0].shape[j] == 1] # unsqueeze init_mask to the index of the unitary dimensions for j in unitary_dims: init_mask[i] = init_mask[i].unsqueeze(j) # type: ignore masked_kspace[i] = init_masked_kspace[i] mask[i][0] = init_mask[i] else: if init_mask.dim() != mask[0].dim(): # type: ignore # find dimensions == 1 in mask[0] and add them to init_mask unitary_dims = [j for j in range(mask[0].dim()) if mask[0].shape[j] == 1] # unsqueeze init_mask to the index of the unitary dimensions for j in unitary_dims: init_mask = init_mask.unsqueeze(j) # type: ignore masked_kspace = init_masked_kspace mask[0] = init_mask if "None" not in self.normalization.__repr__(): if isinstance(masked_kspace, list): masked_kspaces = [] masked_pre_normalization_vars = [] for i in range(len(masked_kspace)): # pylint: disable=consider-using-enumerate _masked_kspace, _masked_pre_normalization_vars = self.normalization( # type: ignore masked_kspace[i], apply_backward_transform=True ) masked_kspaces.append(_masked_kspace) masked_pre_normalization_vars.append(_masked_pre_normalization_vars) masked_kspace = masked_kspaces else: masked_kspace, masked_pre_normalization_vars = self.normalization( # type: ignore masked_kspace, apply_backward_transform=True ) if isinstance(n2r_masked_kspace, list): n2r_masked_kspaces = [] n2r_pre_normalization_vars = [] for i in range(len(n2r_masked_kspace)): # pylint: disable=consider-using-enumerate _n2r_masked_kspace, _n2r_pre_normalization_vars = self.normalization( # type: ignore n2r_masked_kspace[i], apply_backward_transform=True ) n2r_masked_kspaces.append(_n2r_masked_kspace) n2r_pre_normalization_vars.append(_n2r_pre_normalization_vars) n2r_masked_kspace = n2r_masked_kspaces else: n2r_masked_kspace, n2r_pre_normalization_vars = self.normalization( # type: ignore n2r_masked_kspace, apply_backward_transform=True ) else: masked_pre_normalization_vars = None # type: ignore n2r_pre_normalization_vars = None # type: ignore masked_kspace = [masked_kspace, n2r_masked_kspace] mask = [mask, n2r_mask] if self.normalization_type == "grayscale": if isinstance(mask, list): masks = [] for i in range(len(mask)): # pylint: disable=consider-using-enumerate _mask, _ = self.normalization(mask[i], apply_backward_transform=False) # type: ignore masks.append(_mask) mask = masks else: mask, _ = self.normalization(mask, apply_backward_transform=False) # type: ignore pre_normalization_vars = { # type: ignore "kspace_pre_normalization_vars": pre_normalization_vars, "masked_kspace_pre_normalization_vars": masked_pre_normalization_vars, "noise_masked_kspace_pre_normalization_vars": n2r_pre_normalization_vars, } return kspace, masked_kspace, mask, pre_normalization_vars, acc # type: ignore def __noise_to_reconstruction__( self, kspace: torch.Tensor, masked_kspace: torch.Tensor, mask: Union[List, torch.Tensor], ) -> Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]: """Apply the noise-to-reconstruction transform. Parameters ---------- kspace : torch.Tensor The fully-sampled kspace. masked_kspace : torch.Tensor The undersampled kspace. mask : Union[List, torch.Tensor] The undersampling mask. Returns ------- n2r_masked_kspace : Union[List, torch.Tensor] The noise-to-reconstruction undersampled kspace. n2r_mask : Union[List, torch.Tensor] The noise-to-reconstruction mask. """ if isinstance(mask, list): n2r_masked_kspaces = [] n2r_masks = [] for i in range(len(mask)): # pylint: disable=consider-using-enumerate n2r_mask = self.n2r_masking(kspace, mask[i]) # type: ignore # pylint: disable=not-callable n2r_masks.append(n2r_mask) n2r_masked_kspaces.append(masked_kspace[i] * n2r_mask + 0.0) n2r_mask = n2r_masks n2r_masked_kspace = n2r_masked_kspaces else: n2r_mask = self.n2r_masking(kspace, mask) # type: ignore # pylint: disable=not-callable n2r_masked_kspace = masked_kspace * n2r_mask + 0.0 return n2r_masked_kspace, n2r_mask def __self_supervised_data_undersampling__( # noqa: MC0001 self, kspace: torch.Tensor, masked_kspace: Union[List, torch.Tensor], mask: Union[List, torch.Tensor], fname: str, ) -> Tuple[ List[float | Any] | float | Any, List[float | Any] | float | Any, List[List[torch.Tensor | Any]] | List[torch.Tensor | Any], ]: """Self-supervised data undersampling. Parameters ---------- kspace : torch.Tensor The fully-sampled kspace. masked_kspace : Union[List, torch.Tensor] The undersampled kspace. mask : Union[List, torch.Tensor] The undersampling mask. fname : str The filename of the current sample. Returns ------- kspace : torch.Tensor The kspace with the loss mask applied. masked_kspace : torch.Tensor The kspace with the train mask applied. mask : list, [torch.Tensor, torch.Tensor] The train and loss masks. """ if isinstance(mask, list): kspaces = [] masked_kspaces = [] masks = [] for i in range(len(mask)): # pylint: disable=consider-using-enumerate is_1d = mask[i].squeeze().dim() == 1 if self.shift_mask: mask[i] = torch.fft.fftshift(mask[i].squeeze(-1), dim=(-2, -1)).unsqueeze(-1) mask[i] = mask[i].squeeze() if is_1d: mask[i] = mask[i].unsqueeze(0).repeat_interleave(kspace.shape[1], dim=0) train_mask, loss_mask = self.ssdu_masking( # type: ignore # pylint: disable=not-callable kspace, mask[i], fname ) if self.shift_mask: train_mask = torch.fft.fftshift(train_mask, dim=(0, 1)) loss_mask = torch.fft.fftshift(loss_mask, dim=(0, 1)) if is_1d: train_mask = train_mask.unsqueeze(0).unsqueeze(-1) loss_mask = loss_mask.unsqueeze(0).unsqueeze(-1) else: # find unitary dims in mask dims = [i for i, x in enumerate(mask[i].shape) if x == 1] # unsqueeze to broadcast for d in dims: train_mask = train_mask.unsqueeze(d) loss_mask = loss_mask.unsqueeze(d) if train_mask.dim() != kspace.dim(): # find dims != to any train_mask dim dims = [i for i, x in enumerate(kspace.shape) if x not in train_mask.shape] # unsqueeze to broadcast for d in dims: train_mask = train_mask.unsqueeze(d) loss_mask = loss_mask.unsqueeze(d) kspaces.append(kspace * loss_mask + 0.0) masked_kspaces.append(masked_kspace[i] * train_mask + 0.0) masks.append([train_mask, loss_mask]) kspace = kspaces masked_kspace = masked_kspaces mask = masks else: is_1d = mask.squeeze().dim() == 1 if self.shift_mask: mask = torch.fft.fftshift(mask.squeeze(-1), dim=(-2, -1)).unsqueeze(-1) mask = mask.squeeze() if is_1d: mask = mask.unsqueeze(0).repeat_interleave(kspace.shape[1], dim=0) train_mask, loss_mask = self.ssdu_masking( # type: ignore # pylint: disable=not-callable kspace, mask, fname ) if self.shift_mask: train_mask = torch.fft.fftshift(train_mask, dim=(0, 1)) loss_mask = torch.fft.fftshift(loss_mask, dim=(0, 1)) if is_1d: train_mask = train_mask.unsqueeze(0).unsqueeze(-1) loss_mask = loss_mask.unsqueeze(0).unsqueeze(-1) else: # find unitary dims in mask dims = [i for i, x in enumerate(mask.shape) if x == 1] # unsqueeze to broadcast for d in dims: train_mask = train_mask.unsqueeze(d) loss_mask = loss_mask.unsqueeze(d) if train_mask.dim() != kspace.dim(): # find dims != to any train_mask dim dims = [i for i, x in enumerate(kspace.shape) if x not in train_mask.shape] # unsqueeze to broadcast for d in dims: train_mask = train_mask.unsqueeze(d) loss_mask = loss_mask.unsqueeze(d) kspace = kspace * loss_mask + 0.0 masked_kspace = masked_kspace * train_mask + 0.0 mask = [train_mask, loss_mask] return kspace, masked_kspace, mask def __process_coil_sensitivities_map__( self, sensitivity_map: np.ndarray, kspace: torch.Tensor ) -> Union[torch.Tensor, Dict]: """Preprocesses the coil sensitivities map. Parameters ---------- sensitivity_map : np.ndarray The coil sensitivities map. kspace : torch.Tensor The kspace. Returns ------- List[torch.Tensor, Dict] The preprocessed coil sensitivities map and the normalization variables. """ # This condition is necessary in case of auto estimation of sense maps. if self.coil_sensitivity_maps_estimator is not None: sensitivity_map = self.coil_sensitivity_maps_estimator(kspace) elif sensitivity_map is not None and sensitivity_map.size != 0: sensitivity_map = to_tensor(sensitivity_map) sensitivity_map = self.coils_shape_transforms(sensitivity_map, apply_forward_transform=True) sensitivity_map = self.cropping(sensitivity_map, apply_forward_transform=self.kspace_crop) # type: ignore else: # If no sensitivity map is provided, either the data is singlecoil or the sense net is used. # Initialize the sensitivity map to 1 to assure for the singlecoil case. sensitivity_map = torch.ones_like(kspace) if not isinstance(kspace, list) else torch.ones_like(kspace[0]) if not is_none(self.normalization.__repr__()): sensitivity_map, pre_normalization_vars = self.normalization( # type: ignore sensitivity_map, apply_forward_transform=self.kspace_crop ) else: is_complex = sensitivity_map.shape[-1] == 2 if is_complex: sensitivity_map = torch.view_as_complex(sensitivity_map) pre_normalization_vars = { "min": torch.min(torch.abs(sensitivity_map)), "max": torch.max(torch.abs(sensitivity_map)), "mean": torch.mean(torch.abs(sensitivity_map)), "std": torch.std(torch.abs(sensitivity_map)), "var": torch.var(torch.abs(sensitivity_map)), } if is_complex: sensitivity_map = torch.view_as_real(sensitivity_map) return sensitivity_map, pre_normalization_vars def __initialize_prediction__( self, prediction: Union[np.ndarray, None], kspace: torch.Tensor, sensitivity_map: torch.Tensor ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Dict]: """Predicts a coil-combined image. Parameters ---------- prediction : np.ndarray The initial estimation, if None, the prediction is initialized. kspace : torch.Tensor The kspace. sensitivity_map : torch.Tensor The sensitivity map. Returns ------- Tuple[Union[List[torch.Tensor], torch.Tensor], Dict] The initialized prediction, either a list of coil-combined images or a single coil-combined image and the pre-normalization variables (min, max, mean, std). """ if is_none(prediction) or prediction.ndim < 2 or isinstance(kspace, list): # type: ignore if isinstance(kspace, list): prediction = [] pre_normalization_vars = [] for y in kspace: pred = coil_combination_method_func( ifft2(y, self.fft_centered, self.fft_normalization, self.spatial_dims), sensitivity_map, method=self.coil_combination_method, dim=self.coil_dim, ) pred = self.cropping(pred, apply_forward_transform=self.kspace_crop) # type: ignore if not is_none(self.normalization.__repr__()): pred, _pre_normalization_vars = self.normalization( # type: ignore pred, apply_forward_transform=self.kspace_crop ) else: if pred.shape[-1] == 2: pred = torch.view_as_complex(pred) _pre_normalization_vars = { "min": torch.min(torch.abs(pred)), "max": torch.max(torch.abs(pred)), "mean": torch.mean(torch.abs(pred)), "std": torch.std(torch.abs(pred)), "var": torch.var(torch.abs(pred)), } prediction.append(pred) pre_normalization_vars.append(_pre_normalization_vars) if prediction[0].shape[-1] != 2 and torch.is_complex(prediction[0]): prediction = [torch.view_as_real(x) for x in prediction] else: prediction = coil_combination_method_func( ifft2(kspace, self.fft_centered, self.fft_normalization, self.spatial_dims), sensitivity_map, method=self.coil_combination_method, dim=self.coil_dim, ) prediction = self.cropping(prediction, apply_forward_transform=self.kspace_crop) # type: ignore if not is_none(self.normalization.__repr__()): prediction, pre_normalization_vars = self.normalization( # type: ignore prediction, apply_forward_transform=self.kspace_crop ) else: if prediction.shape[-1] == 2: prediction = torch.view_as_complex(prediction) pre_normalization_vars = { # type: ignore "min": torch.min(torch.abs(prediction)), "max": torch.max(torch.abs(prediction)), "mean": torch.mean(torch.abs(prediction)), "std": torch.std(torch.abs(prediction)), "var": torch.var(torch.abs(prediction)), } if prediction.shape[-1] != 2 and torch.is_complex(prediction): prediction = torch.view_as_real(prediction) else: if isinstance(prediction, np.ndarray): prediction = to_tensor(prediction) prediction = self.cropping(prediction, apply_forward_transform=self.kspace_crop) # type: ignore if not is_none(self.normalization.__repr__()): prediction, pre_normalization_vars = self.normalization( # type: ignore prediction, apply_forward_transform=self.kspace_crop ) else: if prediction.shape[-1] == 2: # type: ignore prediction = torch.view_as_complex(prediction) pre_normalization_vars = { # type: ignore "min": torch.min(torch.abs(prediction)), "max": torch.max(torch.abs(prediction)), "mean": torch.mean(torch.abs(prediction)), "std": torch.std(torch.abs(prediction)), "var": torch.var(torch.abs(prediction)), } if prediction.shape[-1] != 2 and torch.is_complex(prediction): prediction = torch.view_as_real(prediction) return prediction, pre_normalization_vars # type: ignore def __parse_normalization_vars__( # noqa: MC0001 self, kspace_vars, sensitivity_vars, prediction_vars, noise_prediction_vars, target_vars ) -> Dict: """ Parses the normalization variables and returns a unified dictionary. Parameters ---------- kspace_vars : Dict The kspace normalization variables. sensitivity_vars : Dict The sensitivity map normalization variables. prediction_vars : Dict The prediction normalization variables. noise_prediction_vars : Union[Dict, None] The noise prediction normalization variables. target_vars : Dict The target normalization variables. Returns ------- Dict The normalization variables. """ normalization_vars = {} if self.complex_data: masked_kspace_vars = kspace_vars["masked_kspace_pre_normalization_vars"] if isinstance(masked_kspace_vars, list): if masked_kspace_vars[0] is not None: for i, masked_kspace_var in enumerate(masked_kspace_vars): normalization_vars[f"masked_kspace_min_{i}"] = masked_kspace_var["min"] normalization_vars[f"masked_kspace_max_{i}"] = masked_kspace_var["max"] normalization_vars[f"masked_kspace_mean_{i}"] = masked_kspace_var["mean"] normalization_vars[f"masked_kspace_std_{i}"] = masked_kspace_var["std"] normalization_vars[f"masked_kspace_var_{i}"] = masked_kspace_var["var"] else: if masked_kspace_vars is not None: normalization_vars["masked_kspace_min"] = masked_kspace_vars["min"] normalization_vars["masked_kspace_max"] = masked_kspace_vars["max"] normalization_vars["masked_kspace_mean"] = masked_kspace_vars["mean"] normalization_vars["masked_kspace_std"] = masked_kspace_vars["std"] normalization_vars["masked_kspace_var"] = masked_kspace_vars["var"] noise_masked_kspace_vars = kspace_vars["noise_masked_kspace_pre_normalization_vars"] if noise_masked_kspace_vars is not None: if isinstance(noise_masked_kspace_vars, list): if noise_masked_kspace_vars[0] is not None: for i, noise_masked_kspace_var in enumerate(noise_masked_kspace_vars): normalization_vars[f"noise_masked_kspace_min_{i}"] = noise_masked_kspace_var["min"] normalization_vars[f"noise_masked_kspace_max_{i}"] = noise_masked_kspace_var["max"] normalization_vars[f"noise_masked_kspace_mean_{i}"] = noise_masked_kspace_var["mean"] normalization_vars[f"noise_masked_kspace_std_{i}"] = noise_masked_kspace_var["std"] normalization_vars[f"noise_masked_kspace_var_{i}"] = noise_masked_kspace_var["var"] else: if noise_masked_kspace_vars is not None: normalization_vars["noise_masked_kspace_min"] = noise_masked_kspace_vars["min"] normalization_vars["noise_masked_kspace_max"] = noise_masked_kspace_vars["max"] normalization_vars["noise_masked_kspace_mean"] = noise_masked_kspace_vars["mean"] normalization_vars["noise_masked_kspace_std"] = noise_masked_kspace_vars["std"] normalization_vars["noise_masked_kspace_var"] = noise_masked_kspace_vars["var"] kspace_vars = kspace_vars["kspace_pre_normalization_vars"] if isinstance(kspace_vars, list): if kspace_vars[0] is not None: for i, kspace_var in enumerate(kspace_vars): normalization_vars[f"kspace_min_{i}"] = kspace_var["min"] normalization_vars[f"kspace_max_{i}"] = kspace_var["max"] normalization_vars[f"kspace_mean_{i}"] = kspace_var["mean"] normalization_vars[f"kspace_std_{i}"] = kspace_var["std"] normalization_vars[f"kspace_var_{i}"] = kspace_var["var"] else: if kspace_vars is not None: normalization_vars["kspace_min"] = kspace_vars["min"] normalization_vars["kspace_max"] = kspace_vars["max"] normalization_vars["kspace_mean"] = kspace_vars["mean"] normalization_vars["kspace_std"] = kspace_vars["std"] normalization_vars["kspace_var"] = kspace_vars["var"] if sensitivity_vars is not None: normalization_vars["sensitivity_maps_min"] = sensitivity_vars["min"] normalization_vars["sensitivity_maps_max"] = sensitivity_vars["max"] normalization_vars["sensitivity_maps_mean"] = sensitivity_vars["mean"] normalization_vars["sensitivity_maps_std"] = sensitivity_vars["std"] normalization_vars["sensitivity_maps_var"] = sensitivity_vars["var"] if isinstance(prediction_vars, list): if prediction_vars[0] is not None: for i, prediction_var in enumerate(prediction_vars): normalization_vars[f"prediction_min_{i}"] = prediction_var["min"] normalization_vars[f"prediction_max_{i}"] = prediction_var["max"] normalization_vars[f"prediction_mean_{i}"] = prediction_var["mean"] normalization_vars[f"prediction_std_{i}"] = prediction_var["std"] normalization_vars[f"prediction_var_{i}"] = prediction_var["var"] else: if prediction_vars is not None: normalization_vars["prediction_min"] = prediction_vars["min"] normalization_vars["prediction_max"] = prediction_vars["max"] normalization_vars["prediction_mean"] = prediction_vars["mean"] normalization_vars["prediction_std"] = prediction_vars["std"] normalization_vars["prediction_var"] = prediction_vars["var"] if noise_prediction_vars is not None: if isinstance(noise_prediction_vars, list): for i, noise_prediction_var in enumerate(noise_prediction_vars): normalization_vars[f"noise_prediction_min_{i}"] = noise_prediction_var["min"] normalization_vars[f"noise_prediction_max_{i}"] = noise_prediction_var["max"] normalization_vars[f"noise_prediction_mean_{i}"] = noise_prediction_var["mean"] normalization_vars[f"noise_prediction_std_{i}"] = noise_prediction_var["std"] normalization_vars[f"noise_prediction_var_{i}"] = noise_prediction_var["var"] else: normalization_vars["noise_prediction_min"] = noise_prediction_vars["min"] normalization_vars["noise_prediction_max"] = noise_prediction_vars["max"] normalization_vars["noise_prediction_mean"] = noise_prediction_vars["mean"] normalization_vars["noise_prediction_std"] = noise_prediction_vars["std"] normalization_vars["noise_prediction_var"] = noise_prediction_vars["var"] if isinstance(target_vars, list): if target_vars[0] is not None: for i, target_var in enumerate(target_vars): normalization_vars[f"target_min_{i}"] = target_var["min"] normalization_vars[f"target_max_{i}"] = target_var["max"] normalization_vars[f"target_mean_{i}"] = target_var["mean"] normalization_vars[f"target_std_{i}"] = target_var["std"] normalization_vars[f"target_var_{i}"] = target_var["var"] else: if target_vars is not None: normalization_vars["target_min"] = target_vars["min"] normalization_vars["target_max"] = target_vars["max"] normalization_vars["target_mean"] = target_vars["mean"] normalization_vars["target_std"] = target_vars["std"] normalization_vars["target_var"] = target_vars["var"] return normalization_vars