Source code for nequip.data._nl

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
from dataclasses import dataclass
from typing import Callable, Dict, Final, List, Optional, Union, Tuple

import numpy as np
import packaging.version

import torch
from . import AtomicDataDict

import ase.neighborlist
from matscipy.neighbours import neighbour_list as matscipy_nl


try:
    from vesin import NeighborList as vesin_nl
except ImportError:
    pass

alchemiops_nl = None


def _load_alchemiops_nl() -> Callable:
    global alchemiops_nl
    if alchemiops_nl is not None:
        return alchemiops_nl
    from nequip.utils.versions.version_utils import get_version_safe

    alchemiops_version = get_version_safe("nvalchemiops")
    if alchemiops_version is None:
        raise ImportError(
            "`nvalchemiops` is not installed. Install it with: pip install nvalchemiops"
        )

    if packaging.version.parse(alchemiops_version) >= packaging.version.parse("0.3.0"):
        from nvalchemiops.torch.neighbors import batch_cell_list as _alchemiops_nl
    else:
        from nvalchemiops.neighborlist import batch_cell_list as _alchemiops_nl

    alchemiops_nl = _alchemiops_nl

    return alchemiops_nl


# use "matscipy" as default
# NOTE:
# - vesin and matscipy do not support self-interaction
# - vesin does not allow for mixed pbcs
NEIGHBORLIST_BACKEND_ASE: Final[str] = "ase"
NEIGHBORLIST_BACKEND_MATSCIPY: Final[str] = "matscipy"
NEIGHBORLIST_BACKEND_VESIN: Final[str] = "vesin"
NEIGHBORLIST_BACKEND_ALCHEMIOPS: Final[str] = "alchemiops"
DEFAULT_NEIGHBORLIST_BACKEND: Final[str] = NEIGHBORLIST_BACKEND_MATSCIPY


@dataclass(frozen=True)
class NeighborlistBackendSpec:
    fn: Callable[[AtomicDataDict.Type, float], AtomicDataDict.Type]
    supports_cpu: bool = True
    supports_cuda: bool = False


