Source code for nequip.data.stats_manager

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from torchmetrics import Metric
from . import AtomicDataDict

from .modifier import BaseModifier, PerAtomModifier, NumNeighbors
from .stats import Mean, RootMeanSquare, StandardDeviation
from typing import List, Dict, Union, Callable, Iterable, Optional, Any

from nequip.utils.logger import RankedLogger

logger = RankedLogger(__name__, rank_zero_only=True)


[docs] class DataStatisticsManager(torch.nn.ModuleList): """Manages ``nequip`` metrics that can be applied to ``AtomicDataDict`` to compute dataset statistics. The main input argument ``metrics`` is a list of dictionaries, where each dictionary contains the following keys: **Mandatory keys:** - ``field`` refers to the quantity of interest for metric computation. It has two formats: - a ``str`` for a ``nequip`` defined field (e.g. ``total_energy``, ``forces``, ``stress``), or - a ``Callable`` that performs some additional operations before returning a ``torch.Tensor`` for metric computation (e.g. :class:`~nequip.data.PerAtomModifier`). - ``metric`` is a ``nequip`` data metric object (a subclass of :class:`torchmetrics.Metric`). **Optional keys:** - ``per_type`` is a ``bool`` (defaults to ``False`` if not provided). If ``True``, node fields (such as ``forces``) will have their metrics computed separately for each node type based on the ``type_names`` argument. - ``ignore_nan`` is a ``bool`` (defaults to ``False`` if not provided). This should be set to true if one expects the underlying ``target`` data to contain ``NaN`` entries. - ``name`` is the name that the metric is logged as. Default names are used if not provided, but it is recommended for users to set custom names for clarity and control. Args: metrics (list): list of dictionaries with keys ``field``, ``metric``, ``per_type``, ``ignore_nan``, and ``name`` dataloader_kwargs (dict): arguments of :class:`torch.utils.data.DataLoader` for dataset statistics computation (ideally, the ``batch_size`` should be as large as possible without triggering OOM) type_names (list): required for ``per_type`` metrics (this must match the ``type_names`` argument of the model, it is advisable to use variable interpolation in the config file to make sure they are consistent) """ def __init__( self, metrics: List[ Dict[str, Union[float, str, Dict[str, Union[str, Callable]], Metric]] ], dataloader_kwargs: Optional[Dict[str, Any]] = None, type_names: Optional[List[str]] = None, ): super().__init__() assert len(metrics) != 0 dataloader_kwargs = {} if dataloader_kwargs is None else dataloader_kwargs assert all( key not in dataloader_kwargs for key in ["dataset", "generator", "collate_fn"] ) self.dataloader_kwargs = dataloader_kwargs self.num_metrics = len(metrics) # === MANDATORY dict keys === self.fields = [ ( BaseModifier(metric["field"]) if isinstance(metric["field"], str) else metric["field"] ) for metric in metrics ] for metric in metrics: self.append(metric["metric"]) # === OPTIONAL dict keys and logic based on dict.get(key, None) === # == ignore Nan == self.ignore_nans = [metric.get("ignore_nan", False) for metric in metrics] assert all(isinstance(item, bool) for item in self.ignore_nans) # == process names == self.names = [] for idx in range(self.num_metrics): name = metrics[idx].get("name", None) if name is None: name = "_".join([str(self.fields[idx]), str(self[idx])]) self.names.append(name) assert len(self.names) == len(set(self.names)), ( f"Repeated names found ({self.names}) -- names must be unique. It is recommended to give custom names instead of relying on the automatic naming." ) # === per_type metrics === self.per_type = [metric.get("per_type", False) for metric in metrics] if any(self.per_type): assert type_names is not None, ( "`type_names` must be provided if any `per_type=True`" ) self.type_names = type_names for idx in range(self.num_metrics): if self.per_type[idx]: field_type = self.fields[idx].type assert field_type in [ "node", "edge", ], ( f"`per_type` metrics only apply to node or edge fields, but {field_type} field found for {self.names[idx]}." ) # set up per_type metrics as a ModuleList # one copy of the base Metric for each type in forward() and compute() ptm_list = torch.nn.ModuleList([]) if field_type == "node": num_types = len(self.type_names) elif field_type == "edge": num_types = len(self.type_names) * len(self.type_names) for _ in range(num_types): ptm_list.append(self[idx].clone()) self[idx] = ptm_list self.stats_dict = {} def forward( self, data: AtomicDataDict.Type, ): """""" for idx in range(self.num_metrics): data_tensor = self.fields[idx](data) if self.per_type[idx]: field_type = self.fields[idx].type if field_type == "node": for type_idx in range(len(self.type_names)): # index out each type selector = torch.eq( data[AtomicDataDict.ATOM_TYPE_KEY], type_idx ) per_type_data_tensor = data_tensor[selector] if self.ignore_nans[idx]: notnan_mask = ~torch.isnan(per_type_data_tensor) per_type_data_tensor = torch.masked_select( per_type_data_tensor, notnan_mask ) _ = self[idx][type_idx](per_type_data_tensor) elif field_type == "edge": # index out each type pair edge_type = torch.index_select( data[AtomicDataDict.ATOM_TYPE_KEY].reshape(-1), 0, data[AtomicDataDict.EDGE_INDEX_KEY].reshape(-1), ).view(2, -1) edge_type = edge_type[0] * len(self.type_names) + edge_type[1] for type_idx in range(len(self.type_names) * len(self.type_names)): selector = torch.eq(edge_type, type_idx) per_type_data_tensor = data_tensor[selector] if self.ignore_nans[idx]: notnan_mask = ~torch.isnan(per_type_data_tensor) per_type_data_tensor = torch.masked_select( per_type_data_tensor, notnan_mask ) _ = self[idx][type_idx](per_type_data_tensor) else: if self.ignore_nans[idx]: notnan_mask = ~torch.isnan(data_tensor) data_tensor = torch.masked_select(data_tensor, notnan_mask) _ = self[idx](data_tensor) def compute(self): logger.info("Computed data statistics:") self.stats_dict = {} for idx in range(self.num_metrics): if self.per_type[idx]: field_type = self.fields[idx].type pt_stats = {} if field_type == "node": for type_idx, type_name in enumerate(self.type_names): pt_stat = self[idx][type_idx].compute() pt_stats[type_name] = pt_stat.item() pt_stat_name = "_".join([self.names[idx], type_name]) self.stats_dict.update({pt_stat_name: pt_stat}) logger.info(f"{pt_stat_name}: {pt_stat}") elif field_type == "edge": for center_idx, center_type in enumerate(self.type_names): for neigh_idx, neigh_type in enumerate(self.type_names): type_pair_idx = ( center_idx + len(self.type_names) * neigh_idx ) pt_stat = self[idx][type_pair_idx].compute() pt_stats["_".join([center_type, neigh_type])] = ( pt_stat.item() ) pt_stat_name = "_".join( [self.names[idx], center_type + neigh_type] ) self.stats_dict.update({pt_stat_name: pt_stat}) logger.info(f"{pt_stat_name}: {pt_stat}") self.stats_dict.update({self.names[idx]: pt_stats}) else: stat = self[idx].compute() self.stats_dict.update({self.names[idx]: stat.item()}) logger.info(f"{self.names[idx]}: {stat}") return self.stats_dict
[docs] def reset(self): """Resets accumulated statistics.""" for idx in range(self.num_metrics): if self.per_type[idx]: field_type = self.fields[idx].type if field_type == "node": num_types = len(self.type_names) elif field_type == "edge": num_types = len(self.type_names) * len(self.type_names) for type_idx in range(num_types): self[idx][type_idx].reset() else: self[idx].reset()
[docs] def get_statistics(self, data_source: Iterable[AtomicDataDict.Type]): """ Remember to call reset before this is needed. Args: data_source (Iterable[AtomicDataDict]): iterable data source """ for data in data_source: self(data) return self.compute()
[docs] def CommonDataStatisticsManager( dataloader_kwargs: Optional[Dict[str, Any]] = None, type_names: Optional[List[str]] = None, ): """:class:`~nequip.data.DataStatisticsManager` wrapper that implements common dataset statistics. The dataset statistics computed by using this wrapper include ``num_neighbors_mean``, ``per_atom_energy_mean``, ``forces_rms``, and ``per_type_forces_rms``, which are variables that can be interpolated for in the ``model`` section of the config file. For example: .. code-block:: yaml training_module: _target_: nequip.train.EMALightningModule # other `EMALightningModule` arguments model: _target_: nequip.model.NequIPGNNModel # other model hyperparameters avg_num_neighbors: ${training_data_stats:num_neighbors_mean} per_type_energy_shifts: ${training_data_stats:per_atom_energy_mean} per_type_energy_scales: ${training_data_stats:forces_rms} # or alternatively the per-type forces RMS # per_type_energy_scales: ${training_data_stats:per_type_forces_rms} """ metrics = [ { "name": "num_neighbors_mean", "field": NumNeighbors(), "metric": Mean(), }, { "name": "per_type_num_neighbors_mean", "field": NumNeighbors(), "metric": Mean(), "per_type": True, }, { "name": "per_atom_energy_mean", "field": PerAtomModifier(AtomicDataDict.TOTAL_ENERGY_KEY), "metric": Mean(), }, { "name": "forces_rms", "field": AtomicDataDict.FORCE_KEY, "metric": RootMeanSquare(), }, { "name": "per_type_forces_rms", "field": AtomicDataDict.FORCE_KEY, "metric": RootMeanSquare(), "per_type": True, }, ] return DataStatisticsManager(metrics, dataloader_kwargs, type_names)
[docs] def EnergyOnlyDataStatisticsManager( dataloader_kwargs: Optional[Dict[str, Any]] = None, type_names: Optional[List[str]] = None, ): """:class:`~nequip.data.DataStatisticsManager` wrapper for energy-only datasets. This manager computes statistics for datasets that only contain energies and no forces. The dataset statistics computed include ``num_neighbors_mean``, ``per_atom_energy_mean``, ``per_atom_energy_std``, and ``total_energy_std``, which are variables that can be interpolated for in the ``model`` section of the config file. For example: .. code-block:: yaml training_module: _target_: nequip.train.EMALightningModule # other `EMALightningModule` arguments model: _target_: nequip.model.NequIPGNNModel do_derivatives: false # other model hyperparameters avg_num_neighbors: ${training_data_stats:num_neighbors_mean} per_type_energy_shifts: ${training_data_stats:per_atom_energy_mean} per_type_energy_scales: ${training_data_stats:total_energy_std} """ metrics = [ { "name": "num_neighbors_mean", "field": NumNeighbors(), "metric": Mean(), }, { "name": "per_type_num_neighbors_mean", "field": NumNeighbors(), "metric": Mean(), "per_type": True, }, { "name": "per_atom_energy_mean", "field": PerAtomModifier(AtomicDataDict.TOTAL_ENERGY_KEY), "metric": Mean(), }, { "name": "per_atom_energy_std", "field": PerAtomModifier(AtomicDataDict.TOTAL_ENERGY_KEY), "metric": StandardDeviation(), }, { "name": "total_energy_std", "field": AtomicDataDict.TOTAL_ENERGY_KEY, "metric": StandardDeviation(), }, ] return DataStatisticsManager(metrics, dataloader_kwargs, type_names)