Source code for atommic.collections.common.losses.aggregator

# coding=utf-8
__author__ = "Dimitris Karkalousos"

# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/aggregator.py

from typing import List

import torch

from atommic.core.classes.common import typecheck
from atommic.core.classes.loss import Loss
from atommic.core.neural_types.elements import LossType
from atommic.core.neural_types.neural_type import NeuralType

__all__ = ["AggregatorLoss"]


[docs]class AggregatorLoss(Loss): """Aggregates multiple losses into a single loss. Examples -------- >>> from atommic.collections.common.losses.aggregator import AggregatorLoss >>> loss = AggregatorLoss(num_inputs=2) >>> loss(loss_1=torch.tensor(1.0), loss_2=torch.tensor(2.0)) tensor(3.) """ @property def input_types(self): """Returns definitions of module input ports.""" return {f"loss_{str(i + 1)}": NeuralType(elements_type=LossType()) for i in range(self._num_losses)} @property def output_types(self): """Returns definitions of module output ports.""" return {"loss": NeuralType(elements_type=LossType())}
[docs] def __init__(self, num_inputs: int = 2, weights: List[float] = None): """Inits :class:`AggregatorLoss`. Parameters ---------- num_inputs : int Number of losses to be summed. weights : List[float] Weights to be applied to each loss. If None, all losses are weighted equally. reduction : str Reduction method to be applied to the aggregated loss. """ super().__init__() self._num_losses = num_inputs if weights is not None and len(weights) != num_inputs: raise ValueError("Length of weights should be equal to the number of inputs (num_inputs)") self._weights = weights
[docs] @typecheck() def forward(self, **kwargs): """Computes the sum of the losses.""" values = [kwargs[x] for x in sorted(kwargs.keys())] loss = torch.zeros_like(values[0]) for loss_idx, loss_value in enumerate(values): if self._weights is not None: loss = loss.add(loss_value, alpha=self._weights[loss_idx]) else: loss = loss.add(loss_value) return loss