Source code for nequip.data.datamodule.md22_datamodule

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
from ._base_datamodule import NequIPDataModule
from nequip.data import AtomicDataDict
from nequip.utils import download_url, extract_zip
from nequip.utils.logger import RankedLogger

import os
from typing import Union, Sequence, List, Callable

logger = RankedLogger(__name__, rank_zero_only=True)

KCALMOL_TO_EV = 0.0433641


def _kcalmol_to_ev(data: AtomicDataDict.Type) -> AtomicDataDict.Type:
    data[AtomicDataDict.TOTAL_ENERGY_KEY] = (
        data[AtomicDataDict.TOTAL_ENERGY_KEY] * KCALMOL_TO_EV
    )
    data[AtomicDataDict.FORCE_KEY] = data[AtomicDataDict.FORCE_KEY] * KCALMOL_TO_EV
    return data


[docs] class MD22DataModule(NequIPDataModule): """Lightning Data Module responsible for processing sGDML MD22 datasets (including downloading). This class handles the MD22 datasets, including ``tetrapeptide`` (CHNO), ``dha`` (CHO), ``stachyose`` (CHO), ``dna_atat`` (CHNO), ``dna_atat_cgcg`` (CHNO), ``buckyball_catcher`` (CH), and ``double_walled_nanotube`` (CH). See `Science Advances 9.2 (2023): eadf0873 <https://www.science.org/doi/10.1126/sciadv.adf0873>`_ for more details. This datamodule will automatically use the training set sizes from the paper, that is, ``tetrapeptide`` (6,000/85,109), ``dha`` (8,000/69,753), ``stachyose`` (8,000/27,272), ``dna_atat`` (3,000/20,001), ``dna_atat_cgcg`` (2,000/10,153), ``buckyball_catcher`` (600/6,102), and ``double_walled_nanotube`` (800/5,032). The "training set" will then be partitioned into train and validation datasets based on ``train_val_split``. The remainder is used as the test dataset. Args: dataset (str): ``tetrapeptide``, ``dha``, ``stachyose``, ``dna_atat``, ``dna_atat_cgcg``, ``buckyball_catcher``, or ``double_walled_nanotube`` data_source_dir (str): directory to download sGDML MD22 data to, or where the npz files are present if already downloaded transforms (List[Callable]): list of data transforms seed (int): data seed for reproducibility train_val_split (List[float] or List[int]): train-validation split either in fractions ``[1, 1-f]`` or integers ``[N_train, N_val]`` """ # dataset: [file_name, num_trainval, num_data] dataset_map = { "tetrapeptide": ["md22_Ac-Ala3-NHMe.npz", 6000, 85109], "dha": ["md22_DHA.npz", 8000, 69753], "stachyose": ["md22_stachyose.npz", 8000, 27272], "dna_atat": ["md22_AT-AT.npz", 3000, 20001], "dna_atat_cgcg": ["md22_AT-AT-CG-CG.npz", 2000, 10153], "buckyball_catcher": ["md22_buckyball-catcher.npz", 600, 6102], "double_walled_nanotube": ["md22_double-walled_nanotube.npz", 800, 5032], } url_dict = { "tetrapeptide": "http://www.quantum-machine.org/gdml/repo/datasets/md22_Ac-Ala3-NHMe.npz", "dha": "http://www.quantum-machine.org/gdml/repo/datasets/md22_DHA.npz", "stachyose": "http://www.quantum-machine.org/gdml/repo/datasets/md22_stachyose.npz", "dna_atat": "http://www.quantum-machine.org/gdml/repo/datasets/md22_AT-AT.npz", "dna_atat_cgcg": "http://www.quantum-machine.org/gdml/repo/datasets/md22_AT-AT-CG-CG.npz", "buckyball_catcher": "http://www.quantum-machine.org/gdml/repo/datasets/md22_buckyball-catcher.npz", "double_walled_nanotube": "http://www.quantum-machine.org/gdml/repo/datasets/md22_double-walled_nanotube.npz", } def __init__( self, dataset: str, data_source_dir: str, transforms: List[Callable], seed: int, train_val_split: Sequence[Union[int, float]], **kwargs, ): assert dataset in self.dataset_map.keys(), ( f"`dataset={dataset}` not supported, `dataset` can be any of {list(self.dataset_map.keys())}" ) data_file_path = "/".join([data_source_dir, self.dataset_map[dataset][0]]) base_config = { "_target_": "nequip.data.dataset.NPZDataset", "file_path": data_file_path, "transforms": list(transforms) + [_kcalmol_to_ev], "key_mapping": { "R": AtomicDataDict.POSITIONS_KEY, "z": AtomicDataDict.ATOMIC_NUMBERS_KEY, "E": AtomicDataDict.TOTAL_ENERGY_KEY, "F": AtomicDataDict.FORCE_KEY, }, } train_end_idx = self.dataset_map[dataset][1] test_length = self.dataset_map[dataset][2] - train_end_idx train_config = { "_target_": "nequip.data.dataset.SubsetByRandomSlice", "dataset": base_config, "start": 0, "length": train_end_idx, "seed": seed, } test_config = { "_target_": "nequip.data.dataset.SubsetByRandomSlice", "dataset": base_config, "start": train_end_idx, "length": test_length, "seed": seed, } super().__init__( seed=seed, split_dataset={ "dataset": train_config, "train": train_val_split[0], "val": train_val_split[1], }, test_dataset=test_config, **kwargs, ) self.dataset = dataset self.data_source_dir = data_source_dir self.data_file_path = data_file_path def prepare_data(self): """""" if not os.path.isfile(self.data_file_path): # download and unzip download_path = download_url( self.url_dict[self.dataset], self.data_source_dir ) if download_path.endswith(".zip"): extract_zip(download_path, self.data_source_dir) else: logger.info(f"Using existing data file `{self.data_file_path}`")