Source code for nequip.data.datamodule._ase_datamodule

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
from ._base_datamodule import NequIPDataModule

from omegaconf import ListConfig, DictConfig, OmegaConf
from typing import Union, List, Callable, Optional, Dict


[docs] class ASEDataModule(NequIPDataModule): """LightningDataModule for `ASE <https://wiki.fysik.dtu.dk/ase/ase/io/io.html>`_-readable datasets. Interface similar to :class:`~nequip.data.datamodule.NequIPDataModule`, except that all the datasets are given in terms of paths to relevant ASE-readable files. Args: seed (int): data seed for reproducibility train_file_path (str/List[str]): path to training dataset file val_file_path (str/List[str]): path(s) to validation dataset file test_file_path (str/List[str]): path(s) to test dataset file predict_file_path (str/List[str]): path(s) to prediction dataset file split_dataset (Dict/List[Dict]): dictionary or list of dictionaries with a ``file_path`` key, which is the path to the ASE-readable dataset file and the keys ``train``, ``val``, ``test``, ``predict`` which represent the subsets to split the dataset into and are either ``int`` s that sum up to the size of ``dataset`` or ``float`` s that sum up to 1 (at least 2, but not necessarily all of ``train``, ``val``, ``test``, ``predict`` must be provided if this option is used) transforms (List[Callable]): list of data transforms ase_args (Dict[str, Any]): arguments for ``ase.io.iread`` (see `here <https://wiki.fysik.dtu.dk/ase/ase/io/io.html#ase.io.iread>`_) include_keys (List[str]): the keys that needs to be parsed in addition to forces and energy; the data stored in ``ase.atoms.Atoms.array`` has the lowest priority, and it will be overrided by data in ``ase.atoms.Atoms.info`` and ``ase.atoms.Atoms.calc.results`` exclude_keys (List[str]): list of keys that may be present in the ASE-readable file but the user wishes to exclude key_mapping (Dict[str, str]): mapping of ``ase`` keys to ``AtomicDataDict`` keys """ def __init__( self, seed: int, # file paths train_file_path: Optional[Union[str, List[str]]] = [], val_file_path: Optional[Union[str, List[str]]] = [], test_file_path: Optional[Union[str, List[str]]] = [], predict_file_path: Optional[Union[str, List[str]]] = [], split_dataset: Optional[Union[Dict, List[Dict]]] = [], # data transforms transforms: List[Callable] = [], # ase params ase_args: dict = {}, include_keys: Optional[List[str]] = [], exclude_keys: Optional[List[str]] = [], key_mapping: Optional[Dict[str, str]] = {}, **kwargs, ): # == first convert all dataset paths to lists if not already lists == dataset_paths = [] for paths in [ train_file_path, val_file_path, test_file_path, predict_file_path, split_dataset, ]: # convert to primitives as later logic is based on types if isinstance(paths, ListConfig) or isinstance(paths, DictConfig): paths = OmegaConf.to_container(paths, resolve=True) assert ( isinstance(paths, list) or isinstance(paths, str) or isinstance(paths, dict) ) if not isinstance(paths, list): # convert str -> List[str] dataset_paths.append([paths]) else: dataset_paths.append(paths) # == assemble config template == dataset_config_template = { "_target_": "nequip.data.dataset.ASEDataset", "transforms": transforms, "ase_args": ase_args, "include_keys": include_keys, "exclude_keys": exclude_keys, "key_mapping": key_mapping, } # == populate train, val, test predict, split datasets == dataset_configs = [[], [], [], []] for config, paths in zip(dataset_configs, dataset_paths[:-1]): for path in paths: dataset_config = dataset_config_template.copy() dataset_config.update({"file_path": path}) config.append(dataset_config) # == populate split dataset == split_config = [] for path_and_splits in dataset_paths[-1]: assert "file_path" in path_and_splits, ( "`file_path` key must be present in each dict of `split_dataset`" ) dataset_config = dataset_config_template.copy() file_path = path_and_splits.pop("file_path") dataset_config.update({"file_path": file_path}) path_and_splits.update( {"dataset": dataset_config} ) # now actually dataset_and_splits split_config.append(path_and_splits) super().__init__( seed=seed, train_dataset=dataset_configs[0], val_dataset=dataset_configs[1], test_dataset=dataset_configs[2], predict_dataset=dataset_configs[3], split_dataset=split_config, **kwargs, )