Source code for nequip.data.dataset.npz_dataset

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import numpy as np
import torch

from .. import AtomicDataDict
from ..dict import from_dict
from .base_datasets import AtomicDataset

from typing import Union, Dict, List, Callable


[docs] class NPZDataset(AtomicDataset): """:class:`~nequip.data.dataset.AtomicDataset` that loads data from an NPZ file following `sGDML <https://www.sgdml.org/#datasets>`_ conventions. It is also compatible with other datasets such as rMD-17, with a change in ``key_mapping`` (the default ``key_mapping`` is set to be compatible with sGDML datasets). The ``NPZDataset`` avoids loading the whole dataset into memory. Args: file_path (str): path to npz file transforms (List[Callable]): list of data transforms key_mapping (Dict[str, str]): mapping of array names in the npz file to ``AtomicDataDict`` keys """ def __init__( self, file_path: str, transforms: List[Callable] = [], key_mapping: Dict[str, str] = { "R": AtomicDataDict.POSITIONS_KEY, "z": AtomicDataDict.ATOMIC_NUMBERS_KEY, "E": AtomicDataDict.TOTAL_ENERGY_KEY, "F": AtomicDataDict.FORCE_KEY, }, ): super().__init__(transforms=transforms) self.file_path = file_path self.key_mapping = key_mapping # use energy array to get num_frames (small to load into memory) E_key = None for k, v in key_mapping.items(): if v == AtomicDataDict.TOTAL_ENERGY_KEY: E_key = k assert E_key is not None, ( "No key corresponding to `total_energy` found in npz dataset" ) with np.load(self.file_path, mmap_mode="r") as npz_data: self.num_frames = npz_data[E_key].shape[0] def __len__(self) -> int: return self.num_frames def _get_data_list( self, indices: Union[List[int], torch.Tensor, slice], ) -> List[AtomicDataDict.Type]: if isinstance(indices, slice): indices = list(range(*indices.indices(self.num_frames))) # memory-map the file with np.load(self.file_path, mmap_mode="r") as npz_data: data_list = [] for idx in indices: data_dict = {} for k, v in self.key_mapping.items(): # special case for now, may generalize if needed in the future if v == AtomicDataDict.ATOMIC_NUMBERS_KEY: data_dict[v] = npz_data[k] else: data_dict[v] = npz_data[k][idx] data_list.append(from_dict(data_dict)) return data_list