Source code for nequip.data.dataset.test_data

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
from typing import Union, Callable, List, Tuple
import copy

import numpy as np
import torch

import ase
import ase.build
from ase.calculators.emt import EMT

from .. import AtomicDataDict
from ..dict import from_dict
from .base_datasets import AtomicDataset
from .lmdb_dataset import NequIPLMDBDataset


[docs] class EMTTestDataset(AtomicDataset): """Test dataset with PBC, based on the toy EMT potential included in ASE. Randomly generates (in a reproducable manner) a basic bulk with added Gaussian noise around equilibrium positions. Uses orthorhombic cell construction for safer testing. In ASE units (eV, Å, eV/Å). Args: transforms (List[Callable]): list of data transforms supercell (Tuple[int, int, int]): supercell in each lattice vector direction sigma (float): standard deviation of Gaussian noise element (str): element supported by ASE's `EMT <https://wiki.fysik.dtu.dk/ase/ase/calculators/emt.html>`_ calculator (supported elements: ``Cu``, ``Pd``, ``Au``, ``Pt``, ``Al``, ``Ni``, ``Ag``) num_frames (int): number of structures to be generated in the dataset seed (int): seed for the random Gaussian noise """ def __init__( self, transforms: List[Callable] = [], supercell: Tuple[int, int, int] = (4, 4, 4), sigma: float = 0.1, element: str = "Cu", num_frames: int = 10, seed: int = 123456, ): super().__init__(transforms=transforms) assert element in ("Cu", "Pd", "Au", "Pt", "Al", "Ni", "Ag") self.element = element self.sigma = sigma self.supercell = tuple(supercell) self.num_frames = num_frames self.seed = seed # generate data # NOTE: orthorhombic cell is safer for tests, e.g. LAMMPS base_atoms = ase.build.bulk(self.element, "fcc", orthorhombic=True).repeat( self.supercell ) base_atoms.calc = EMT() orig_pos = copy.deepcopy(base_atoms.positions) rng = np.random.default_rng(self.seed) self.data_list = [] for idx in range(len(self)): base_atoms.positions[:] = orig_pos base_atoms.positions += rng.normal( loc=0.0, scale=self.sigma, size=base_atoms.positions.shape ) self.data_list.append( from_dict( { "pos": base_atoms.positions, "cell": np.array(base_atoms.get_cell()), "pbc": base_atoms.get_pbc(), "atomic_numbers": base_atoms.get_atomic_numbers(), "forces": base_atoms.get_forces(), "total_energy": base_atoms.get_potential_energy(), "stress": np.expand_dims(base_atoms.get_stress(voigt=False), 0), } ) ) def __len__(self) -> int: return self.num_frames def _get_data_list( self, indices: Union[List[int], torch.Tensor, slice], ) -> List[AtomicDataDict.Type]: if isinstance(indices, slice): return self.data_list[indices] else: return [self.data_list[index] for index in indices]
class LMDBTestDataset(NequIPLMDBDataset): """LMDB wrapper for the `EMTTestDataset`.""" def __init__( self, file_path: str, transforms: List[Callable] = [], supercell: Tuple[int, int, int] = (4, 4, 4), sigma: float = 0.1, element: str = "Cu", num_frames: int = 10, seed: int = 123456, ): # Generate random data test_ds = EMTTestDataset( transforms=[], # transforms are applied in the LMDBDataset supercell=supercell, sigma=sigma, element=element, num_frames=num_frames, seed=seed, ) # Save to LMDB NequIPLMDBDataset.save_from_iterator( file_path=file_path, iterator=test_ds, ) # Initialize LMDB dataset # Note: transforms are not applied here, as they are already applied in the EMTTestDataset super().__init__(file_path=file_path, transforms=transforms)