nequip.data.datamodule

For usage examples and configuration guidance, see the Data Configuration guide.

class nequip.data.datamodule.NequIPDataModule(seed: int, train_dataset: Dict | List | None = [], val_dataset: Dict | List | None = [], test_dataset: Dict | List | None = [], predict_dataset: Dict | List | None = [], split_dataset: Dict | List | None = [], train_dataloader: Dict = {}, val_dataloader: Dict = {}, test_dataloader: Dict = {}, predict_dataloader: Dict = {}, stats_manager: Dict | None = None)[source]

Sanity checking is only performed at runtime – ensure that the correct datasets are provided for the intended runs, which can be train, val, test, and/or predict.

  • train runs require train_dataset and val_dataset

  • val runs require val_dataset

  • test runs require test_dataset

  • predict runs require predict_dataset

One can explicitly specify which train, val, test, predict datasets to use, or randomly split a dataset to be used for any of those tasks with the split_dataset argument. These options are not mutually exclusive, e.g. if a single test_dataset is provided, and split_dataset is used to get another test set, there will now be two test sets (indexed by 0 and 1) used for testing. If test_dataset is a list, i.e. multiple test datasets are provided (e.g. if there are n test sets with indices 0, 1, …, n - 1) and multiple split_ataset is a list that contributes multiple test sets (say m such test sets are provided). There will be a total of m+n test sets, with the ones from test_dataset taking indices 0, 1, …, n - 1 and the ones from the split_dataset taking indices n, n+1, …, n+m-1.

Parameters:
  • seed (int) – data seed for reproducibility

  • train_dataset (Dict/List[Dict]) – training dataset

  • val_dataset (Dict/List[Dict]) – validation dataset(s) (can provide multiple datasets in a list)

  • test_dataset (Dict/List[Dict]) – test dataset(s) (can provide multiple datasets in a list)

  • predict_dataset (Dict/List[Dict]) – prediction dataset(s) (can provide multiple datasets in a list)

  • split_dataset (Dict/List[Dict]) – dictionary with a dataset key, which defines the dataset and the keys train, val, test, predict which represent the subsets to split dataset into and are either int s that sum up to the size of dataset or float s that sum up to 1 (at least 2, but not necessarily all of train, val, test, predict must be provided if this option is used)

  • train_dataloader (Dict) – training DataLoader configuration dictionary

  • val_dataloader (Dict) – validation DataLoader configuration dictionary

  • test_dataloader (Dict) – testing DataLoader configuration dictionary

  • predict_dataloader (Dict) – prediction DataLoader configuration dictionary

  • stats_manager (Dict) – dictionary that can be instantiated into a DataStatisticsManager object

get_statistics(dataset: str = 'train', dataset_idx: int = 0)[source]

Compute statistics of the dataset.

Parameters:
  • dataset (str) – train, val, test, or predict

  • dataset_idx (int) – dataset index (there can be multiple val, test, predict datasets)

class nequip.data.datamodule.ASEDataModule(seed: int, train_file_path: str | List[str] | None = [], val_file_path: str | List[str] | None = [], test_file_path: str | List[str] | None = [], predict_file_path: str | List[str] | None = [], split_dataset: Dict | List[Dict] | None = [], transforms: List[Callable] = [], ase_args: dict = {}, include_keys: List[str] | None = [], exclude_keys: List[str] | None = [], key_mapping: Dict[str, str] | None = {}, **kwargs)[source]

LightningDataModule for ASE-readable datasets.

Interface similar to NequIPDataModule, except that all the datasets are given in terms of paths to relevant ASE-readable files.

Parameters:
  • seed (int) – data seed for reproducibility

  • train_file_path (str/List[str]) – path to training dataset file

  • val_file_path (str/List[str]) – path(s) to validation dataset file

  • test_file_path (str/List[str]) – path(s) to test dataset file

  • predict_file_path (str/List[str]) – path(s) to prediction dataset file

  • split_dataset (Dict/List[Dict]) – dictionary or list of dictionaries with a file_path key, which is the path to the ASE-readable dataset file and the keys train, val, test, predict which represent the subsets to split the dataset into and are either int s that sum up to the size of dataset or float s that sum up to 1 (at least 2, but not necessarily all of train, val, test, predict must be provided if this option is used)

  • transforms (List[Callable]) – list of data transforms

  • ase_args (Dict[str, Any]) – arguments for ase.io.iread (see here)

  • include_keys (List[str]) – the keys that needs to be parsed in addition to forces and energy; the data stored in ase.atoms.Atoms.array has the lowest priority, and it will be overrided by data in ase.atoms.Atoms.info and ase.atoms.Atoms.calc.results

  • exclude_keys (List[str]) – list of keys that may be present in the ASE-readable file but the user wishes to exclude

  • key_mapping (Dict[str, str]) – mapping of ase keys to AtomicDataDict keys

class nequip.data.datamodule.sGDML_CCSD_DataModule(dataset: str, data_source_dir: str, transforms: List[Callable], seed: int, train_val_split: Sequence[int | float], trainval_test_subset: List[int] | None = None, **kwargs)[source]

