Source code for atommic.collections.common.metrics.global_average_loss_metric
# coding=utf-8
__author__ = "Dimitris Karkalousos"
# Taken and adapted from:
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/global_average_loss_metric.py
import torch
from torchmetrics import Metric
__all__ = ["GlobalAverageLossMetric"]
[docs]class GlobalAverageLossMetric(Metric):
"""This class is for averaging loss across multiple processes if a distributed backend is used. True average is
computed not running average. It does not accumulate gradients so the averaged loss cannot be used for
optimization.
.. note::
If ``take_avg_loss`` is ``True``, the :meth:`update` method ``loss`` argument has to be a mean loss. If
``take_avg_loss`` is ``False`` then the :meth:`update` method ``loss`` argument has to be a sum of losses. See
PyTorch Lightning Metrics for the metric usage instruction.
Examples
--------
>>> from atommic.collections.common.metrics.global_average_loss_metric import GlobalAverageLossMetric
>>> metric = GlobalAverageLossMetric()
>>> metric.update(torch.tensor(1.0), torch.tensor(1))
>>> metric.update(torch.tensor(2.0), torch.tensor(1))
>>> metric.compute()
tensor(1.5000)
>>> metric.update(torch.tensor(3.0), torch.tensor(1))
>>> metric.compute()
tensor(2.0000)
"""
full_state_update: bool = True
[docs] def __init__(self, dist_sync_on_step=False, process_group=None, take_avg_loss=True):
"""Inits :class:`GlobalAverageLossMetric`.
Parameters
----------
dist_sync_on_step : bool
Synchronize metric state across processes at each method :meth:`forward` call before returning the value at
the step. Default is ``False``.
process_group : Any, optional
Specify the process group on which synchronization is called. default: ``None`` (which selects the entire
world). Default is ``None``.
take_avg_loss : bool
If ``True`` values of :meth:`update` method ``loss`` argument has to be a mean loss. If ``False`` values of
:meth:`update` method ``loss`` argument has to be a sum of losses. Default is ``True``.
"""
super().__init__(dist_sync_on_step=dist_sync_on_step, process_group=process_group)
self.add_state("loss_sum", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum')
self.add_state("num_measurements", torch.tensor(0, dtype=torch.int64), dist_reduce_fx='sum')
self.take_avg_loss = take_avg_loss
[docs] def update(self, loss, num_measurements): # pylint: disable=arguments-differ
"""Updates :attr:`loss_sum` and :attr:`num_measurements`.
Parameters
----------
loss : torch.Tensor
A float zero dimensional ``torch.Tensor`` which is either sum or average of losses for processed examples.
See ``take_avg_loss`` parameter of :meth:`__init__`.
num_measurements : torch.Tensor
An integer zero dimensional ``torch.Tensor`` which contains a number of loss measurements. The sum or mean
of the results of these measurements are in the ``loss`` parameter.
"""
if self.take_avg_loss:
self.loss_sum = self.loss_sum + loss.detach() * num_measurements
else:
self.loss_sum = self.loss_sum + loss.detach()
self.num_measurements = self.num_measurements + num_measurements
[docs] def compute(self):
"""Returns mean loss."""
if self.num_measurements.eq(0):
return torch.tensor(float("nan"))
return self.loss_sum / self.num_measurements