Common MRI Metrics#

class atommic.collections.common.metrics.global_average_loss_metric.GlobalAverageLossMetric(*args: Any, **kwargs: Any)[source]#

Bases: Metric

This class is for averaging loss across multiple processes if a distributed backend is used. True average is computed not running average. It does not accumulate gradients so the averaged loss cannot be used for optimization.

Note

If take_avg_loss is True, the update() method loss argument has to be a mean loss. If take_avg_loss is False then the update() method loss argument has to be a sum of losses. See PyTorch Lightning Metrics for the metric usage instruction.

Examples

>>> from atommic.collections.common.metrics.global_average_loss_metric import GlobalAverageLossMetric
>>> metric = GlobalAverageLossMetric()
>>> metric.update(torch.tensor(1.0), torch.tensor(1))
>>> metric.update(torch.tensor(2.0), torch.tensor(1))
>>> metric.compute()
tensor(1.5000)
>>> metric.update(torch.tensor(3.0), torch.tensor(1))
>>> metric.compute()
tensor(2.0000)
__init__(dist_sync_on_step=False, process_group=None, take_avg_loss=True)[source]#

Inits GlobalAverageLossMetric.

Parameters
  • dist_sync_on_step (bool) –

    Synchronize metric state across processes at each method forward() call before returning the value at

    the step. Default is False.

  • process_group (Any, optional) – Specify the process group on which synchronization is called. default: None (which selects the entire world). Default is None.

  • take_avg_loss (bool) – If True values of update() method loss argument has to be a mean loss. If False values of update() method loss argument has to be a sum of losses. Default is True.

compute()[source]#

Returns mean loss.

full_state_update: bool = True#
update(loss, num_measurements)[source]#

Updates loss_sum and num_measurements.

Parameters
  • loss (torch.Tensor) – A float zero dimensional torch.Tensor which is either sum or average of losses for processed examples. See take_avg_loss parameter of __init__().

  • num_measurements (torch.Tensor) – An integer zero dimensional torch.Tensor which contains a number of loss measurements. The sum or mean of the results of these measurements are in the loss parameter.