def _compute_neighborlist_single_frame(
    pos: torch.Tensor,
    r_max: float,
    backend: str,
    cell: Optional[torch.Tensor] = None,
    pbc: Union[bool, Tuple[bool, bool, bool], torch.Tensor] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Internal function to create neighbor list and neighbor vectors based on radial cutoff.

    Note: This is a private function. Users should use ``compute_neighborlist_`` instead.

    Edges are given by the following convention:
    - ``edge_index[0]`` is the *source* (convolution center).
    - ``edge_index[1]`` is the *target* (neighbor).

    All outputs are Tensors on the same device as ``pos``; this allows future optimization of the neighbor list on the GPU.

    Args:
        pos (torch.Tensor shape [N, 3]): Positional coordinates.
        r_max (float): Radial cutoff distance for neighbor finding.
        cell (torch.Tensor shape [3, 3] or None): Cell for periodic boundary conditions. Required if any ``pbc`` is True.
        pbc (bool or 3-tuple of bool or torch.Tensor): Whether the system is periodic in each of the three cell dimensions.
        backend (str): Neighborlist backend to use.

    Returns:
        edge_index (torch.tensor shape [2, num_edges]): List of edges.
        edge_cell_shift (torch.tensor shape [num_edges, 3]): Relative cell shift vectors.
    """
    if isinstance(pbc, bool):
        pbc = (pbc,) * 3
    elif isinstance(pbc, torch.Tensor):
        # convert tensor to tuple for backends (handles GPU tensors)
        pbc = tuple(pbc.detach().cpu().tolist())

    # get device and dtype from position tensor
    out_device = pos.device
    out_dtype = pos.dtype

    # convert to numpy for neighborlist backends
    temp_pos = pos.detach().cpu().numpy()

    # get cell and complete with ASE utils
    if cell is not None:
        temp_cell = cell.detach().cpu().numpy()
    else:
        # no cell provided, check that PBC is not requested
        if pbc[0] or pbc[1] or pbc[2]:
            raise ValueError(
                "Periodic boundary conditions requested but no cell was provided."
            )
        temp_cell = np.zeros((3, 3), dtype=temp_pos.dtype)
    temp_cell = ase.geometry.complete_cell(temp_cell)

    if backend == NEIGHBORLIST_BACKEND_VESIN:
        # use same mixed pbc logic as
        # https://github.com/Luthaf/vesin/blob/main/python/vesin/src/vesin/_ase.py
        if pbc[0] and pbc[1] and pbc[2]:
            periodic = True
        elif not pbc[0] and not pbc[1] and not pbc[2]:
            periodic = False
        else:
            raise ValueError(
                f"different periodic boundary conditions on different axes are not supported by `{NEIGHBORLIST_BACKEND_VESIN}` neighborlist, use `{NEIGHBORLIST_BACKEND_ASE}` or `{NEIGHBORLIST_BACKEND_MATSCIPY}`"
            )

        first_idex, second_idex, shifts = vesin_nl(
            cutoff=float(r_max), full_list=True
        ).compute(points=temp_pos, box=temp_cell, periodic=periodic, quantities="ijS")
        # vesin returns uint64
        first_idex = first_idex.astype(np.int64)
        second_idex = second_idex.astype(np.int64)

    elif backend == NEIGHBORLIST_BACKEND_MATSCIPY:
        first_idex, second_idex, shifts = matscipy_nl(
            "ijS",
            pbc=pbc,
            cell=temp_cell,
            positions=temp_pos,
            cutoff=float(r_max),
        )
    elif backend == NEIGHBORLIST_BACKEND_ASE:
        first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list(
            "ijS",
            pbc,
            temp_cell,
            temp_pos,
            cutoff=float(r_max),
            self_interaction=False,
            use_scaled_positions=False,
        )
    else:
        raise ValueError(f"Unknown neighborlist backend = `{backend}`")

    # construct output to return
    edge_index = torch.vstack(
        (torch.LongTensor(first_idex), torch.LongTensor(second_idex))
    ).to(device=out_device)
    shifts = torch.as_tensor(
        shifts,
        dtype=out_dtype,
        device=out_device,
    )
    return edge_index, shifts


def _compute_neighborlist_unbatched_backend(
    data: AtomicDataDict.Type, r_max: float, backend: str
) -> AtomicDataDict.Type:
    _data_is_batched = AtomicDataDict.BATCH_KEY in data

    to_batch: List[AtomicDataDict.Type] = []
    for idx in range(AtomicDataDict.num_frames(data)):
        # if data is unbatched, `frame_from_batched` should just be no-op
        data_per_frame = AtomicDataDict.frame_from_batched(data, idx)

        cell = data_per_frame.get(AtomicDataDict.CELL_KEY, None)
        if cell is not None:
            cell = cell.view(3, 3)  # remove batch dimension

        pbc = data_per_frame.get(AtomicDataDict.PBC_KEY, None)
        if pbc is not None:
            pbc = pbc.view(3)  # remove batch dimension

        edge_index, edge_cell_shift = _compute_neighborlist_single_frame(
            pos=data_per_frame[AtomicDataDict.POSITIONS_KEY],
            r_max=r_max,
            cell=cell,
            pbc=pbc,
            backend=backend,
        )
        # add neighborlist information
        data_per_frame[AtomicDataDict.EDGE_INDEX_KEY] = edge_index
        if (
            data.get(AtomicDataDict.CELL_KEY, None) is not None
            and edge_cell_shift is not None
        ):
            data_per_frame[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = edge_cell_shift
        to_batch.append(data_per_frame)

    # the following ensures that we preserve the batch state
    # i.e. unbatched input -> unbatched output; batched input -> batched output
    if _data_is_batched:
        # rebatch to make sure neighborlist information is in a similar batched format
        return AtomicDataDict.batched_from_list(to_batch)
    else:
        assert len(to_batch) == 1
        return to_batch[0]


def alchemiops_batch_cell_list(
    data: AtomicDataDict.Type,
    r_max: float,
) -> AtomicDataDict.Type:
    """Compute a neighbor list using Alchemiops cell list algorithm.

    Args:
        data: input AtomicDataDict.
        r_max: cutoff radius.

    Returns:
        data with neighborlist entries added in-place.
        Only ``edge_index`` and (if ``cell`` exists) ``edge_cell_shift`` are modified.
    """
    alchemiops_fn = _load_alchemiops_nl()

    positions = data[AtomicDataDict.POSITIONS_KEY]
    # handle batching
    if AtomicDataDict.is_batched(data):
        system_idx = data[AtomicDataDict.BATCH_KEY].to(torch.int32)
        n_systems = AtomicDataDict.num_frames(data)
    else:
        system_idx = torch.zeros(1, dtype=torch.int32, device=positions.device).expand(
            positions.shape[0]
        )
        n_systems = 1

    # default to zero cell if cell not present
    if AtomicDataDict.CELL_KEY in data:
        cell = data[AtomicDataDict.CELL_KEY]
    else:
        cell = torch.zeros(
            (n_systems, 3, 3), dtype=positions.dtype, device=positions.device
        )

    # default to no PBCs if PBCs not present
    if AtomicDataDict.PBC_KEY in data:
        pbc = data[AtomicDataDict.PBC_KEY]
    else:
        pbc = torch.zeros((n_systems, 3), dtype=torch.bool, device=positions.device)

    # adapted from https://github.com/TorchSim/torch-sim/blob/main/torch_sim/neighbors/alchemiops.py
    # for non-periodic systems with zero cells, use a nominal identity cell
    # to avoid division by zero in alchemiops warp kernels
    # See https://github.com/NVIDIA/nvalchemi-toolkit-ops/issues/4
    is_non_periodic = ~pbc.any(dim=1)  # [n_systems]
    is_zero_cell = cell.abs().sum(dim=(1, 2)) == 0  # [n_systems]
    needs_nominal_cell = is_non_periodic & is_zero_cell
    if needs_nominal_cell.any():
        identity = torch.eye(3, dtype=cell.dtype, device=cell.device)
        cell = cell.clone()  # Avoid modifying the original
        cell[needs_nominal_cell] = identity

    # call alchemiops cell list
    # nvalchemiops uses `positions.device` to select where neighborlist construction runs
    res = alchemiops_fn(
        positions=positions,
        cutoff=r_max,
        batch_idx=system_idx,
        cell=cell,
        pbc=pbc,
        return_neighbor_list=True,
    )

    # parse results: (neighbor_list, neighbor_ptr[, neighbor_list_shifts])
    if len(res) == 3:  # type: ignore[arg-type]
        edge_index, _, edge_cell_shift = res  # type: ignore[misc]
    else:
        edge_index, _ = res  # type: ignore[misc]
        edge_cell_shift = torch.zeros(
            (edge_index.shape[1], 3), dtype=positions.dtype, device=positions.device
        )

    # populate data dict with neighborlist
    data[AtomicDataDict.EDGE_INDEX_KEY] = edge_index.to(dtype=torch.long)
    if AtomicDataDict.CELL_KEY in data:
        data[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = edge_cell_shift.to(dtype=cell.dtype)
    return data


_DEFAULT_NEIGHBORLIST_BACKEND_OPTIONS: Final[Dict[str, NeighborlistBackendSpec]] = {
    NEIGHBORLIST_BACKEND_ASE: NeighborlistBackendSpec(
        fn=lambda data, r_max: _compute_neighborlist_unbatched_backend(
            data=data, r_max=r_max, backend=NEIGHBORLIST_BACKEND_ASE
        ),
        supports_cpu=True,
        supports_cuda=False,
    ),
    NEIGHBORLIST_BACKEND_MATSCIPY: NeighborlistBackendSpec(
        fn=lambda data, r_max: _compute_neighborlist_unbatched_backend(
            data=data, r_max=r_max, backend=NEIGHBORLIST_BACKEND_MATSCIPY
        ),
        supports_cpu=True,
        supports_cuda=False,
    ),
    NEIGHBORLIST_BACKEND_VESIN: NeighborlistBackendSpec(
        fn=lambda data, r_max: _compute_neighborlist_unbatched_backend(
            data=data, r_max=r_max, backend=NEIGHBORLIST_BACKEND_VESIN
        ),
        supports_cpu=True,
        supports_cuda=False,
    ),
    NEIGHBORLIST_BACKEND_ALCHEMIOPS: NeighborlistBackendSpec(
        fn=alchemiops_batch_cell_list,
        supports_cpu=True,
        supports_cuda=True,
    ),
}

NEIGHBORLIST_BACKEND_OPTIONS: Dict[str, NeighborlistBackendSpec] = dict(
    _DEFAULT_NEIGHBORLIST_BACKEND_OPTIONS
)


[docs] def register_neighborlist_backend( backend: str, fn: Callable[[AtomicDataDict.Type, float], AtomicDataDict.Type], supports_cpu: bool = True, supports_cuda: bool = False, overwrite: bool = False, ) -> None: """Register a neighborlist backend callable. Args: backend (str): name for the backend. fn (Callable): backend function with signature ``fn(data, r_max) -> data``. supports_cpu (bool): whether the backend supports CPU execution. supports_cuda (bool): whether the backend supports CUDA execution. overwrite (bool): whether to replace an existing backend with the same name. Notes: Backend function contract: - batched input must return batched output; unbatched input must return unbatched output. - output tensors must stay on the same device as input tensors. - existing tensors in ``data`` must be preserved; mutation is limited to neighborlist outputs. """ if not isinstance(backend, str) or backend == "": raise ValueError("`backend` must be a non-empty string") if not callable(fn): raise TypeError("`fn` must be callable") if backend in NEIGHBORLIST_BACKEND_OPTIONS and not overwrite: raise ValueError( f"Neighborlist backend `{backend}` already registered. Set `overwrite=True` to replace it." ) NEIGHBORLIST_BACKEND_OPTIONS[backend] = NeighborlistBackendSpec( fn=fn, supports_cpu=supports_cpu, supports_cuda=supports_cuda, )
def compute_neighborlist_( data: AtomicDataDict.Type, r_max: float, backend: str = DEFAULT_NEIGHBORLIST_BACKEND, ) -> AtomicDataDict.Type: """Add a neighborlist to `data` in-place. Contract: - batched input -> batched output; - unbatched input -> unbatched output; - output tensors are on the same device as input tensors. """ if backend not in NEIGHBORLIST_BACKEND_OPTIONS: supported = ", ".join(f"`{b}`" for b in NEIGHBORLIST_BACKEND_OPTIONS) raise ValueError( f"Unknown neighborlist backend = `{backend}`. Supported backends: {supported}" ) return NEIGHBORLIST_BACKEND_OPTIONS[backend].fn(data, r_max)