nequip.data.transforms

Data transforms convert the raw data from the AtomicDataset to include information necessary for the model to make predictions and perform training. For example, datasets do not usually come with neighborlists, so the NeighborListTransform is required to convert raw data that only contains positions and energy (and force) labels to additionally include a neighborlist necessary for the model to make predictions.

class nequip.data.transforms.ChemicalSpeciesToAtomTypeMapper(model_type_names: List[str] | None = None, chemical_species_to_atom_type_map: Dict[str, str] | None = None, chemical_symbols: List[str] | None = None)[source]

Maps atomic numbers to atom types and adds the atom types to the AtomicDataDict.

The model operates on abstract atom type indices rather than chemical species, so this transform bridges the gap by mapping atomic numbers to the model’s type indices. In the common case, model type names correspond directly to chemical symbols (e.g. ["H", "C", "O"]) and the mapping is an identity. Custom type names (e.g. "my_Cu" instead of "Cu") require an explicit map.

Parameters:
  • model_type_names (List[str]) – list of atom type names known by the model, e.g. ["H", "C", "O"]

  • chemical_species_to_atom_type_map (Dict[str, str]) – mapping from chemical species to model atom type names, e.g. {"H": "H", "C": "C", "O": "O"} or {"Cu": "my_Cu"} for custom type names. Not all model_type_names need to be present in the map (useful for models trained on full periodic table but simulating subset of elements). If None, defaults to identity mapping, which requires that model_type_names correspond exactly to chemical species (e.g. ["H", "C", "O"]).

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class nequip.data.transforms.NeighborListTransform(r_max: float, per_edge_type_cutoff: Dict[str, float | Dict[str, float]] | None = None, type_names: List[str] | None = None, backend: str = 'matscipy')[source]

Constructs a neighborlist and adds it to the AtomicDataDict.

Parameters:
  • r_max (float) – cutoff radius used for nearest neighbors

  • backend (str) – neighbor list backend (“ase”, “matscipy”, or “vesin”)

  • per_edge_type_cutoff (Dict) – optional per-edge-type cutoffs (must be <= r_max)

  • type_names (List[str]) – list of atom type names

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class nequip.data.transforms.VirialToStressTransform[source]

Converts virials to stress and adds the stress to the AtomicDataDict.

Specifically implements

\[\tau_{ij} = - \frac{\sigma_{ij}}{\Omega}\]

where \(\tau_{ij}\) is a virial component, \(\sigma_{ij}\) is a stress component, and \(\Omega\) is the volume of the cell.

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class nequip.data.transforms.StressSignFlipTransform[source]

Flips the sign of stress in the AtomicDataDict.

In the NequIP convention, positive diagonal components of the stress tensor implies that the system is under tensile strain and wants to compress, while a negative value implies that the system is under compressive strain and wants to expand. This transform can be applied to datasets that follow the opposite sign convention, so that the necessary sign flip happens on-the-fly during training and users can avoid having to generate a copy of the dataset with NequIP stress sign conventions.

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class nequip.data.transforms.AddNaNStressTransform[source]

Add NaN stress tensors for structures without stress data.

Useful for datasets where stresses are not available for all structures. The NaN values can be ignored during loss computation and metrics calculation by using the ignore_nan flag in loss functions and metrics.

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.