Source code for nequip.data.datamodule.coll_datamodule
from nequip.data.datamodule import ASEDataModule
from nequip.utils.file_utils import download_url
from nequip.utils.logger import RankedLogger
import os
from typing import List, Callable
logger = RankedLogger(__name__, rank_zero_only=True)
_URL_TRAIN = "https://figshare.com/ndownloader/files/25605734"
_URL_VAL = "https://figshare.com/ndownloader/files/25605737"
_URL_TEST = "https://figshare.com/ndownloader/files/25605740"
# Least square solve for per-atom energies yields
# C: -1035.412048 (0.036)
# H: -16.834627 (0.023)
# O: -2046.033121 (0.041)
[docs]
class COLLDataModule(ASEDataModule):
"""LightningDataModule for the COLL dataset from `<https://arxiv.org/abs/2011.14115>`_.
Args:
seed (int): data seed for reproducibility
transforms (List[Callable]): list of data transforms
data_source_dir (str): directory where dataset files will be downloaded to if not already present
"""
def __init__(
self,
seed: int,
transforms: List[Callable],
data_source_dir: str,
**kwargs,
):
self.data_source_dir = data_source_dir
train_file_path = os.path.join(data_source_dir, "coll_v1.2_AE_train.xyz")
val_file_path = os.path.join(data_source_dir, "coll_v1.2_AE_val.xyz")
test_file_path = os.path.join(data_source_dir, "coll_v1.2_AE_test.xyz")
super().__init__(
seed=seed,
train_file_path=train_file_path,
val_file_path=val_file_path,
test_file_path=test_file_path,
transforms=transforms,
**kwargs,
)
self.train_file_path = train_file_path
self.val_file_path = val_file_path
self.test_file_path = test_file_path
def prepare_data(self):
""""""
os.makedirs(self.data_source_dir, exist_ok=True)
files_to_download = [
(self.train_file_path, _URL_TRAIN, "training"),
(self.val_file_path, _URL_VAL, "validation"),
(self.test_file_path, _URL_TEST, "test"),
]
for file_path, url, dataset_type in files_to_download:
if not os.path.isfile(file_path):
logger.info(f"Downloading {dataset_type} dataset to `{file_path}`")
download_url(
url, self.data_source_dir, filename=os.path.basename(file_path)
)
else:
logger.info(f"Using existing {dataset_type} data file `{file_path}`")