Data Handling

Extension packages can implement custom data handling by subclassing NequIP’s base classes and registering custom data fields.

Data Key Registration

NequIP requires all data fields to be registered by type for proper processing.

Field Types

  • Graph fields: per-frame (e.g., total_energy, stress)

  • Node fields: per-atom (e.g., forces, positions)

  • Edge fields: per-edge (e.g., edge_vectors)

  • Long fields: integer dtype (e.g., atomic_numbers)

  • Cartesian tensors: physical tensors (e.g., stress)

Registration Example

from nequip.data import register_fields

register_fields(
    graph_fields=["custom_energy"],
    node_fields=["custom_forces"],
    edge_fields=["custom_edge_attr"],
    long_fields=["custom_indices"],
    cartesian_tensor_fields={"custom_tensor": "ij=ji"}
)

Register custom fields in your package’s __init__.py before they are used in:

  • Dataset loading

  • Loss functions (MetricsManager)

  • Model outputs

Built-in fields are pre-registered. See the API reference below for the complete details.

API Reference

nequip.data.register_fields(graph_fields: Sequence[str] | None = None, node_fields: Sequence[str] | None = None, edge_fields: Sequence[str] | None = None, long_fields: Sequence[str] | None = None, cartesian_tensor_fields: Dict[str, str] | None = None) None[source]

Register custom fields as being per-frame, per-atom, per-edge, long dtype and/or Cartesian tensors.

Parameters:
  • graph_fields (Sequence[str]) – per-frame fields

  • node_fields (Sequence[str]) – per-atom fields

  • edge_fields (Sequence[str]) – per-edge fields

  • long_fields (Sequence[str]) – long dtype fields

  • cartesian_tensor_fields (Dict[str, str]) – Cartesian tensor fields (both the name, and the formula must be provided, e.g. "ij=ji", see e3nn docs)

nequip.data.deregister_fields(*fields: Sequence[str]) None[source]

Deregister a field registered with register_fields().

Silently ignores fields that were never registered to begin with.

Parameters:

*fields (Sequence[str]) – fields to deregister.

Custom Datasets

Extension packages can implement custom datasets by subclassing NequIP’s base dataset classes to handle custom data formats and sources.

See AtomicDataset for the base dataset class and the dataset API documentation for examples.

Custom DataModules

Extension packages can create custom DataModules to handle specific benchmark datasets or complex data workflows. DataModules manage train/val/test splits, dataset downloading/preprocessing, and coordinate datasets, transforms, dataloaders, and statistics.

See NequIPDataModule for the base datamodule class and the datamodule API documentation for examples of dataset-specific implementations.

Data Transforms

Extension packages can implement custom data transforms to preprocess data during loading. Transforms are classes that implement a __call__ method to modify AtomicDataDict objects.

See the transforms API documentation for available transform classes and their patterns.

Custom Neighborlist Backends

Neighborlist backends are registered by name with register_neighborlist_backend() and can then be selected from NeighborListTransform via backend=....

The registration schema is:

  • backend: string key for the backend.

  • fn: callable with signature fn(data, r_max) -> data.

  • supports_cpu / supports_cuda: backend device capability flags.

Backend function contract:

  • batched input must return batched output; unbatched input must return unbatched output.

  • output tensors must stay on the same device as input tensors.

  • mutation should be limited to adding/updating neighborlist outputs.

Example:

import torch
from nequip.data import (
    AtomicDataDict,
    register_neighborlist_backend,
)

def my_backend(data: AtomicDataDict.Type, r_max: float) -> AtomicDataDict.Type:
    # compute edge_index and edge_cell_shift here
    data[AtomicDataDict.EDGE_INDEX_KEY] = torch.empty((2, 0), dtype=torch.long)
    data[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = torch.empty((0, 3))
    return data

register_neighborlist_backend(
    backend="my_backend",
    fn=my_backend,
    supports_cpu=True,
    supports_cuda=False,
)

API Reference

nequip.data.register_neighborlist_backend(backend: str, fn: Callable[[Dict[str, Tensor], float], Dict[str, Tensor]], supports_cpu: bool = True, supports_cuda: bool = False, overwrite: bool = False) None[source]

Register a neighborlist backend callable.

Parameters:
  • backend (str) – name for the backend.

  • fn (Callable) – backend function with signature fn(data, r_max) -> data.

  • supports_cpu (bool) – whether the backend supports CPU execution.

  • supports_cuda (bool) – whether the backend supports CUDA execution.

  • overwrite (bool) – whether to replace an existing backend with the same name.

Notes

Backend function contract:

  • batched input must return batched output; unbatched input must return unbatched output.

  • output tensors must stay on the same device as input tensors.

  • existing tensors in data must be preserved; mutation is limited to neighborlist outputs.