nequip.data.transforms

Data transforms convert the raw data from the AtomicDataset to include information necessary for the model to make predictions and perform training. For example, datasets do not usually come with neighborlists, so the NeighborListTransform is required to convert raw data that only contains positions and energy (and force) labels to additionally include a neighborlist necessary for the model to make predictions.

class nequip.data.transforms.ChemicalSpeciesToAtomTypeMapper(model_type_names: List[str] | None = None, chemical_species_to_atom_type_map: Dict[str, str] | None = None, chemical_symbols: List[str] | None = None)[source]

Maps atomic numbers to atom types and adds the atom types to the AtomicDataDict.

This transform accounts for how the atom types seen by the model can be different from the atomic species that one obtains from a conventional dataset. There could be cases where the same chemical species corresponds to multiple atom types, e.g. different charge states.

Parameters:
  • model_type_names (List[str]) – list of atom type names known by the model, e.g. ["H", "C", "O"]

  • chemical_species_to_atom_type_map (Dict[str, str]) – mapping from chemical species to model atom type names, e.g. {"H": "H", "C": "C", "O": "O"} or {"C": "C3+", "C": "C4+"} for charge states. Not all model_type_names need to be present in the map (useful for models trained on full periodic table but simulating subset of elements). If None, defaults to identity mapping, which requires that model_type_names correspond exactly to chemical species (e.g. ["H", "C", "O"]).

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class nequip.data.transforms.NeighborListTransform(r_max: float, per_edge_type_cutoff: Dict[str, float | Dict[str, float]] | None = None, type_names: List[str] | None = None, backend: str = 'matscipy')[source]

Constructs a neighborlist and adds it to the AtomicDataDict.

Parameters:
  • r_max (float) – cutoff radius used for nearest neighbors

  • backend (str) – neighbor list backend (“ase”, “matscipy”, or “vesin”)

  • per_edge_type_cutoff (Dict) – optional per-edge-type cutoffs (must be <= r_max)

  • type_names (List[str]) – list of atom type names

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class nequip.data.transforms.VirialToStressTransform[source]

Converts virials to stress and adds the stress to the AtomicDataDict.

Specifically implements

\[\tau_{ij} = - \frac{\sigma_{ij}}{\Omega}\]

where \(\tau_{ij}\) is a virial component, \(\sigma_{ij}\) is a stress component, and \(\Omega\) is the volume of the cell.

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class nequip.data.transforms.StressSignFlipTransform[source]

Flips the sign of stress in the AtomicDataDict.

In the NequIP convention, positive diagonal components of the stress tensor implies that the system is under tensile strain and wants to compress, while a negative value implies that the system is under compressive strain and wants to expand. This transform can be applied to datasets that follow the opposite sign convention, so that the necessary sign flip happens on-the-fly during training and users can avoid having to generate a copy of the dataset with NequIP stress sign conventions.

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class nequip.data.transforms.AddNaNStressTransform[source]

Add NaN stress tensors for structures without stress data.

Useful for datasets where stresses are not available for all structures. The NaN values can be ignored during loss computation and metrics calculation by using the ignore_nan flag in loss functions and metrics.

forward(data: Dict[str, Tensor]) Dict[str, Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.