Source code for atommic.collections.reconstruction.losses.na

# coding=utf-8
__author__ = "Dimitris Karkalousos"

import torch
import torch.nn.functional as F

from atommic.core.classes.loss import Loss


[docs]class NoiseAwareLoss(Loss): """Computes the Noise Aware loss between two tensors. .. note:: Extends :class:`atommic.core.classes.loss.Loss`. Examples -------- >>> from atommic.collections.reconstruction.losses.na import NoiseAwareLoss >>> import torch >>> loss = NoiseAwareLoss(win_size=7, k1=0.01, k2=0.03) >>> loss(X=torch.rand(1, 1, 256, 256), Y=torch.rand(1, 1, 256, 256)) tensor(0.0872) """
[docs] def forward( self, target: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor = None, sigma: float = 0.0 ) -> torch.Tensor: """Forward pass of :class:`NoiseAwareLoss`. Parameters ---------- target : torch.Tensor The target tensor. pred : torch.Tensor The predicted tensor. mask : torch.Tensor The mask tensor. If None, all pixels are considered. sigma : float The noise level. """ pred = pred.to(target.dtype) if mask is None: mask = torch.ones_like(target) mask = mask.to(target.dtype) # Compute the mean squared error mse = F.mse_loss(target, pred, reduction="none") # Compute the noise variance at each pixel sigma = torch.median(torch.abs(target - pred)) / 0.6745 noise_var = sigma**2 / (1 - mask + 1e-8) # Compute the noise aware loss loss = mse / (2 * noise_var) + torch.log( 2 * noise_var * torch.sqrt(torch.tensor([2 * 3.1415926535])).to(target.device) ) loss = loss.mean() return loss