Source code for nequip.data.stats
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from torchmetrics import Metric
from nequip.utils.global_dtype import _GLOBAL_DTYPE
from typing import Callable
class _MeanX(Metric):
"""Hidden base class for mean statistics.
Computes running means with the mean part of Welford's one-pass algorithm for running variance. Can be subclassed for other types of mean statistics, e.g. MeanAbsolute, MeanSquare, etc with ``modifier`` argument.
"""
def __init__(self, modifier: Callable = torch.nn.Identity(), **kwargs):
super().__init__(**kwargs)
self.modifier = modifier
# use the running mean part of Welford's one-pass algorithm for running variance
# but keep sum and count as variables to be updated correctly during distributed training
# reasoning: we avoid accumulated sum (big number) += new sum (small number) during each update so it should be more numerically stable, but we sync sums across devices assuming that the sum on each device is of the same order of magnitude
self.add_state("sum", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, data: torch.Tensor) -> None:
# short circuit if data tensor is empty
sample_count = data.numel()
if sample_count > 0:
data = data.to(_GLOBAL_DTYPE)
# subtract means instead of add sums to reduce precision loss
current_mean = (
self.sum.div(self.count) if torch.is_nonzero(self.count) else 0
)
sample_mean = self.modifier(data).mean()
delta = sample_mean - current_mean
new_count = self.count + sample_count
new_mean = current_mean + delta * sample_count / new_count
# update count and sum
self.count = new_count
self.sum = new_mean * self.count
def compute(self) -> torch.Tensor:
return self.sum.div(self.count)
[docs]
class Mean(_MeanX):
"""Mean computed in a running fashion."""
def __init__(self, **kwargs):
super().__init__(modifier=torch.nn.Identity(), **kwargs)
def __str__(self) -> str:
return "mean"
[docs]
class MeanAbsolute(_MeanX):
"""Mean of absolute value computed in a running fashion."""
def __init__(self, **kwargs):
super().__init__(modifier=torch.abs, **kwargs)
def __str__(self) -> str:
return "mean_abs"
[docs]
class RootMeanSquare(_MeanX):
"""Root mean square computed in a running fashion."""
def __init__(self, **kwargs):
super().__init__(modifier=torch.square, **kwargs)
def compute(self) -> torch.Tensor:
""""""
mean_square = super().compute()
return torch.sqrt(mean_square)
def __str__(self) -> str:
return "rms"
[docs]
class StandardDeviation(Metric):
"""Standard deviation computed in a running fashion with Welford's online algorithm.
Args:
squared (bool): if ``True``, returns variance, else returns standard deviation
unbiased (bool): whether to use the unbiased estimate for standard deviation
"""
def __init__(self, squared: bool = False, unbiased=True, **kwargs):
super().__init__(**kwargs)
self.squared = squared
self.unbiased = unbiased
# TODO: implement the correct dist_reduce_fx for distributed use
self.add_state("M2", default=torch.tensor(0))
self.add_state("mean", default=torch.tensor(0))
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, data: torch.Tensor) -> None:
""""""
# short circuit if data tensor is empty
sample_count = data.numel()
if sample_count > 0:
data = data.to(_GLOBAL_DTYPE)
# compute sample stats
sample_mean = data.mean()
sample_M2 = (data - sample_mean).square().sum()
# auxiliary variables
delta = sample_mean - self.mean
new_count = self.count + sample_count
mean_change = delta * sample_count / new_count
# update
self.mean = self.mean + mean_change
self.M2 = self.M2 + sample_M2 + delta * mean_change * self.count
self.count = new_count
def compute(self) -> torch.Tensor:
""""""
denom = self.count - 1 if self.unbiased else self.count
variance = self.M2.div(denom)
return variance if self.squared else torch.sqrt(variance)
def __str__(self) -> str:
return "var" if self.squared else "std"
[docs]
class Max(Metric):
"""Largest entry seen.
Args:
abs (bool): whether to use absolute values
"""
def __init__(self, abs: bool = False, **kwargs):
super().__init__(**kwargs)
self.abs = abs
self.add_state("max", default=torch.tensor(-float("inf")), dist_reduce_fx="max")
def update(self, data: torch.Tensor) -> None:
""""""
if data.numel() > 0:
self.max = torch.maximum(
self.max, data.abs().max() if self.abs else data.max()
)
def compute(self) -> torch.Tensor:
""""""
return self.max
def __str__(self) -> str:
return "absmax" if self.abs else "max"
[docs]
class Min(Metric):
"""Smallest entry seen.
Args:
abs (bool): whether to use absolute values
"""
def __init__(self, abs: bool = False, **kwargs):
super().__init__(**kwargs)
self.abs = abs
self.add_state("min", default=torch.tensor(float("inf")), dist_reduce_fx="min")
def update(self, data: torch.Tensor) -> None:
""""""
if data.numel() > 0:
self.min = torch.minimum(
self.min, data.abs().min() if self.abs else data.min()
)
def compute(self) -> torch.Tensor:
""""""
return self.min
def __str__(self) -> str:
return "absmin" if self.abs else "min"
[docs]
class Count(Metric):
"""Total number of entries."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, data: torch.Tensor) -> None:
""""""
self.count += data.numel()
def compute(self) -> torch.Tensor:
""""""
return self.count
def __str__(self) -> str:
return "count"