Lightning Data Module responsible for processing sGDML CCSD datasets (including downloading).

The sGDML datasets can be found at http://www.sgdml.org/#datasets. This class handles the CCSD and CCSD(T) datasets, including aspirin, benzene, malonaldehyde, toluene, and ethanol.

Parameters:
  • dataset (str) – aspirin, benzene, malonaldehyde, toluene, or ethanol

  • data_source_dir (str) – directory to download sGDML CCSD data to, or where the npz files are present if already downloaded and unzipped

  • transforms (List[Callable]) – list of data transforms

  • seed (int) – data seed for reproducibility

  • train_val_split (List[float]/List[int]) – train-validation split either in fractions [1, 1-f] or integers [N_train, N_val]

  • trainval_test_subset (List[int]) – Subset of [N_train + N_val, N_test] to use from the full dataset (the intended use is for minimal tests)

class nequip.data.datamodule.rMD17DataModule(dataset: str, data_source_dir: str, transforms: List[Callable], seed: int, train_val_test_split: Sequence[int | float], subset_len: int | None = None, **kwargs)[source]

Lightning Data Module responsible for processing rMD17 datasets (including downloading).

The revised MD-17 datasets can be found at this link . This class handles all datasets included in the file: aspirin, azobenzene, benzene, ethanol, malonaldehyde, naphthalene, paracetamol, salicylic, toluene and uracil. Each dataset contains 100,000 samples for each molecule, with the exception of azobenzene that contains 99,988 samples. Each dataset is not pre-split into training, validation and testing sets. The user has to specify the split using the train_val_test_split argument.

Note

If only a subset of the dataset is meant to be used (e.g. for testing), the subset_len argument can be used to specify the number of samples to use. In this case, train_val_test_split has to be set either as fractions or as a list of integers that sum up to subset_len. If subset_len is not set, the full dataset is used.

Parameters:
  • dataset (str) – aspirin, azobenzene, benzene, ethanol, malonaldehyde, naphthalene, paracetamol, salicylic, toluene or uracil.

  • data_source_dir (str) – directory to download the data to, or where the npz files are present if already downloaded and unzipped

  • transforms (List[Callable]) – list of data transforms

  • seed (int) – data seed for reproducibility

  • train_val_test_split (List[float]/List[int]) – train-validation-test split either in fractions [a, b, c] (a+b+c=1) or integers [N_train, N_val, N_test]. If using integers, they have to sum up to either the total number of samples in the dataset, or to the subset_len if it is set.

  • subset_len (int) – Subset of N_train + N_val + N_test to use from the full dataset (the intended use is for minimal tests).

class nequip.data.datamodule.MD22DataModule(dataset: str, data_source_dir: str, transforms: List[Callable], seed: int, train_val_split: Sequence[int | float], **kwargs)[source]

Lightning Data Module responsible for processing sGDML MD22 datasets (including downloading).

This class handles the MD22 datasets, including tetrapeptide (CHNO), dha (CHO), stachyose (CHO), dna_atat (CHNO), dna_atat_cgcg (CHNO), buckyball_catcher (CH), and double_walled_nanotube (CH). See Science Advances 9.2 (2023): eadf0873 for more details.

This datamodule will automatically use the training set sizes from the paper, that is, tetrapeptide (6,000/85,109), dha (8,000/69,753), stachyose (8,000/27,272), dna_atat (3,000/20,001), dna_atat_cgcg (2,000/10,153), buckyball_catcher (600/6,102), and double_walled_nanotube (800/5,032). The “training set” will then be partitioned into train and validation datasets based on train_val_split. The remainder is used as the test dataset.

Parameters:
  • dataset (str) – tetrapeptide, dha, stachyose, dna_atat, dna_atat_cgcg, buckyball_catcher, or double_walled_nanotube

  • data_source_dir (str) – directory to download sGDML MD22 data to, or where the npz files are present if already downloaded

  • transforms (List[Callable]) – list of data transforms

  • seed (int) – data seed for reproducibility

  • train_val_split (List[float] or List[int]) – train-validation split either in fractions [1, 1-f] or integers [N_train, N_val]

class nequip.data.datamodule.NequIP3BPADataModule(seed: int, transforms: List[Callable], train_val_split: Sequence[int | float], data_source_dir: str, train_set: str = '300K', test_sets: List[str] = ['300K', '600K', '1200K', 'dih_beta120', 'dih_beta150', 'dih_beta180'], **kwargs)[source]

LightningDataModule for the 3BPA dataset.

This datamodule can be used for train, validate, and test runs.

This datamodule can automatically download the dataset to data_source_dir. Users can also manually download the 3BPA zipfile and unzip it (data_source_dir should then be the directory containing the dataset_3BPA directory). Users must not tamper with the contents of the dataset_3BPA directory produced upon unzipping as this datamodule assumes the default filenames in the directory.

The 3BPA dataset has two possible training sets, one at 300K and one with mixed temperatures. The 300K training set is used by default, but users can specify it with the train_set – either 300K or mixedT is allowed.

The train_val_split argument is required to split the train_set chosen into separate training and validation datasets, as in NequIPDataModule.

