Source code for nequip.data.datamodule.sgdml_datamodule

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
from ._base_datamodule import NequIPDataModule
from nequip.utils import download_url, extract_zip
from nequip.utils.logger import RankedLogger

import os
from typing import Union, Sequence, Optional, List, Callable

logger = RankedLogger(__name__, rank_zero_only=True)


[docs] class sGDML_CCSD_DataModule(NequIPDataModule): """Lightning Data Module responsible for processing sGDML CCSD datasets (including downloading). The sGDML datasets can be found at http://www.sgdml.org/#datasets. This class handles the CCSD and CCSD(T) datasets, including ``aspirin``, ``benzene``, ``malonaldehyde``, ``toluene``, and ``ethanol``. Args: dataset (str): ``aspirin``, ``benzene``, ``malonaldehyde``, ``toluene``, or ``ethanol`` data_source_dir (str): directory to download sGDML CCSD data to, or where the npz files are present if already downloaded and unzipped transforms (List[Callable]): list of data transforms seed (int): data seed for reproducibility train_val_split (List[float]/List[int]): train-validation split either in fractions ``[1, 1-f]`` or integers ``[N_train, N_val]`` trainval_test_subset (List[int]): Subset of ``[N_train + N_val, N_test]`` to use from the full dataset (the intended use is for minimal tests) """ DATASET_MAP = { "aspirin": "aspirin_ccsd", "benzene": "benzene_ccsd_t", "malonaldehyde": "malonaldehyde_ccsd_t", "toluene": "toluene_ccsd_t", "ethanol": "ethanol_ccsd_t", } URL_DICT = { "aspirin": "http://www.quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip", "benzene": "http://www.quantum-machine.org/gdml/data/npz/benzene_ccsd_t.zip", "malonaldehyde": "http://www.quantum-machine.org/gdml/data/npz/malonaldehyde_ccsd_t.zip", "toluene": "http://www.quantum-machine.org/gdml/data/npz/toluene_ccsd_t.zip", "ethanol": "http://www.quantum-machine.org/gdml/data/npz/ethanol_ccsd_t.zip", } def __init__( self, dataset: str, data_source_dir: str, transforms: List[Callable], seed: int, train_val_split: Sequence[Union[int, float]], trainval_test_subset: Optional[List[int]] = None, **kwargs, ): assert dataset in self.DATASET_MAP.keys(), ( f"`dataset={dataset}` not supported, `dataset` can be any of {list(self.DATASET_MAP.keys())}" ) train_file_path = "/".join( [data_source_dir, self.DATASET_MAP[dataset] + "-train.npz"] ) test_file_path = "/".join( [data_source_dir, self.DATASET_MAP[dataset] + "-test.npz"] ) trainval_config = { "_target_": "nequip.data.dataset.NPZDataset", "file_path": train_file_path, "transforms": transforms, } test_config = trainval_config.copy() test_config["file_path"] = test_file_path if trainval_test_subset is not None: assert len(trainval_test_subset) == 2 trainval_config = { "_target_": "nequip.data.dataset.SubsetByRandomSlice", "dataset": trainval_config, "start": 0, "length": trainval_test_subset[0], "seed": seed, } test_config = { "_target_": "nequip.data.dataset.SubsetByRandomSlice", "dataset": test_config, "start": 0, "length": trainval_test_subset[1], "seed": seed, } super().__init__( seed=seed, test_dataset=[test_config], split_dataset=[ { "dataset": trainval_config, "train": train_val_split[0], "val": train_val_split[1], } ], **kwargs, ) self.dataset = dataset self.data_source_dir = data_source_dir self.train_file_path = train_file_path self.test_file_path = test_file_path def prepare_data(self): """""" if not ( os.path.isfile(self.train_file_path) and os.path.isfile(self.test_file_path) ): # download and unzip download_path = download_url( self.URL_DICT[self.dataset], self.data_source_dir ) if download_path.endswith(".zip"): extract_zip(download_path, self.data_source_dir) else: logger.info( f"Using existing data files `{self.train_file_path}` and `{self.test_file_path}`" )