nequip.data.dataset

class nequip.data.dataset.AtomicDataset(transforms: List[Callable] = [])[source]

Base class for all NequIP datasets.

This class provides a standard interface for loading atomic structure data and applying transforms.

Subclasses must implement:
  • __len__() - Return the total number of data samples

  • _get_data_list(indices) - Return raw data for the given indices

Alternatively, subclasses may directly override __getitem__ and __getitems__ for custom indexing behavior.

Parameters:

transforms (List[Callable], optional) – List of data transforms to apply to each data sample. Transforms are applied in order. Defaults to empty list.

num_atoms(indices: List[int] | Tensor | slice) List[int][source]

Subclasses may override this.

class nequip.data.dataset.NequIPLMDBDataset(file_path: str, transforms: List[Callable] = [], exclude_keys: List[str] = [])[source]

AtomicDataset for LMDB data.

The NequIPLMDBDataset is the recommended solution for managing large datasets within the NequIP software ecosystem. One can convert existing datasets into LMDB formated data with helper functions from this class.

As a Dataset object, this class assumes each entry in the LMDB data is a NequIP AtomicDataDict.

Parameters:
  • file_path (str) – path to LMDB file

  • transforms (List[Callable]) – list of data transforms

  • exclude_keys (List[str]) – list of data keys to ignore

classmethod save_from_iterator(file_path: str, iterator: Iterable[Dict[str, Tensor]], map_size: int = 53687091200, write_frequency: int = 1000, extra_metadata: List[LMDBMetadataSpec] = []) None[source]

Uses an iterator of AtomicDataDict objects to construct an LMDB dataset.

Parameters:
  • file_path (str) – path to save the LMDB data

  • iterator (Iterable) – iterator of atomic data dicts

  • map_size (int) – maximum size the database may grow to in bytes (defaults to 50 Gb); note that an exception will be raised if database grows larger than map_size

  • write_frequency (int) – frequency of writing (defaults to 1000). Larger is faster.

  • extra_metadata (List[LMDBMetadataSpec]) – optional list of extra metadata specifications - beyond _BASE_METADATA - to be written to the database. Defaults to an empty list.

class nequip.data.dataset.ASEDataset(file_path: str, transforms: List[Callable] = [], ase_args: Dict[str, Any] = {}, include_keys: List[str] | None = [], exclude_keys: List[str] | None = [], key_mapping: Dict[str, str] | None = {})[source]

AtomicDataset for ASE-readable file formats.

Parameters:
  • file_path (str) – path to ASE-readable file

  • transforms (List[Callable]) – list of data transforms

  • ase_args (Dict[str, Any]) – arguments for ase.io.iread()

  • include_keys (List[str]) – the keys that needs to be parsed into dataset in addition to standard keys (see Data Fields). The data stored in ase.atoms.Atoms.array has the lowest priority, and it will be overrided by data in ase.atoms.Atoms.info and ase.atoms.Atoms.calc.results

  • exclude_keys (List[str]) – list of keys that may be present in the ASE-readable file but the user wishes to exclude

  • key_mapping (Dict[str, str]) – mapping of ase keys to AtomicDataDict keys

class nequip.data.dataset.HDF5Dataset(file_name: str, transforms: List[Callable] = [], key_mapping: Dict[str, str] = {'atomic_numbers': 'atomic_numbers', 'energy': 'total_energy', 'forces': 'forces', 'pos': 'pos', 'types': 'atom_types'})[source]

AtomicDataset that loads data from a HDF5 file.

This class is useful for very large datasets that cannot fit in memory. It efficiently loads data from disk as needed without everything needing to be in memory at once.

To use this, file_name should point to the HDF5 file, or alternatively a semicolon separated list of multiple files. Each group in the file contains samples that all have the same number of atoms. Typically there is one group for each unique number of atoms, but that is not required. Each group should contain arrays whose length equals the number of samples, one for each type of data. The names of the arrays can be specified with key_mapping.

Parameters:
  • file_name (str) – a semicolon separated list of HDF5 files.

  • transforms (List[Callable]) – list of data transforms

  • key_mapping (Dict[str, str]) – mapping of array names in the HDF5 file to AtomicDataDict keys

class nequip.data.dataset.NPZDataset(file_path: str, transforms: List[Callable] = [], key_mapping: Dict[str, str] = {'E': 'total_energy', 'F': 'forces', 'R': 'pos', 'z': 'atomic_numbers'})[source]

AtomicDataset that loads data from an NPZ file following sGDML conventions. It is also compatible with other datasets such as rMD-17, with a change in key_mapping (the default key_mapping is set to be compatible with sGDML datasets).

The NPZDataset avoids loading the whole dataset into memory.

Parameters:
  • file_path (str) – path to npz file

  • transforms (List[Callable]) – list of data transforms

  • key_mapping (Dict[str, str]) – mapping of array names in the npz file to AtomicDataDict keys

class nequip.data.dataset.EMTTestDataset(transforms: List[Callable] = [], supercell: Tuple[int, int, int] = (4, 4, 4), sigma: float = 0.1, element: str = 'Cu', num_frames: int = 10, seed: int = 123456)[source]

Test dataset with PBC, based on the toy EMT potential included in ASE.

Randomly generates (in a reproducable manner) a basic bulk with added Gaussian noise around equilibrium positions. Uses orthorhombic cell construction for safer testing.

In ASE units (eV, Å, eV/Å).

Parameters:
  • transforms (List[Callable]) – list of data transforms

  • supercell (Tuple[int, int, int]) – supercell in each lattice vector direction

  • sigma (float) – standard deviation of Gaussian noise

  • element (str) – element supported by ASE’s EMT calculator (supported elements: Cu, Pd, Au, Pt, Al, Ni, Ag)

  • num_frames (int) – number of structures to be generated in the dataset

  • seed (int) – seed for the random Gaussian noise

class nequip.data.dataset.SubsetByRandomSlice(dataset: Dataset, start: int, length: int, seed: int)[source]

Subset of dataset by slicing a random permutation of the dataset.

Parameters:
  • dataset (Dataset) – torch.utils.data.Dataset to get subset of

  • start (int) – starting index for the slice

  • length (int) – number of samples to slice from start

  • seed (int) – seed for reproducibility of the random permutation of indices

nequip.data.dataset.RandomSplitAndIndexDataset(dataset: Dataset, split_dict: Dict[str, int | float], dataset_key: str, seed: int) Dataset[source]
Parameters:
  • dataset (Dataset) – the base dataset that is to be split

  • split_dict (Dict) – dictionary with signature {name_of_subset: num_data/frac_data} where num_data must sum up to the size of the given dataset or frac_data must sum up to 1

  • dataset_key (str) – name of the data subset to return

  • seed (int) – seed for reproducible splits