Source code for atommic.collections.segmentation.losses.cross_entropy

# coding=utf-8
__author__ = "Dimitris Karkalousos"

import torch
from torch import nn


[docs]class CrossEntropyLoss(nn.Module): """Wrapper around PyTorch's CrossEntropyLoss to support 2D and 3D inputs."""
[docs] def __init__( self, num_samples: int = 50, ignore_index: int = -100, reduction: str = "none", label_smoothing: float = 0.0, weight: torch.Tensor = None, ): """Inits :class:`CrossEntropyLoss`. Parameters ---------- num_samples : int, optional Number of Monte Carlo samples, by default 50 ignore_index : int, optional Index to ignore, by default -100 reduction : str, optional Reduction method, by default "none" label_smoothing : float, optional Label smoothing, by default 0.0 weight : torch.Tensor, optional Weight for each class, by default None """ super().__init__() self.mc_samples = num_samples self.cross_entropy = torch.nn.CrossEntropyLoss( weight=weight, ignore_index=ignore_index, reduction=reduction, label_smoothing=label_smoothing, )
[docs] def forward(self, target: torch.Tensor, _input: torch.Tensor, pred_log_var: torch.Tensor = None) -> torch.Tensor: """Forward pass of :class:`CrossEntropyLoss`. Parameters ---------- target : torch.Tensor Target tensor. Shape: (batch_size, num_classes, *spatial_dims) _input : torch.Tensor Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) pred_log_var : torch.Tensor, optional Prediction log variance tensor. Shape: (batch_size, num_classes, *spatial_dims). Default is ``None``. Returns ------- torch.Tensor Loss tensor. Shape: (batch_size, *spatial_dims) """ # In case we do not have a batch dimension, add it if _input.dim() == 3: _input = _input.unsqueeze(0) if target.dim() == 3: target = target.unsqueeze(0) self.cross_entropy.weight = self.cross_entropy.weight.clone().to(_input.device) if self.mc_samples == 1 or pred_log_var is None: return self.cross_entropy(_input.float(), target).mean() pred_shape = [self.mc_samples, *_input.shape] noise = torch.randn(pred_shape, device=_input.device) noisy_pred = _input.unsqueeze(0) + torch.sqrt(torch.exp(pred_log_var)).unsqueeze(0) * noise noisy_pred = noisy_pred.view(-1, *_input.shape[1:]) tiled_target = target.unsqueeze(0).tile((self.mc_samples,)).view(-1, *target.shape[1:]) loss = self.cross_entropy(noisy_pred, tiled_target).view(self.mc_samples, -1, *_input.shape[-2:]).mean(0) return loss.mean()