# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
from typing import Union, Optional, List
import torch
from e3nn.o3._irreps import Irreps
from e3nn.util.jit import compile_mode
from nequip.data import AtomicDataDict
from nequip.data.misc import chemical_symbols_to_atomic_numbers_dict
from ._graph_mixin import GraphModuleMixin
from .utils import scatter, with_edge_vectors_
from nequip.utils.compile import conditional_torchscript_jit
from .embedding.cutoffs import PolynomialCutoff
class _LJParam(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, param, index1, index2):
if param.ndim == 2:
# make it symmetric
param = param.triu() + param.triu(1).transpose(-1, -2)
# get for each atom pair
param = torch.index_select(
param.view(-1), 0, index1 * param.shape[0] + index2
)
# make it positive
param = param.relu() # TODO: better way?
return param
@compile_mode("script")
class LennardJones(GraphModuleMixin, torch.nn.Module):
"""Lennard-Jones and related pair potentials."""
lj_style: str
exponent: float
def __init__(
self,
type_names: List[str],
lj_sigma: Union[torch.Tensor, float],
lj_delta: Union[torch.Tensor, float] = 0,
lj_epsilon: Optional[Union[torch.Tensor, float]] = None,
lj_sigma_trainable: bool = False,
lj_delta_trainable: bool = False,
lj_epsilon_trainable: bool = False,
lj_exponent: Optional[float] = None,
lj_per_type: bool = True,
lj_style: str = "lj",
polynomial_cutoff_p: float = 6.0,
per_atom_energy_field: str = AtomicDataDict.PER_ATOM_ENERGY_KEY,
irreps_in=None,
) -> None:
super().__init__()
num_types = len(type_names)
self.per_atom_energy_field = per_atom_energy_field
# === irreps registration ===
self._init_irreps(
irreps_in=irreps_in,
required_irreps_in=[AtomicDataDict.NORM_LENGTH_KEY],
irreps_out={self.per_atom_energy_field: "0e"},
)
if self.per_atom_energy_field in self.irreps_in:
energy_irreps = Irreps(self.irreps_in[self.per_atom_energy_field])
assert all(ir.l == 0 for _, ir in energy_irreps), (
f"{self.per_atom_energy_field} must be scalar irreps, found {energy_irreps}"
)
self.irreps_out[self.per_atom_energy_field] = energy_irreps
assert lj_style in ("lj", "lj_repulsive_only", "repulsive")
self.lj_style = lj_style
for param, (value, trainable) in {
"epsilon": (lj_epsilon, lj_epsilon_trainable),
"sigma": (lj_sigma, lj_sigma_trainable),
"delta": (lj_delta, lj_delta_trainable),
}.items():
if value is None:
self.register_buffer(param, torch.Tensor()) # torchscript
continue
value = torch.as_tensor(value, dtype=torch.get_default_dtype())
if value.ndim == 0 and lj_per_type:
# one scalar for all pair types
value = (
torch.ones(
num_types, num_types, device=value.device, dtype=value.dtype
)
* value
)
elif value.ndim == 2:
assert lj_per_type
# one per pair type, check symmetric
assert value.shape == (num_types, num_types)
# per-species square, make sure symmetric
assert torch.equal(value, value.T)
value = torch.triu(value)
else:
raise ValueError
setattr(self, param, torch.nn.Parameter(value, requires_grad=trainable))
if lj_exponent is None:
lj_exponent = 6.0
self.exponent = lj_exponent
self.cutoff = conditional_torchscript_jit(PolynomialCutoff(polynomial_cutoff_p))
self.model_dtype = torch.get_default_dtype()
self._param = conditional_torchscript_jit(_LJParam())
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
data = with_edge_vectors_(data, with_lengths=True)
edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0]
atom_types = data[AtomicDataDict.ATOM_TYPE_KEY]
edge_len = data[AtomicDataDict.EDGE_LENGTH_KEY].unsqueeze(-1)
edge_types = torch.index_select(
atom_types, 0, data[AtomicDataDict.EDGE_INDEX_KEY].reshape(-1)
).view(2, -1)
index1 = edge_types[0]
index2 = edge_types[1]
sigma = self._param(self.sigma, index1, index2)
delta = self._param(self.delta, index1, index2)
epsilon = self._param(self.epsilon, index1, index2)
if self.lj_style == "repulsive":
# 0.5 to assign half and half the energy to each side of the interaction
lj_eng = 0.5 * epsilon * ((sigma * (edge_len - delta)) ** -self.exponent)
else:
lj_eng = (sigma / (edge_len - delta)) ** self.exponent
lj_eng = torch.neg(lj_eng)
lj_eng = lj_eng + lj_eng.square()
# 2.0 because we do the slightly symmetric thing and let
# ij and ji each contribute half of the LJ energy of the pair
# this avoids indexing out certain edges in the general case where
# the edges are not ordered.
lj_eng = (2.0 * epsilon) * lj_eng
if self.lj_style == "lj_repulsive_only":
# if taking only the repulsive part, shift up so the minima is at eng=0
lj_eng = lj_eng + epsilon
# this is continuous at the minima, and we mask out everything greater
# TODO: this is probably broken with NaNs at delta
lj_eng = lj_eng * (edge_len < (2 ** (1.0 / self.exponent) + delta))
# apply polynomial cutoff from this module's own normalized edge lengths
lj_edge_cutoff = self.cutoff(data[AtomicDataDict.NORM_LENGTH_KEY]).to(
self.model_dtype
)
lj_eng = lj_eng.to(self.model_dtype) * lj_edge_cutoff
# sum edge LJ energies onto atoms
atomic_eng = scatter(
lj_eng,
edge_center,
dim=0,
dim_size=AtomicDataDict.num_nodes(data),
)
if self.per_atom_energy_field in data:
atomic_eng = atomic_eng + data[self.per_atom_energy_field]
data[self.per_atom_energy_field] = atomic_eng
return data
def __repr__(self) -> str:
def _f(e):
e = e.data
if e.ndim == 0:
return f"{e:.6f}"
elif e.ndim == 2:
return f"{e}"
return f"PairPotential(lj_style={self.lj_style} | σ={_f(self.sigma)} δ={_f(self.delta)} ε={_f(self.epsilon)} exp={self.exponent:.1f})"
@compile_mode("script")
class SimpleLennardJones(GraphModuleMixin, torch.nn.Module):
"""Simple Lennard-Jones."""
lj_sigma: float
lj_epsilon: float
def __init__(
self,
lj_sigma: float,
lj_epsilon: float,
polynomial_cutoff_p: float = 6.0,
irreps_in=None,
) -> None:
super().__init__()
self._init_irreps(
irreps_in=irreps_in,
required_irreps_in=[AtomicDataDict.NORM_LENGTH_KEY],
irreps_out={AtomicDataDict.PER_ATOM_ENERGY_KEY: "0e"},
)
self.lj_sigma = lj_sigma
self.lj_epsilon = lj_epsilon
self.cutoff = conditional_torchscript_jit(PolynomialCutoff(polynomial_cutoff_p))
self.model_dtype = torch.get_default_dtype()
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
data = with_edge_vectors_(data, with_lengths=True)
edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0]
edge_len = data[AtomicDataDict.EDGE_LENGTH_KEY].unsqueeze(-1)
lj_eng = (self.lj_sigma / edge_len) ** 6.0
lj_eng = lj_eng.square() - lj_eng
lj_eng = 2 * self.lj_epsilon * lj_eng
# apply polynomial cutoff from this module's own normalized edge lengths
lj_edge_cutoff = self.cutoff(data[AtomicDataDict.NORM_LENGTH_KEY]).to(
self.model_dtype
)
lj_eng = lj_eng.to(self.model_dtype) * lj_edge_cutoff
# sum edge LJ energies onto atoms
atomic_eng = scatter(
lj_eng,
edge_center,
dim=0,
dim_size=AtomicDataDict.num_nodes(data),
)
if AtomicDataDict.PER_ATOM_ENERGY_KEY in data:
atomic_eng = atomic_eng + data[AtomicDataDict.PER_ATOM_ENERGY_KEY]
data[AtomicDataDict.PER_ATOM_ENERGY_KEY] = atomic_eng
return data
class _ZBL(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
Z: torch.Tensor,
r: torch.Tensor,
atom_types: torch.Tensor,
edge_index: torch.Tensor,
qqr2exesquare: float,
) -> torch.Tensor:
# from LAMMPS pair_zbl_const.h
pzbl: float = 0.23
a0: float = 0.46850
c1: float = 0.02817
c2: float = 0.28022
c3: float = 0.50986
c4: float = 0.18175
d1: float = -0.20162
d2: float = -0.40290
d3: float = -0.94229
d4: float = -3.19980
# (num_atoms,) -> (num_atoms, 1)
node_Zs = torch.nn.functional.embedding(atom_types.view(-1), Z.view(-1, 1))
# (num_atoms,) -> (2 * num_edges,)
edge_Zs = torch.nn.functional.embedding(edge_index.view(-1), node_Zs).view(
2, -1
)
Zi = torch.select(edge_Zs, 0, 0)
Zj = torch.select(edge_Zs, 0, 1)
del node_Zs, edge_Zs
x = ((torch.pow(Zi, pzbl) + torch.pow(Zj, pzbl)) * r) / a0
psi = (
c1 * (d1 * x).exp()
+ c2 * (d2 * x).exp()
+ c3 * (d3 * x).exp()
+ c4 * (d4 * x).exp()
)
eng = qqr2exesquare * ((Zi * Zj) / r) * psi
return eng
[docs]
@compile_mode("script")
class ZBL(GraphModuleMixin, torch.nn.Module):
"""`ZBL <https://docs.lammps.org/pair_zbl.html>`_ pair potential energy term.
Useful as a prior for core repulsion to mitigate molecular dynamics failure modes associated with atoms getting too close.
Args:
type_names (List[str]): list of type names known by the model, ``[atom1, atom2, atom3]``
chemical_species (List[str]): list of chemical symbols, e.g. ``[C, H, O]``
units (str): `LAMMPS units <https://docs.lammps.org/units.html>`_ that the data is in; ``metal`` and ``real`` are presently supported -- raise a GitHub issue if more is desired
polynomial_cutoff_p (float): exponent used for the polynomial cutoff (default ``6``)
"""
def __init__(
self,
type_names: List[str],
chemical_species: List[str],
units: str,
polynomial_cutoff_p: float = 6.0,
per_atom_energy_field: str = AtomicDataDict.PER_ATOM_ENERGY_KEY,
irreps_in=None,
):
super().__init__()
num_types = len(type_names)
self.per_atom_energy_field = per_atom_energy_field
# === irreps registration ===
self._init_irreps(
irreps_in=irreps_in,
required_irreps_in=[AtomicDataDict.NORM_LENGTH_KEY],
irreps_out={self.per_atom_energy_field: "0e"},
)
if self.per_atom_energy_field in self.irreps_in:
energy_irreps = Irreps(self.irreps_in[self.per_atom_energy_field])
assert all(ir.l == 0 for _, ir in energy_irreps), (
f"{self.per_atom_energy_field} must be scalar irreps, found {energy_irreps}"
)
self.irreps_out[self.per_atom_energy_field] = energy_irreps
assert len(chemical_species) == num_types
atomic_numbers: List[int] = [
chemical_symbols_to_atomic_numbers_dict[chemical_species[type_i]]
for type_i in range(num_types)
]
if min(atomic_numbers) < 1:
raise ValueError(
f"Your chemical symbols don't seem valid (minimum atomic number is {min(atomic_numbers)} < 1); did you try to use fake chemical symbols for arbitrary atom types?"
)
# LAMMPS note on units:
# > The numerical values of the exponential decay constants in the
# > screening function depend on the unit of distance. In the above
# > equation they are given for units of Angstroms. LAMMPS will
# > automatically convert these values to the distance unit of the
# > specified LAMMPS units setting. The values of Z should always be
# > given as multiples of a proton’s charge, e.g. 29.0 for copper.
# So, we store the atomic numbers directly.
self.register_buffer(
"atomic_numbers",
torch.as_tensor(atomic_numbers, dtype=torch.get_default_dtype()),
)
# And we have to convert our value of prefector into the model's physical units
# Here, prefactor is (electron charge)^2 / (4 * pi * electrical permisivity of vacuum)
# we have a value for that in eV and Angstrom
# See https://github.com/lammps/lammps/blob/c415385ab4b0983fa1c72f9e92a09a8ed7eebe4a/src/update.cpp#L187 for values from LAMMPS
# LAMMPS uses `force->qqr2e * force->qelectron * force->qelectron`
# Make it a buffer so rescalings are persistent, it still acts as a scalar Tensor
self.register_buffer(
"_qqr2exesquare",
torch.as_tensor(
{"metal": 14.399645 * (1.0) ** 2, "real": 332.06371 * (1.0) ** 2}[
units
],
dtype=torch.float64,
)
* 0.5, # Put half the energy on each of ij, ji
)
self.cutoff = conditional_torchscript_jit(PolynomialCutoff(polynomial_cutoff_p))
self.model_dtype = torch.get_default_dtype()
self._zbl = conditional_torchscript_jit(_ZBL())
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
""""""
data = with_edge_vectors_(data, with_lengths=True)
edge_center = data[AtomicDataDict.EDGE_INDEX_KEY][0]
# account for possibility of reduced num nodes in atomic energy in a local-ghost atom context
if self.per_atom_energy_field in data:
num_nodes = data[self.per_atom_energy_field].size(0)
else:
num_nodes = AtomicDataDict.num_nodes(data)
zbl_edge_eng = self._zbl(
Z=self.atomic_numbers,
r=data[AtomicDataDict.EDGE_LENGTH_KEY].view(-1),
atom_types=data[AtomicDataDict.ATOM_TYPE_KEY],
edge_index=data[AtomicDataDict.EDGE_INDEX_KEY],
qqr2exesquare=self._qqr2exesquare,
).unsqueeze(-1)
# apply cutoff
zbl_edge_cutoff = self.cutoff(data[AtomicDataDict.NORM_LENGTH_KEY]).to(
self.model_dtype
)
zbl_edge_eng = zbl_edge_eng * zbl_edge_cutoff
atomic_eng = scatter(
zbl_edge_eng,
edge_center,
dim=0,
dim_size=num_nodes,
)
if self.per_atom_energy_field in data:
atomic_eng = atomic_eng + data[self.per_atom_energy_field]
data[self.per_atom_energy_field] = atomic_eng
return data
__all__ = [LennardJones, ZBL]