Source code for nequip.data.datamodule.water_datamodule

from nequip.data import AtomicDataDict
from nequip.data.datamodule import ASEDataModule
from nequip.utils.file_utils import download_url
from nequip.utils.logger import RankedLogger

import os
from typing import Union, Sequence, List, Callable

logger = RankedLogger(__name__, rank_zero_only=True)


_URL_WATER = "https://github.com/BingqingCheng/Mapping-the-space-of-materials-and-molecules/raw/refs/heads/master/mlp-water/dataset_1593_eVAng.xyz"


# Least square solve for per-atom energies yields
#    H: -187.6044
#    O: -93.8022


[docs] class WaterDataModule(ASEDataModule): """LightningDataModule for the water dataset from `Cheng, Bingqing, et al. "Ab initio thermodynamics of liquid and solid water." Proceedings of the National Academy of Sciences 116.4 (2019): 1110-1115. <https://www.pnas.org/doi/full/10.1073/pnas.1815117116>`_. Args: seed (int): data seed for reproducibility transforms (List[Callable]): list of data transforms data_source_dir (str): directory that contains ``dataset_1593_eVAng.xyz`` if present, else the directory that ``dataset_1593_eVAng.xyz`` will be downloaded to train_val_test_split (Sequence[Union[int, float]]): ``[train, val, test]`` split ratio """ def __init__( self, seed: int, transforms: List[Callable], data_source_dir: str, train_val_test_split: Sequence[Union[int, float]], **kwargs, ): assert len(train_val_test_split) == 3 file_path = data_source_dir + "/dataset_1593_eVAng.xyz" super().__init__( seed=seed, split_dataset={ "file_path": file_path, "train": train_val_test_split[0], "val": train_val_test_split[1], "test": train_val_test_split[2], }, transforms=transforms, key_mapping={ "TotEnergy": AtomicDataDict.TOTAL_ENERGY_KEY, "force": AtomicDataDict.FORCE_KEY, }, **kwargs, ) self.data_source_dir = data_source_dir self.file_path = file_path def prepare_data(self): """""" if not os.path.isfile(self.file_path): _ = download_url(_URL_WATER, self.data_source_dir) else: logger.info(f"Using existing data file `{self.file_path}`")