There are several test datasets to choose from, including 300K, 600K, 1200K, dih_beta120, dih_beta150, and dih_beta180. All are automatically included in the testing dataset in that order by default, but one can override this by providing the test_set argument as a List test sets. One can provide an empty list to have no test sets.

It is recommended to set the isolated atom energies in the model’s per_type_energy_shifts. The following information can be found in iso_atoms.xyz in the 3BPA data zip, but is reproduced here in the format of the config arguments:

model:
  type_names: [C, H, N, O]
  per_type_energy_shifts:
    C: -1029.4889999855063
    H: -13.587222780835477
    N: -1484.9814568572233
    O: -2041.9816003861047
Parameters:
  • seed (int) – data seed for reproducibility

  • transforms (List[Callable]) – list of data transforms

  • train_val_split (List[float]/List[int]) – train-validation split either in fractions [1, 1-f] or integers [N_train, N_val]

  • data_source_dir (str) – directory to download 3BPA dataset to, or where the dataset_3BPA directory is located if already downloaded and unzipped

  • train_set (str) – either 300K or mixedT

  • test_set (List[str]) – list that can contain 300K, 600K, 1200K, dih_beta120, dih_beta150, and/or dih_beta180

class nequip.data.datamodule.COLLDataModule(seed: int, transforms: List[Callable], data_source_dir: str, **kwargs)[source]

LightningDataModule for the COLL dataset from https://arxiv.org/abs/2011.14115.

Parameters:
  • seed (int) – data seed for reproducibility

  • transforms (List[Callable]) – list of data transforms

  • data_source_dir (str) – directory where dataset files will be downloaded to if not already present

class nequip.data.datamodule.TM23DataModule(seed: int, data_source_dir: str, element: str, transforms: List[Callable], train_val_split: Sequence[int | float], **kwargs)[source]

LightningDataModule for the TM23 dataset.

This datamodule can be used for train, validate, and test runs.

This datamodule can automatically download the TM23 dataset from https://archive.materialscloud.org/record/2024.48 and unzip it in data_source_dir if not already downloaded. Otherwise, one can download and unzip the dataset as is and set data_source_dir to the directory that contains benchmarking_master_collection.

The combined dataset containing cold, warm, and melt frames are used as the train and test datasets. element can be any TM23 element, including Ag, Au, Cd, Co, Cr, Cu, Fe, Hf, Hg, Ir, Mn, Mo, Nb, Ni, Os, Pd, Pt, Re, Rh, Ru, Ta, Tc, Ti, V, W, Zn, and Zr.

The train_val_split argument is required to split the training dataset into separate training and validation datasets.

Parameters:
  • seed (int) – data seed for reproducibility

  • data_source_dir (str) – directory containing the TM23 dataset if present, else directory where TM23 dataset will be downloaded to

  • element (str) – element from TM23 dataset to use

  • transforms (List[Callable]) – list of data transforms

  • train_val_split (List[float] or List[int]) – train-validation split either in fractions [1, 1-f] or integers [N_train, N_val]

class nequip.data.datamodule.SAMD23DataModule(seed: int, transforms: List[Callable], data_source_dir: str, system: str = 'HfO', include_ood: bool = True, **kwargs)[source]

LightningDataModule for the Samsung SAMD23 dataset.

This datamodule can be used for train, validate, and test runs.

It automatically downloads the dataset from Google Drive using gdown, extracts it into data_source_dir, and loads ASE-compatible datasets from the pre-split Trainset.xyz, Validset.xyz, and Testset.xyz files.

If include_ood=True, the datamodule also looks for an OOD.xyz file in the same folder. If found, this file is included as a second test set during evaluation. Testset.xyz remains the main in-distribution test set. This setting does not affect training or validation — only test evaluation.

Users may also download and extract the data manually. In that case, the extracted folder (HfO/ or SiN/) should be placed inside data_source_dir, and the expected filenames must be preserved.

Note

Automatic downloading requires the optional gdown package. Install with pip install gdown.

Parameters:
  • seed (int) – data seed for reproducibility

  • transforms (List[Callable]) – list of NequIP data transforms to apply

  • data_source_dir (str) – directory to store and/or locate the dataset

  • system (str) – HfO or SiN (default HfO)

  • include_ood (bool) – whether to include OOD.xyz as a second test set. If True, the test split will contain both Testset.xyz and OOD.xyz, evaluated as separate test sets. (default True)

class nequip.data.datamodule.WaterDataModule(seed: int, transforms: List[Callable], data_source_dir: str, train_val_test_split: Sequence[int | float], **kwargs)[source]

LightningDataModule for the water dataset from Cheng, Bingqing, et al. “Ab initio thermodynamics of liquid and solid water.” Proceedings of the National Academy of Sciences 116.4 (2019): 1110-1115..

Parameters:
  • seed (int) – data seed for reproducibility

  • transforms (List[Callable]) – list of data transforms

  • data_source_dir (str) – directory that contains dataset_1593_eVAng.xyz if present, else the directory that dataset_1593_eVAng.xyz will be downloaded to

  • train_val_test_split (Sequence[Union[int, float]]) – [train, val, test] split ratio