Source code for nequip.data.dataset.ase_dataset

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
import ase
import ase.io

from .. import AtomicDataDict
from ..ase import from_ase
from .base_datasets import AtomicDataset

from typing import Union, Dict, List, Optional, Callable, Any


[docs] class ASEDataset(AtomicDataset): """:class:`~nequip.data.dataset.AtomicDataset` for `ASE <https://wiki.fysik.dtu.dk/ase/ase/io/io.html>`_-readable file formats. Args: file_path (str): path to ASE-readable file transforms (List[Callable]): list of data transforms ase_args (Dict[str, Any]): arguments for :func:`ase.io.iread` include_keys (List[str]): the keys that needs to be parsed into dataset in addition to standard keys (see :doc:`../../../api/data_fields`). The data stored in :attr:`ase.atoms.Atoms.array` has the lowest priority, and it will be overrided by data in :attr:`ase.atoms.Atoms.info` and :attr:`ase.atoms.Atoms.calc.results` exclude_keys (List[str]): list of keys that may be present in the ASE-readable file but the user wishes to exclude key_mapping (Dict[str, str]): mapping of ``ase`` keys to ``AtomicDataDict`` keys """ def __init__( self, file_path: str, transforms: List[Callable] = [], ase_args: Dict[str, Any] = {}, include_keys: Optional[List[str]] = [], exclude_keys: Optional[List[str]] = [], key_mapping: Optional[Dict[str, str]] = {}, ): super().__init__(transforms=transforms) self.file_path = file_path # process ase_args self.ase_args = {} self.ase_args.update(ase_args) assert "index" not in self.ase_args assert "filename" not in self.ase_args self.ase_args.update({"filename": self.file_path}) # read file and construct list of AtomicDataDicts self.data_list: List[AtomicDataDict.Type] = [] for atoms in ase.io.iread(**self.ase_args, parallel=False): self.data_list.append( from_ase( atoms=atoms, key_mapping=key_mapping, include_keys=include_keys, exclude_keys=exclude_keys, ) ) def __len__(self) -> int: return len(self.data_list) def _get_data_list( self, indices: Union[List[int], torch.Tensor, slice], ) -> List[AtomicDataDict.Type]: if isinstance(indices, slice): return self.data_list[indices] else: return [self.data_list[index] for index in indices]