Data Handling¶
Extension packages can implement custom data handling by subclassing NequIP’s base classes and registering custom data fields.
Data Key Registration¶
NequIP requires all data fields to be registered by type for proper processing.
Field Types¶
Graph fields: per-frame (e.g.,
total_energy,stress)Node fields: per-atom (e.g.,
forces,positions)Edge fields: per-edge (e.g.,
edge_vectors)Long fields: integer dtype (e.g.,
atomic_numbers)Cartesian tensors: physical tensors (e.g.,
stress)
Registration Example¶
from nequip.data import register_fields
register_fields(
graph_fields=["custom_energy"],
node_fields=["custom_forces"],
edge_fields=["custom_edge_attr"],
long_fields=["custom_indices"],
cartesian_tensor_fields={"custom_tensor": "ij=ji"}
)
Register custom fields in your package’s __init__.py before they are used in:
Dataset loading
Loss functions (
MetricsManager)Model outputs
Built-in fields are pre-registered. See the API reference below for the complete details.
API Reference¶
- nequip.data.register_fields(graph_fields: Sequence[str] | None = None, node_fields: Sequence[str] | None = None, edge_fields: Sequence[str] | None = None, long_fields: Sequence[str] | None = None, cartesian_tensor_fields: Dict[str, str] | None = None) None[source]¶
Register custom fields as being per-frame, per-atom, per-edge, long dtype and/or Cartesian tensors.
- Parameters:
graph_fields (Sequence[str]) – per-frame fields
node_fields (Sequence[str]) – per-atom fields
edge_fields (Sequence[str]) – per-edge fields
long_fields (Sequence[str]) – long
dtypefieldscartesian_tensor_fields (Dict[str, str]) – Cartesian tensor fields (both the name, and the
formulamust be provided, e.g."ij=ji", see e3nn docs)
Custom Datasets¶
Extension packages can implement custom datasets by subclassing NequIP’s base dataset classes to handle custom data formats and sources.
See AtomicDataset for the base dataset class and the dataset API documentation for examples.
Custom DataModules¶
Extension packages can create custom DataModules to handle specific benchmark datasets or complex data workflows. DataModules manage train/val/test splits, dataset downloading/preprocessing, and coordinate datasets, transforms, dataloaders, and statistics.
See NequIPDataModule for the base datamodule class and the datamodule API documentation for examples of dataset-specific implementations.
Data Transforms¶
Extension packages can implement custom data transforms to preprocess data during loading. Transforms are classes that implement a __call__ method to modify AtomicDataDict objects.
See the transforms API documentation for available transform classes and their patterns.
Custom Neighborlist Backends¶
Neighborlist backends are registered by name with
register_neighborlist_backend() and can then be selected from
NeighborListTransform via backend=....
The registration schema is:
backend: string key for the backend.fn: callable with signaturefn(data, r_max) -> data.supports_cpu/supports_cuda: backend device capability flags.
Backend function contract:
batched input must return batched output; unbatched input must return unbatched output.
output tensors must stay on the same device as input tensors.
mutation should be limited to adding/updating neighborlist outputs.
Example:
import torch
from nequip.data import (
AtomicDataDict,
register_neighborlist_backend,
)
def my_backend(data: AtomicDataDict.Type, r_max: float) -> AtomicDataDict.Type:
# compute edge_index and edge_cell_shift here
data[AtomicDataDict.EDGE_INDEX_KEY] = torch.empty((2, 0), dtype=torch.long)
data[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = torch.empty((0, 3))
return data
register_neighborlist_backend(
backend="my_backend",
fn=my_backend,
supports_cpu=True,
supports_cuda=False,
)
API Reference¶
- nequip.data.register_neighborlist_backend(backend: str, fn: Callable[[Dict[str, Tensor], float], Dict[str, Tensor]], supports_cpu: bool = True, supports_cuda: bool = False, overwrite: bool = False) None[source]¶
Register a neighborlist backend callable.
- Parameters:
backend (str) – name for the backend.
fn (Callable) – backend function with signature
fn(data, r_max) -> data.supports_cpu (bool) – whether the backend supports CPU execution.
supports_cuda (bool) – whether the backend supports CUDA execution.
overwrite (bool) – whether to replace an existing backend with the same name.
Notes
Backend function contract:
batched input must return batched output; unbatched input must return unbatched output.
output tensors must stay on the same device as input tensors.
existing tensors in
datamust be preserved; mutation is limited to neighborlist outputs.