Source code for nequip.train.metrics
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from torchmetrics import Metric
from nequip.data.stats import _MeanX
[docs]
class MeanAbsoluteError(_MeanX):
"""Mean absolute error."""
def __init__(self, **kwargs):
super().__init__(modifier=torch.abs, **kwargs)
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
""""""
super().update(preds - target)
def __str__(self) -> str:
return "mae"
[docs]
class MeanSquaredError(_MeanX):
"""Mean squared error."""
def __init__(self, **kwargs):
super().__init__(modifier=torch.square, **kwargs)
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
""""""
super().update(preds - target)
def __str__(self) -> str:
return "mse"
[docs]
class RootMeanSquaredError(MeanSquaredError):
"""Root mean squared error."""
def compute(self) -> torch.Tensor:
""""""
return torch.sqrt(self.sum.div(self.count))
def __str__(self) -> str:
return "rmse"
[docs]
class MaximumAbsoluteError(Metric):
"""Maximum absolute error."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state(
"max_error", default=torch.tensor(-float("inf")), dist_reduce_fx="max"
)
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
""""""
if preds.numel() > 0:
abs_errors = torch.abs(preds - target)
self.max_error = torch.maximum(self.max_error, abs_errors.max())
def compute(self) -> torch.Tensor:
""""""
return self.max_error
def __str__(self) -> str:
return "max_ae"
[docs]
class HuberLoss(_MeanX):
"""Huber loss (see `torch.nn.HuberLoss <https://pytorch.org/docs/stable/generated/torch.nn.HuberLoss.html>`_)
Note that ``delta`` takes on the units of the target and prediction tensors.
"""
def __init__(self, reduction="mean", delta=1.0, **kwargs):
assert reduction in ["mean", "sum"]
def _huber(x):
absx = torch.abs(x)
return torch.where(
absx < delta, 0.5 * x.square(), delta * (absx - 0.5 * delta)
)
super().__init__(modifier=_huber, **kwargs)
self.reduction = reduction
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
""""""
super().update(preds - target)
def compute(self) -> torch.Tensor:
""""""
if self.reduction == "mean":
return self.sum.div(self.count)
elif self.reduction == "sum":
return self.sum
def __str__(self) -> str:
return "huber"
[docs]
class StratifiedHuberForceLoss(_MeanX):
"""Stratified Huber loss for vectors (forces)
(see `torch.nn.HuberLoss <https://pytorch.org/docs/stable/generated/torch.nn.HuberLoss.html>`_).
This metrics class implements a stratified/conditional Huber loss, where the Huber ``delta`` parameter is scaled based on the magnitude of the reference vector (i.e. force), by providing a ``delta_dict`` of ``{lower bound: delta parameter}`` where the loss contributions for all vectors with a magnitude between lower bound and the next lower bound are computed as a Huber loss with the corresponding ``delta`` parameter.
Note that ``delta`` values take on the units of the target and prediction tensors.
If the first lower bound in ``delta_dict`` is not 0 (typically recommended),
then a MSELoss (divided by 2; matching Huber loss in the L2 regime (``|x| < delta``),
see `torch.nn.HuberLoss <https://pytorch.org/docs/stable/generated/torch.nn.HuberLoss.html>`_) is used for vectors with a magnitude smaller than the first lower bound.
"""
def __init__(self, delta_dict, reduction="mean", **kwargs):
if min(delta_dict.keys()) > 0:
# add a 0 - lower-bound but with infinite delta so always in the L2 regime
delta_dict = {0: float("inf"), **delta_dict}
self.delta_dict = delta_dict # dict of lower bound: delta parameter
assert reduction in ["mean", "sum"]
assert len(self.delta_dict) >= 2, (
"At least two delta values are required, otherwise use standard HuberLoss instead."
)
super().__init__(**kwargs)
self.reduction = reduction
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
""""""
# templated from the `conditional_huber_forces` function from MACE:
bounds = list(self.delta_dict.keys())
deltas = list(self.delta_dict.values())
stratified_losses = torch.zeros_like(preds)
vector_magnitudes = torch.norm(target, dim=-1)
bounds_masks = [vector_magnitudes >= bounds[i] for i in range(len(bounds))]
for i in range(len(bounds)):
stratum_mask = bounds_masks[i] & (
~bounds_masks[i + 1] if (i + 1) < len(bounds) else True
)
stratified_losses[stratum_mask] = torch.nn.functional.huber_loss(
target[stratum_mask],
preds[stratum_mask],
reduction="none",
delta=deltas[i],
)
super().update(stratified_losses)
def compute(self) -> torch.Tensor:
""""""
if self.reduction == "mean":
return self.sum.div(self.count)
elif self.reduction == "sum":
return self.sum
def __str__(self) -> str:
return "stratified huber"