Losses#
ATOMMIC provides a number of loss functions for training models. These are all subclasses of torch.nn.Module
and can be used in the same way as any other PyTorch loss function.
For reconstruction, qMRI and multitask tasks, the following losses are available:
MSELoss:A loss function based on the Mean Squared Error (MSE). It can be used for any task and it calls
torch.nn.MSELoss.
L1Loss:A loss function based on the Mean Absolute Error (MAE). It can be used for any task and it calls
torch.nn.L1Loss.
SSIMLoss:A loss function based on the Structural Similarity Index (SSIM). It can be used for any task and it is based on [Wang2004].
- Wang2004
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: from error visibility to structural similarity. IEEE transactions on image processing, 13(4), 600-612.
NoiseAwareLoss:A custom loss function that is aware of the noise level in the data. It can be used for any task and it is based on [Oh2021].
- Oh2021
Oh, Y., Kim, B., & Ham, B. (2021). Background-aware pooling and noise-aware loss for weakly-supervised semantic segmentation. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 6913-6922).
SinkhornDistance:Resembles the Wasserstein distance, but is differentiable and can be used as a loss function. It can be used for any task and it is based on [Cuturi2013].
- Cuturi2013
Marco Cuturi, Sinkhorn Distances: Lightspeed Computation of Optimal Transport, NIPS 2013
CrossEntropyLoss:A loss function based on the cross-entropy between the predicted and the ground truth segmentation. It can be used for segmentation tasks and it is a wrapper around
torch.nn.CrossEntropyLoss.
Dice:A loss function based on the Dice coefficient. It can be used for segmentation tasks and it is a wrapper for
monai.losses.DiceLossto support multi-class and multi-label tasks. It is based on [Milletari2016].- Milletari2016
Milletari, F., Navab, N., & Ahmadi, S. A. (2016, October). V-net: Fully convolutional neural networks for volumetric medical image segmentation. In 2016 fourth international conference on 3D vision (3DV) (pp. 565-571). IEEE.
AggregatorLoss#
The AggregatorLoss class is used to combine multiple losses into a single loss function.
Note
The AggregatorLoss is not a loss function itself, but a wrapper around multiple loss functions. It is used to
combine multiple losses into a single loss function. The AggregatorLoss is used by the ATOMMIC models to
combine the losses by setting a weight for each loss function. The weights must sum to 1.0.
AggregatorLoss is configurable via YAML with Hydra. For example:
model:
reconstruction_loss:
mse: 0.2
l1: 0.2
ssim: 0.2
noise_aware: 0.2
wasserstein: 0.2
This will create a loss function for the reconstruction task that is a weighted sum of the MSE, MAE, SSIM,
NoiseAware and Wasserstein losses.
model:
segmentation_loss:
cross_entropy: 0.5
dice: 0.5
This will create a loss function for the segmentation task that is a weighted sum of the CrossEntropy and Dice
losses.
model:
reconstruction_loss:
mse: 0.2
l1: 0.2
ssim: 0.2
noise_aware: 0.2
wasserstein: 0.2
segmentation_loss:
cross_entropy: 0.5
dice: 0.5
total_reconstruction_loss_weight: 0.5
total_segmentation_loss_weight: 0.5
This will create a loss function for the multitask task that is a weighted sum of the reconstruction and the
segmentation losses. The weights for the reconstruction and segmentation losses are set by the
total_reconstruction_loss_weight and total_segmentation_loss_weight parameters, respectively.
model:
quantitative_loss:
mse: 0.2
l1: 0.2
ssim: 0.2
noise_aware: 0.2
wasserstein: 0.2
This will create a loss function for the qMRI task that is a weighted sum of the MSE, MAE, SSIM, NoiseAware and
Wasserstein losses.