Common MRI Metrics#
- class atommic.collections.common.metrics.global_average_loss_metric.GlobalAverageLossMetric(*args: Any, **kwargs: Any)[source]#
Bases:
MetricThis class is for averaging loss across multiple processes if a distributed backend is used. True average is computed not running average. It does not accumulate gradients so the averaged loss cannot be used for optimization.
Note
If
take_avg_lossisTrue, theupdate()methodlossargument has to be a mean loss. Iftake_avg_lossisFalsethen theupdate()methodlossargument has to be a sum of losses. See PyTorch Lightning Metrics for the metric usage instruction.Examples
>>> from atommic.collections.common.metrics.global_average_loss_metric import GlobalAverageLossMetric >>> metric = GlobalAverageLossMetric() >>> metric.update(torch.tensor(1.0), torch.tensor(1)) >>> metric.update(torch.tensor(2.0), torch.tensor(1)) >>> metric.compute() tensor(1.5000) >>> metric.update(torch.tensor(3.0), torch.tensor(1)) >>> metric.compute() tensor(2.0000)
- __init__(dist_sync_on_step=False, process_group=None, take_avg_loss=True)[source]#
Inits
GlobalAverageLossMetric.- Parameters
dist_sync_on_step (bool) –
- Synchronize metric state across processes at each method
forward()call before returning the value at the step. Default is
False.
- Synchronize metric state across processes at each method
process_group (Any, optional) – Specify the process group on which synchronization is called. default:
None(which selects the entire world). Default isNone.take_avg_loss (bool) – If
Truevalues ofupdate()methodlossargument has to be a mean loss. IfFalsevalues ofupdate()methodlossargument has to be a sum of losses. Default isTrue.
- full_state_update: bool = True#
- update(loss, num_measurements)[source]#
Updates
loss_sumandnum_measurements.- Parameters
loss (torch.Tensor) – A float zero dimensional
torch.Tensorwhich is either sum or average of losses for processed examples. Seetake_avg_lossparameter of__init__().num_measurements (torch.Tensor) – An integer zero dimensional
torch.Tensorwhich contains a number of loss measurements. The sum or mean of the results of these measurements are in thelossparameter.