Source code for nequip.data.transforms.neighborlist
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from nequip.data import AtomicDataDict, compute_neighborlist_
from nequip.data._nl import DEFAULT_NEIGHBORLIST_BACKEND
from nequip.data._key_registry import get_field_type
from typing import Optional, Dict, Union, List
class NeighborListPruneTransform(torch.nn.Module):
"""Prunes a neighborlist based on per-edge-type cutoffs.
Args:
r_max (float): global cutoff radius
per_edge_type_cutoff (Dict): per-edge-type cutoffs (must be <= r_max)
type_names (List[str]): list of atom type names
"""
def __init__(
self,
r_max: float,
per_edge_type_cutoff: Dict[str, Union[float, Dict[str, float]]],
type_names: List[str],
):
super().__init__()
self.r_max = r_max
self.per_edge_type_cutoff = per_edge_type_cutoff
self.type_names = type_names
# set up normalizer for per-edge-type cutoffs
from nequip.nn.embedding import EdgeLengthNormalizer
self._normalizer = EdgeLengthNormalizer(
r_max=self.r_max,
type_names=type_names,
per_edge_type_cutoff=per_edge_type_cutoff,
)
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
"""Prune neighbor list based on per-edge-type cutoffs."""
if AtomicDataDict.ATOM_TYPE_KEY not in data:
raise KeyError(
f"Per-edge-type cutoffs require '{AtomicDataDict.ATOM_TYPE_KEY}' to be present in the data. "
"This is likely because nequip.data.transforms.ChemicalSpeciesToAtomTypeMapper needs to be applied before "
"neighborlist pruning. Please check your data transform order."
)
# get mask for pruning
# make a shallow copy of original data dict, such that the original
# doesn't have the tensors added by applying `self._normalizer`
mask = (
self._normalizer(data.copy())[AtomicDataDict.NORM_LENGTH_KEY].view(-1)
<= 1.0
)
# mask edge index (handled separately since it has shape [2, num_edges])
data[AtomicDataDict.EDGE_INDEX_KEY] = data[AtomicDataDict.EDGE_INDEX_KEY][
:, mask
]
# mask all other edge fields
for field in list(data.keys()):
if field == AtomicDataDict.EDGE_INDEX_KEY:
continue # already handled above
if get_field_type(field, error_on_unregistered=False) == "edge":
data[field] = data[field][mask]
return data
class SortedNeighborListTransform(NeighborListTransform):
"""Behaves like :class:`NeighborListTransform` but additionally sorts the neighborlist and provides transpose permutation indices."""
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# first compute the basic neighborlist
data = super().forward(data)
# sort the edge index and corresponding edge attributes
edge_idxs = data[AtomicDataDict.EDGE_INDEX_KEY]
# short-circuit for empty edge index case
if edge_idxs.numel() == 0:
return data
sort_indices = torch.argsort(
edge_idxs[0] * AtomicDataDict.num_nodes(data) + edge_idxs[1]
)
# sort edge indices
data[AtomicDataDict.EDGE_INDEX_KEY] = torch.index_select(
edge_idxs, 1, sort_indices
)
edge_idxs = data[AtomicDataDict.EDGE_INDEX_KEY]
# sort cell shifts if present
if AtomicDataDict.EDGE_CELL_SHIFT_KEY in data:
data[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = torch.index_select(
data[AtomicDataDict.EDGE_CELL_SHIFT_KEY], 0, sort_indices
)
# compute transpose permutation for backward pass
# compute full transpose permutation: sort by sender first, then receiver
transpose_perm = torch.argsort(
edge_idxs[1] * AtomicDataDict.num_nodes(data) + edge_idxs[0]
)
data[AtomicDataDict.EDGE_TRANSPOSE_PERM_KEY] = transpose_perm
return data