Losses#
ATOMMIC
provides a number of loss functions for training models. These are all subclasses of torch.nn.Module
and can be used in the same way as any other PyTorch loss function.
For reconstruction
, qMRI
and multitask
tasks, the following losses are available:
MSELoss
:A loss function based on the Mean Squared Error (MSE). It can be used for any task and it calls
torch.nn.MSELoss
.
L1Loss
:A loss function based on the Mean Absolute Error (MAE). It can be used for any task and it calls
torch.nn.L1Loss
.
SSIMLoss
:A loss function based on the Structural Similarity Index (SSIM). It can be used for any task and it is based on [Wang2004].
- Wang2004
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: from error visibility to structural similarity. IEEE transactions on image processing, 13(4), 600-612.
NoiseAwareLoss
:A custom loss function that is aware of the noise level in the data. It can be used for any task and it is based on [Oh2021].
- Oh2021
Oh, Y., Kim, B., & Ham, B. (2021). Background-aware pooling and noise-aware loss for weakly-supervised semantic segmentation. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 6913-6922).
SinkhornDistance
:Resembles the Wasserstein distance, but is differentiable and can be used as a loss function. It can be used for any task and it is based on [Cuturi2013].
- Cuturi2013
Marco Cuturi, Sinkhorn Distances: Lightspeed Computation of Optimal Transport, NIPS 2013
CrossEntropyLoss
:A loss function based on the cross-entropy between the predicted and the ground truth segmentation. It can be used for segmentation tasks and it is a wrapper around
torch.nn.CrossEntropyLoss
.
Dice
:A loss function based on the Dice coefficient. It can be used for segmentation tasks and it is a wrapper for
monai.losses.DiceLoss
to support multi-class and multi-label tasks. It is based on [Milletari2016].- Milletari2016
Milletari, F., Navab, N., & Ahmadi, S. A. (2016, October). V-net: Fully convolutional neural networks for volumetric medical image segmentation. In 2016 fourth international conference on 3D vision (3DV) (pp. 565-571). IEEE.
AggregatorLoss
#
The AggregatorLoss
class is used to combine multiple losses into a single loss function.
Note
The AggregatorLoss
is not a loss function itself, but a wrapper around multiple loss functions. It is used to
combine multiple losses into a single loss function. The AggregatorLoss
is used by the ATOMMIC
models to
combine the losses by setting a weight for each loss function. The weights must sum to 1.0.
AggregatorLoss
is configurable via YAML with Hydra. For example:
model:
reconstruction_loss:
mse: 0.2
l1: 0.2
ssim: 0.2
noise_aware: 0.2
wasserstein: 0.2
This will create a loss function for the reconstruction
task that is a weighted sum of the MSE, MAE, SSIM,
NoiseAware and Wasserstein losses.
model:
segmentation_loss:
cross_entropy: 0.5
dice: 0.5
This will create a loss function for the segmentation
task that is a weighted sum of the CrossEntropy and Dice
losses.
model:
reconstruction_loss:
mse: 0.2
l1: 0.2
ssim: 0.2
noise_aware: 0.2
wasserstein: 0.2
segmentation_loss:
cross_entropy: 0.5
dice: 0.5
total_reconstruction_loss_weight: 0.5
total_segmentation_loss_weight: 0.5
This will create a loss function for the multitask
task that is a weighted sum of the reconstruction and the
segmentation losses. The weights for the reconstruction and segmentation losses are set by the
total_reconstruction_loss_weight
and total_segmentation_loss_weight
parameters, respectively.
model:
quantitative_loss:
mse: 0.2
l1: 0.2
ssim: 0.2
noise_aware: 0.2
wasserstein: 0.2
This will create a loss function for the qMRI
task that is a weighted sum of the MSE, MAE, SSIM, NoiseAware and
Wasserstein losses.