Source code for nequip.integrations.ase
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
from typing import Union, Callable, Dict, List
import torch
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from nequip.data import AtomicDataDict, from_ase
from .mixins import _IntegrationLoaderMixin
[docs]
class NequIPCalculator(_IntegrationLoaderMixin, Calculator):
"""NequIP framework ASE Calculator.
This ASE Calculator is compatible with models from the NequIP framework, including NequIP and Allegro models.
The recommended way to use this Calculator is with a compiled model, i.e. ``nequip-compile`` the model and load it into the Calculator with ``NequIPCalculator.from_compiled_model(...)``. If one uses ``--mode aotinductor`` during ``nequip-compile``, it is important to use the flag ``--target ase`` for the compiled model file to work with this ASE Calculator.
.. warning::
If you are running MD with custom species, please make sure to set the correct masses for ASE.
Args:
model: a model in the NequIP framework
device (str/torch.device): device for model to evaluate on, e.g. ``cpu`` or ``cuda``
energy_units_to_eV (float): energy conversion factor (default ``1.0``)
length_units_to_A (float): length units conversion factor (default ``1.0``)
transforms (List[Callable]): list of data transforms
"""
implemented_properties = ["energy", "energies", "forces", "stress", "free_energy"]
@classmethod
def _get_aoti_compile_target(cls) -> Dict:
from nequip.scripts._compile_utils import COMPILE_TARGET_DICT, AOTI_ASE_TARGET
return COMPILE_TARGET_DICT[AOTI_ASE_TARGET]
def __init__(
self,
model: torch.nn.Module,
device: Union[str, torch.device],
energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0,
transforms: List[Callable] = [],
**kwargs,
):
Calculator.__init__(self, **kwargs)
self.results = {}
# === handle model ===
assert not model.training, (
"make sure to call .eval() on model before building NequIPCalculator"
)
# === handle device ===
if isinstance(device, str):
device = torch.device(device)
self.device = device
self.model = model.to(self.device)
# === data details ===
self.energy_units_to_eV = energy_units_to_eV
self.length_units_to_A = length_units_to_A
self.transforms = transforms
# logic to handle when the CPU -> GPU transfer happens
# i.e. if neighborlist is CPU-only, we should just do the data transforms before moving data to cuda for model execution (if self.device is cuda)
# this logic is just to avoid unnecessary CPU -> GPU or GPU -> CPU transfers wherever possible
from nequip.data.transforms import NeighborListTransform
from nequip.data._nl import NEIGHBORLIST_BACKEND_OPTIONS
neighborlist_transforms = [
t for t in self.transforms if isinstance(t, NeighborListTransform)
]
num_neighborlist_transforms = len(neighborlist_transforms)
assert num_neighborlist_transforms <= 1, (
"Expected at most one NeighborListTransform in calculator transforms."
)
if num_neighborlist_transforms == 1:
# if the neighborlist supports cuda, and we're running the model on cuda,
# we should just move the data to cuda first before the transforms so that neighborlist happens on cuda
nl_backend = neighborlist_transforms[0].backend
assert nl_backend in NEIGHBORLIST_BACKEND_OPTIONS
self._move_to_device_before_transforms = (
self.device.type == "cuda"
and NEIGHBORLIST_BACKEND_OPTIONS[nl_backend].supports_cuda
)
else:
# no neighborlist -> move to device first if device is cuda
self._move_to_device_before_transforms = self.device.type == "cuda"
# move transforms to the device where they execute
transform_device = (
self.device
if self._move_to_device_before_transforms
else torch.device("cpu")
)
self.transforms = [t.to(transform_device) for t in self.transforms]
def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
""""""
# call to base-class to set atoms attribute
Calculator.calculate(self, atoms)
data = self.atoms_to_data(atoms)
out = self.call_model(data)
self.results = {}
# only store results the model actually computed to avoid KeyErrors
if AtomicDataDict.TOTAL_ENERGY_KEY in out:
self.results["energy"] = self.energy_units_to_eV * (
out[AtomicDataDict.TOTAL_ENERGY_KEY]
.detach()
.cpu()
.numpy()
.reshape(tuple())
)
# "force consistent" energy
self.results["free_energy"] = self.results["energy"]
if AtomicDataDict.PER_ATOM_ENERGY_KEY in out:
self.results["energies"] = self.energy_units_to_eV * (
out[AtomicDataDict.PER_ATOM_ENERGY_KEY]
.detach()
.squeeze(-1)
.cpu()
.numpy()
)
if AtomicDataDict.FORCE_KEY in out:
# force has units eng / len:
self.results["forces"] = (
self.energy_units_to_eV / self.length_units_to_A
) * out[AtomicDataDict.FORCE_KEY].detach().cpu().numpy()
if AtomicDataDict.STRESS_KEY in out:
stress = out[AtomicDataDict.STRESS_KEY].detach().cpu().numpy()
stress = stress.reshape(3, 3) * (
self.energy_units_to_eV / self.length_units_to_A**3
)
# ase wants voigt format
stress_voigt = full_3x3_to_voigt_6_stress(stress)
self.results["stress"] = stress_voigt
self.save_extra_outputs(out)
def atoms_to_data(self, atoms: Atoms) -> AtomicDataDict.Type:
data = from_ase(atoms)
if self._move_to_device_before_transforms:
data = AtomicDataDict.to_(data, self.device)
for t in self.transforms:
data = t(data)
if not self._move_to_device_before_transforms:
data = AtomicDataDict.to_(data, self.device)
return data
def call_model(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
return self.model(data)
def save_extra_outputs(self, out: AtomicDataDict.Type):
# subclasses can implement this method to process extra outputs without code duplication
pass