Source code for atommic.collections.segmentation.losses.focal

# coding=utf-8
__author__ = "Tim Paquaij, Dimitris Karkalousos"

# Taken and adapted from:
# https://github.com/Project-MONAI/MONAI/blob/46a5272196a6c2590ca2589029eed8e4d56ff008/monai/losses/focal_loss.py

import warnings
from typing import Any, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor

from atommic.collections.segmentation.losses.utils import one_hot
from atommic.core.classes.loss import Loss


[docs]class FocalLoss(Loss): """FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from high confidence correct predictions. Reimplementation of the Focal Loss described in: - ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017 - "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy", Zhu et al., Medical Physics 2018 """
[docs] def __init__( self, include_background: bool = True, gamma: float = 2.0, alpha: float | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: str = "mean", use_softmax: bool = False, to_onehot_y: bool = False, num_segmentation_classes: int = None, ) -> None: """Inits :class:`FocalLoss` Parameters ---------- include_background : bool whether to include the computation on the first channel of the predicted output. Default is ``True``. gamma : float, optional Value of the exponent gamma in the definition of the Focal loss. Default is 2 alpha : float, optional Value of the alpha: [0,1] in the definition of the alpha-balanced Focal loss. Default is None weight : torch.Tensor, optional Weight for each class. Default is None reduction : Union[str, None] Specifies the reduction to apply: ``none``: no reduction will be applied. ``mean``: reduction with averaging over both batch and channel dimensions if input is 2D, or batch dimension only if input is 1D ``sum``: reduction with summing over both batch and channel dimensions if input is 2D, or batch dimension only if input is 1D Default is ``mean``. use_softmax : bool, optional option to compute the focal loss as a categorical cross-entropy. Default is ``False`` to_onehot_y : bool Whether to convert `y` into the one-hot format. Default is ``False``. num_segmentation_classes: int Total number of segmentation classes. Default is ``None``. """ super().__init__() self.include_background = include_background self.gamma = gamma self.alpha = alpha self.weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", self.weight) self.reduction = reduction self.use_softmax = use_softmax self.to_onehot_y = to_onehot_y self.num_segmentation_classes = num_segmentation_classes
[docs] def forward(self, target: torch.Tensor, _input: torch.Tensor) -> Tuple[Union[Tensor, Any], Tensor]: # noqa: MC0001 """Forward pass of :class:`FocalLoss`. Parameters ---------- target : torch.Tensor Target tensor. Shape: (batch_size, num_classes, *spatial_dims) _input : torch.Tensor Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) Returns ------- torch.Tensor Focal Loss """ unsqueezed = False if _input.dim() == 3: # if _input.shape[-3] == self.num_segmentation_classes then we need dummy batch dim, else dummy channel dim _input = _input.unsqueeze(0) if _input.shape[-3] == self.num_segmentation_classes else _input.unsqueeze(1) if target.dim() == 3: # if target.shape[-3] == self.num_segmentation_classes then we need dummy batch dim, else dummy channel dim target = target.unsqueeze(0) if target.shape[-3] == self.num_segmentation_classes else target.unsqueeze(1) unsqueezed = True n_pred_ch = _input.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y = True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background = False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] _input = _input[:, 1:] if target.shape != _input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({_input.shape})") loss: Optional[torch.Tensor] = None _input = _input.float() target = target.float() if self.use_softmax: if not self.include_background and self.alpha is not None: self.alpha = None warnings.warn("`include_background = False`, `alpha` ignored when using softmax.") loss = softmax_focal_loss(_input, target, self.gamma, self.alpha) else: loss = sigmoid_focal_loss(_input, target, self.gamma, self.alpha) num_of_classes = target.shape[1] if (self.class_weight is not None) and ( # type: ignore # pylint: disable=access-member-before-definition num_of_classes != 1 ): # make sure the lengths of weights are equal to the number of classes if self.class_weight.ndim == 0: # type: ignore # pylint: disable=access-member-before-definition self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) # type: ignore else: if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. If `include_background = False`, the weight should not include the background category class 0.""" ) if self.class_weight.min() < 0: raise ValueError("the value/values of the `weight` should be no less than 0.") # apply class_weight to loss self.class_weight = self.class_weight.to(loss) broadcast_dims = [-1] + [1] * len(target.shape[2:]) self.class_weight = self.class_weight.view(broadcast_dims) loss = self.class_weight * loss focal_score = loss.clone() # some elements might be Nan (if ground truth y was missing (zeros)), we need to account for it nans = torch.isnan(loss) not_nans = (~nans).float() t_zero = torch.zeros(1, device=loss.device, dtype=loss.dtype) loss[nans] = 0 if self.reduction == "mean": not_nans = not_nans.sum(dim=3) loss = torch.where(not_nans > 0, loss.sum(dim=3) / not_nans, t_zero) # second spatial dim average not_nans = not_nans.sum(dim=2) loss = torch.where(not_nans > 0, loss.sum(dim=2) / not_nans, t_zero) # first spatial dim average not_nans = not_nans.sum(dim=1) loss = torch.where(not_nans > 0, loss.sum(dim=1) / not_nans, t_zero) # channel average not_nans = (not_nans > 0).float().sum(dim=0) loss = torch.where(not_nans > 0, loss.sum(dim=0) / not_nans, t_zero) # batch average elif self.reduction == "sum": loss = loss.sum() elif self.reduction == "none" and unsqueezed: loss = loss.squeeze(1) focal_score = focal_score.squeeze(1) return focal_score, loss
def softmax_focal_loss( _input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None # noqa: MC0001 ) -> torch.Tensor: """Softmax operation for focal loss Parameters ---------- _input : torch.Tensor Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) target : torch.Tensor Target tensor. Shape: (batch_size, num_classes, *spatial_dims) gamma : float Value of the exponent gamma in the definition of the Focal loss. Default is ``2``. alpha : float, optional Value of the alpha: [0,1] in the definition of the alpha-balanced Focal loss. Default is ``None``. Returns ------- torch.Tensor Focal Loss """ input_ls = _input.log_softmax(1) loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: # (1-alpha) for the background class and alpha for the other classes alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) broadcast_dims = [-1] + [1] * len(target.shape[2:]) alpha_fac = alpha_fac.view(broadcast_dims) loss = alpha_fac * loss return loss def sigmoid_focal_loss( _input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None # noqa: MC0001 ) -> torch.Tensor: """Sigmoid operation for focal loss Parameters ---------- _input : torch.Tensor Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) target : torch.Tensor Target tensor. Shape: (batch_size, num_classes, *spatial_dims) gamma : float Value of the exponent gamma in the definition of the Focal loss. Default is ``2``. alpha : float, optional Value of the alpha: [0,1] in the definition of the alpha-balanced Focal loss. Default is ``None``. Returns ------- torch.Tensor Focal Loss """ # computing binary cross entropy with logits # equivalent to F.binary_cross_entropy_with_logits(_input, target, reduction = 'none') # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363 loss: torch.Tensor = _input - _input * target - F.logsigmoid(_input) # sigmoid(-i) if t == 1; sigmoid(i) if t == 0 <=> # 1-sigmoid(i) if t == 1; sigmoid(i) if t == 0 <=> # 1-p if t == 1; p if t == 0 <=> # pfac, that is, the term (1 - pt) invprobs = F.logsigmoid(-_input * (target * 2 - 1)) # reduced chance of overflow # (pfac.log() * gamma).exp() <=> # pfac.log().exp() ^ gamma <=> # pfac ^ gamma loss = (invprobs * gamma).exp() * loss if alpha is not None: # alpha if t == 1; (1-alpha) if t == 0 alpha_factor = target * alpha + (1 - target) * (1 - alpha) loss = alpha_factor * loss return loss