Source code for nequip.data.datamodule.samd23_datamodule
from ._ase_datamodule import ASEDataModule
from nequip.utils.file_utils import extract_tar
from nequip.utils.logger import RankedLogger
import os
from typing import List, Callable
logger = RankedLogger(__name__, rank_zero_only=True)
_URLS = {
"HfO": "https://drive.google.com/uc?id=1-DVMGyXjvNYaBtaAkWu8uQVgvz8pEgMZ",
"SiN": "https://drive.google.com/uc?id=1l9nsie40Bpm8CNW4sx94yAuvmMkUfM3b",
}
[docs]
class SAMD23DataModule(ASEDataModule):
"""LightningDataModule for the `Samsung SAMD23 dataset <https://proceedings.neurips.cc/paper_files/paper/2023/hash/a1859debfb3b59d094f3504d5ebb6c25-Abstract-Datasets_and_Benchmarks.html>`_.
This datamodule can be used for ``train``, ``validate``, and ``test`` runs.
It automatically downloads the dataset from Google Drive using ``gdown``, extracts it into ``data_source_dir``,
and loads ASE-compatible datasets from the pre-split ``Trainset.xyz``, ``Validset.xyz``, and ``Testset.xyz`` files.
If ``include_ood=True``, the datamodule also looks for an ``OOD.xyz`` file in the same folder.
If found, this file is included as a second test set during evaluation. ``Testset.xyz`` remains the main in-distribution test set.
This setting does not affect training or validation — only test evaluation.
Users may also download and extract the data manually.
In that case, the extracted folder (``HfO/`` or ``SiN/``) should be placed inside ``data_source_dir``, and the expected filenames must be preserved.
.. note::
Automatic downloading requires the optional ``gdown`` package. Install with ``pip install gdown``.
Args:
seed (int): data seed for reproducibility
transforms (List[Callable]): list of NequIP data transforms to apply
data_source_dir (str): directory to store and/or locate the dataset
system (str): ``HfO`` or ``SiN`` (default ``HfO``)
include_ood (bool): whether to include ``OOD.xyz`` as a second test set. If True, the test split will contain both ``Testset.xyz`` and ``OOD.xyz``, evaluated as separate test sets. (default ``True``)
"""
def __init__(
self,
seed: int,
transforms: List[Callable],
data_source_dir: str,
system: str = "HfO",
include_ood: bool = True,
**kwargs,
):
system = system.strip()
assert system in _URLS, (
f"Unknown system `{system}`; must be one of {list(_URLS)}"
)
self.system = system
self.data_source_dir = data_source_dir
self.dataset_dir = os.path.join(data_source_dir, system)
self.include_ood = include_ood
self.ood_path = os.path.join(self.dataset_dir, "OOD.xyz")
self.train_file_path = os.path.join(self.dataset_dir, "Trainset.xyz")
self.val_file_path = os.path.join(self.dataset_dir, "Validset.xyz")
test_file_paths = [os.path.join(self.dataset_dir, "Testset.xyz")]
if include_ood:
test_file_paths.append(self.ood_path)
super().__init__(
seed=seed,
train_file_path=self.train_file_path,
val_file_path=self.val_file_path,
test_file_path=test_file_paths,
transforms=transforms,
**kwargs,
)
def prepare_data(self):
""""""
required_files = [
self.train_file_path,
self.val_file_path,
os.path.join(self.dataset_dir, "Testset.xyz"),
]
if not all(os.path.isfile(f) for f in required_files):
logger.info(
f"Dataset files for {self.system} not found locally. Downloading from Google Drive..."
)
archive_path = os.path.join(self.data_source_dir, f"{self.system}.tar")
if not os.path.isfile(archive_path):
drive_url = _URLS[self.system]
logger.info(f"Downloading {self.system} dataset from: {drive_url}")
try:
import gdown
except ImportError as e:
raise ImportError(
"Downloading the SAMD23 dataset requires the optional 'gdown' package. "
"Please install it with `pip install gdown` and try again."
) from e
gdown.download(drive_url, archive_path, quiet=False)
else:
logger.info(f"Archive already exists at: {archive_path}")
extract_tar(path=archive_path, folder=self.data_source_dir, mode="r:")
else:
logger.info(f"Using existing data files in `{self.dataset_dir}`")
# Log OOD file status after extraction
if self.include_ood:
if os.path.isfile(self.ood_path):
logger.info(f"Confirmed OOD test set exists at: {self.ood_path}")
else:
logger.warning(
f"OOD test set requested but not found at: {self.ood_path}"
)