Source code for nequip.data.modifier

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
"""
Data statistics and metrics managers work with BaseModifier, its subclasses (and perhaps classes that mimic their behavior) under the hood.
The action of a modifier:
 - `AtomicDataDict` -> `torch.Tensor`  (for data statistics), or
 - `AtomicDataDict`, `AtomicDataDict` -> `torch.Tensor`, `torch.Tensor`  (for metrics)
It should implement
  - `__str__()` for automatic naming,
  - `type()` property for data processing logic,
  - `__call__()` for its action, or
  - optionally `_func()` if the same action is to be applied for both data dicts (when used in metrics)
"""

import torch
from . import AtomicDataDict, _key_registry
from nequip.nn.utils import with_edge_vectors_
from typing import Optional, Union, List


class BaseModifier:
    def __init__(self, field: str) -> None:
        self.field = field

    def _func(self, data: AtomicDataDict.Type) -> torch.Tensor:
        return data[self.field]

    def __call__(
        self, data1: AtomicDataDict.Type, data2: Optional[AtomicDataDict.Type] = None
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        if data2 is None:
            return self._func(data1)
        else:
            return self._func(data1), self._func(data2)

    def __str__(self) -> str:
        return _key_registry.ABBREV.get(self.field, self.field)

    @property
    def type(self) -> str:
        return _key_registry.get_field_type(self.field)


[docs] class PerAtomModifier(BaseModifier): """Normalizes a graph field by the number of atoms (nodes) in the graph. Args: field (str): graph field to be normalized (e.g. ``total_energy``) factor (float): optional factor to scale the field by (e.g. for unit conversions, etc) """ def __init__(self, field: str, factor: Optional[float] = None) -> None: assert field in _key_registry._GRAPH_FIELDS super().__init__(field) self._factor = factor def _func(self, data: AtomicDataDict.Type) -> torch.Tensor: num_atoms = ( data[AtomicDataDict.NUM_NODES_KEY].reciprocal().reshape(-1) ) # (N_graph,) normed = torch.einsum("n..., n -> n...", data[self.field], num_atoms) if self._factor is not None: normed = self._factor * normed return normed def __str__(self) -> str: return "per_atom_" + _key_registry.ABBREV.get(self.field, self.field)
[docs] class MappedFieldModifier(BaseModifier): """Get predictions and targets from different fields.""" def __init__(self, pred_field: str, target_field: str) -> None: super().__init__(pred_field) self.pred_field = pred_field self.target_field = target_field pred_type = _key_registry.get_field_type(self.pred_field) target_type = _key_registry.get_field_type(self.target_field) assert pred_type == target_type, ( f"`pred_field` ({self.pred_field}) and `target_field` ({self.target_field}) " f"must have the same field type, but got `{pred_type}` and `{target_type}`" ) self._type = pred_type def __call__( self, data1: AtomicDataDict.Type, data2: Optional[AtomicDataDict.Type] = None ) -> Union[torch.Tensor, List[torch.Tensor]]: pred = data1[self.pred_field] if data2 is None: return pred return pred, data2[self.target_field] def __str__(self) -> str: pred = _key_registry.ABBREV.get(self.pred_field, self.pred_field) target = _key_registry.ABBREV.get(self.target_field, self.target_field) return f"pred_{pred}_label_{target}" @property def type(self) -> str: return self._type
[docs] class EdgeLengths(BaseModifier): """Get edge lengths from an ``AtomicDataDict``.""" def __init__(self) -> None: super().__init__(AtomicDataDict.EDGE_INDEX_KEY) def _func(self, data: AtomicDataDict.Type) -> torch.Tensor: data = with_edge_vectors_(data, with_lengths=True) return data[AtomicDataDict.EDGE_LENGTH_KEY] def __str__(self) -> str: return "edge_lengths" @property def type(self) -> str: return "edge"
[docs] class NumNeighbors(BaseModifier): """Get number of neighbors from an ``AtomicDataDict``.""" def __init__(self) -> None: super().__init__(AtomicDataDict.EDGE_INDEX_KEY) def _func(self, data: AtomicDataDict.Type) -> torch.Tensor: counts = torch.unique( data[AtomicDataDict.EDGE_INDEX_KEY][0], sorted=True, return_counts=True, )[1] # in case the cutoff is small and some nodes have no neighbors, # we need to pad `counts` up to the right length counts = torch.nn.functional.pad( counts, pad=(0, len(data[AtomicDataDict.POSITIONS_KEY]) - len(counts)) ) return counts def __str__(self) -> str: return "num_neighbors" @property def type(self) -> str: return "node"