Source code for atommic.collections.reconstruction.losses.ssim
# coding=utf-8
__author__ = "Dimitris Karkalousos"
# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI
import torch
import torch.nn.functional as F
from atommic.core.classes.loss import Loss
[docs]class SSIMLoss(Loss):
"""Computes the (1-) SSIM loss between two tensors.
Examples
--------
>>> from atommic.collections.reconstruction.losses.ssim import SSIMLoss
>>> import torch
>>> loss = SSIMLoss(win_size=7, k1=0.01, k2=0.03)
>>> loss(X=torch.rand(1, 1, 256, 256), Y=torch.rand(1, 1, 256, 256), data_range=torch.tensor([1.]))
tensor(0.9872)
"""
[docs] def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03):
"""Inits :class:`SSIMLoss`.
Parameters
----------
win_size : int, optional
Window size for SSIM calculation.
k1 : float, optional
k1 parameter for SSIM calculation.
k2 : float, optional
k2 parameter for SSIM calculation.
"""
super().__init__()
self.win_size = win_size
self.k1, self.k2 = k1, k2
self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2)
NP = win_size**2
self.cov_norm = NP / (NP - 1)
[docs] def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor = None):
"""Forward pass of :class:`SSIMLoss`.
Parameters
----------
X : torch.Tensor
First input tensor.
Y : torch.Tensor
Second input tensor.
data_range : torch.Tensor
Data range of the input tensors. If ``None``, it is computed as the maximum range of the input tensors.
Default is ``None``.
"""
if not isinstance(self.w, torch.Tensor): # type: ignore # pylint: disable=access-member-before-definition
raise AssertionError
# This is necessary to first assign self.w to CUDA and then in case of fp32 to avoid RuntimeError: Inference
# tensors cannot be saved for backward.
self.w = self.w.to(Y).clone() # type: ignore
if data_range is None:
data_range = torch.tensor([max(X.max() - X.min(), Y.max() - Y.min())]).to(Y)
if isinstance(data_range, int):
data_range = torch.tensor([data_range]).to(Y)
data_range = data_range[:, None, None, None]
C1 = (self.k1 * data_range) ** 2
C2 = (self.k2 * data_range) ** 2
ux = F.conv2d(X, self.w)
uy = F.conv2d(Y, self.w)
uxx = F.conv2d(X * X, self.w)
uyy = F.conv2d(Y * Y, self.w)
uxy = F.conv2d(X * Y, self.w)
vx = self.cov_norm * (uxx - ux * ux)
vy = self.cov_norm * (uyy - uy * uy)
vxy = self.cov_norm * (uxy - ux * uy)
A1, A2, B1, B2 = (2 * ux * uy + C1, 2 * vxy + C2, ux**2 + uy**2 + C1, vx + vy + C2)
D = B1 * B2
S = (A1 * A2) / D
return 1 - S.mean()