# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import math
from e3nn import o3
from nequip.data import AtomicDataDict
from nequip.nn import (
GraphModel,
SequentialGraphNetwork,
ScalarMLP,
PerTypeScaleShift,
ConvNetLayer,
ForceStressOutput,
ApplyFactor,
)
from nequip.nn.embedding import (
NodeTypeEmbed,
PolynomialCutoff,
EdgeLengthNormalizer,
BesselEdgeLengthEncoding,
SphericalHarmonicEdgeAttrs,
)
from .utils import model_builder
from .energy_modules import _append_energy_modules
import warnings
from typing import Sequence, Optional, List, Dict, Union, Callable
_NEQUIP_GNN_PRESETS = {
"S": {
"num_layers": 2,
"l_max": 1,
"num_features": [128, 64],
},
"M": {
"num_layers": 4,
"l_max": 2,
"num_features": [128, 64, 32],
},
"L": {
"num_layers": 6,
"l_max": 3,
"num_features": [128, 64, 32, 32],
},
"XL": {
"num_layers": 6,
"l_max": 4,
"num_features": [320, 96, 64, 32, 32],
},
}
_NEQUIP_GNN_STANDARD_PRESET = {
"parity": False,
"type_embed_num_features": 32,
"radial_mlp_depth": 1,
"radial_mlp_width": 128,
}
def _format_nequip_gnn_preset_docstring() -> str:
shared_defaults = "\n".join(
[
f" - ``{key}``: ``{value!r}``"
for key, value in _NEQUIP_GNN_STANDARD_PRESET.items()
]
)
preset_defaults = "\n".join(
[
f" - ``{preset}``: ``{defaults!r}``"
for preset, defaults in _NEQUIP_GNN_PRESETS.items()
]
)
return f"""Build :func:`NequIPGNNModel` from a named architecture preset.
This is a wrapper of :func:`NequIPGNNModel` that injects preset hyperparameters based on model sizes of the NequIP foundation potentials.
All arguments are the same as :func:`NequIPGNNModel`, except this builder also requires ``preset`` and applies preset defaults before ``**kwargs``.
For full argument documentation, see :func:`NequIPGNNModel`.
Users can override the preset defaults by providing arguments for the fields to be overriden.
Preset argument:
preset (str): one of {", ".join([f"``{name}``" for name in _NEQUIP_GNN_PRESETS.keys()])}
Override order:
1. shared defaults
2. per-preset defaults
3. explicit ``**kwargs`` (highest priority)
Shared defaults:
{shared_defaults}
Per-preset defaults:
{preset_defaults}
"""
[docs]
@model_builder
def PresetNequIPGNNModel(
preset: str,
**kwargs,
) -> GraphModel:
preset = preset.upper()
assert preset in _NEQUIP_GNN_PRESETS, (
f"`preset` must be one of {list(_NEQUIP_GNN_PRESETS.keys())}, but found `{preset}`"
)
model_kwargs = {
**_NEQUIP_GNN_STANDARD_PRESET,
**_NEQUIP_GNN_PRESETS[preset],
}
# explicit kwargs override standard and preset defaults
model_kwargs.update(kwargs)
return NequIPGNNModel(**model_kwargs)
[docs]
@model_builder
def NequIPGNNModel(
num_layers: int = 4,
l_max: int = 1,
parity: bool = True,
num_features: Union[int, List[int]] = 32,
type_embed_num_features: Optional[int] = None,
radial_mlp_depth: int = 1,
radial_mlp_width: int = 128,
**kwargs,
) -> GraphModel:
"""NequIP GNN model that can predict energies only or energies with forces/stresses.
Args:
seed (int): seed for reproducibility
model_dtype (str): ``float32`` or ``float64``
r_max (float): cutoff radius
per_edge_type_cutoff (Dict): one can optionally specify cutoffs for each edge type [must be smaller than ``r_max``] (default ``None``)
type_names (Sequence[str]): list of atom type names
num_layers (int): number of interaction blocks, we find 3-5 to work best (default ``4``)
l_max (int): the maximum rotation order for the network's features, ``1`` is a good default, ``2`` is more accurate but slower (default ``1``)
parity (bool): whether to include features with odd mirror parity -- often turning parity off gives equally good results but faster networks, so it's worth testing (default ``True``)
num_features (int/List[int]): multiplicity of the features, smaller is faster (default ``32``); it is also possible to provide the multiplicity for each irrep, e.g. for ``l_max=2`` and ``parity=False``, ``num_features=[5, 2, 7]`` refers to ``5x0e``, ``2x1o`` and ``7x2e`` features
type_embed_num_features (int): number of features for the type embedding layer; if not provided, defaults to ``num_features[0]`` (default ``None``)
radial_mlp_depth (int): number of radial layers, usually 1-3 works best, smaller is faster (default ``1``)
radial_mlp_width (int): number of hidden neurons in radial function, smaller is faster (default ``128``)
readout_mlp_hidden_layers_depth (int): number of hidden layers in the readout MLP (default ``0``)
readout_mlp_hidden_layers_width (int): width of hidden layers in the readout MLP (default 0e contribution of ``num_features``)
readout_mlp_nonlinearity (str): ``silu``, ``mish``, ``gelu``, or ``None`` (default ``silu``)
num_bessels (int): number of Bessel basis functions (default ``8``)
bessel_trainable (bool): whether the Bessel roots are trainable (default ``False``)
polynomial_cutoff_p (int): p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance (default ``6``)
avg_num_neighbors (float/Dict[str, float]): used to normalize edge sums for better numerics (default ``None``)
per_type_energy_scales (float/List[float]): per-atom energy scales, which could be derived from the force RMS of the data (default ``None``)
per_type_energy_shifts (float/List[float]): per-atom energy shifts, which should generally be isolated atom reference energies or estimated from average per-atom energies of the data (default ``None``)
per_type_energy_scales_trainable (bool): whether the per-atom energy scales are trainable (default ``False``)
per_type_energy_shifts_trainable (bool): whether the per-atom energy shifts are trainable (default ``False``)
pair_potential (torch.nn.Module): additional pair potential term, e.g. :class:`~nequip.nn.pair_potential.ZBL` (default ``None``)
do_derivatives (bool): whether to compute forces and stresses via autograd (default ``True``)
"""
# === sanity checks and warnings ===
assert num_layers > 0, (
f"at least one convnet layer required, but found `num_layers={num_layers}`"
)
# === spherical harmonics ===
irreps_edge_sh = repr(o3.Irreps.spherical_harmonics(lmax=l_max))
# === handle `num_features` ===
if isinstance(num_features, int):
num_features = [num_features] * (l_max + 1)
assert len(num_features) == l_max + 1, (
f"`num_features` should be of length `l_max + 1` ({l_max + 1}), but found `num_features={num_features}` with {len(num_features)} entries."
)
# === type embedding ===
type_embed_num_features = (
type_embed_num_features
if type_embed_num_features is not None
else num_features[0]
)
# === convnet ===
# convert a single set of parameters uniformly for every layer
feature_irreps_hidden = repr(
o3.Irreps(
[
(num_features[l], (l, p))
for l in range(l_max + 1)
for p in (
(1, -1) if parity else ((1,) if l % 2 == 0 else (-1,))
) # p = 1 for even l, -1 for odd l, with parity = False
]
)
)
feature_irreps_hidden_list = [feature_irreps_hidden] * (num_layers - 1)
radial_mlp_depth_list = [radial_mlp_depth] * num_layers
radial_mlp_width_list = [radial_mlp_width] * num_layers
# === post convnets ===
feature_irreps_hidden_list += [repr(o3.Irreps([(num_features[0], (0, 1))]))]
# === build model ===
model = FullNequIPGNNModel(
irreps_edge_sh=irreps_edge_sh,
type_embed_num_features=type_embed_num_features,
feature_irreps_hidden=feature_irreps_hidden_list,
radial_mlp_depth=radial_mlp_depth_list,
radial_mlp_width=radial_mlp_width_list,
**kwargs,
)
return model
PresetNequIPGNNModel.__doc__ = _format_nequip_gnn_preset_docstring()
@model_builder
def FullNequIPGNNModel(
r_max: float,
type_names: Sequence[str],
# convnet params
radial_mlp_depth: Sequence[int],
radial_mlp_width: Sequence[int],
feature_irreps_hidden: Sequence[Union[str, o3.Irreps]],
# irreps and dims
irreps_edge_sh: Union[int, str, o3.Irreps],
type_embed_num_features: int,
categorical_graph_field_embed: Optional[List[Dict[str, int]]] = None,
# readout
readout_mlp_hidden_layers_depth: int = 0,
readout_mlp_hidden_layers_width: Optional[int] = None,
readout_mlp_nonlinearity: Optional[str] = "silu",
# edge length encoding
per_edge_type_cutoff: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
num_bessels: int = 8,
bessel_trainable: bool = False,
polynomial_cutoff_p: int = 6,
# edge sum normalization
avg_num_neighbors: Optional[Union[float, Dict[str, float]]] = None,
# per atom energy params
per_type_energy_scales: Optional[Union[float, Sequence[float]]] = None,
per_type_energy_shifts: Optional[Union[float, Sequence[float]]] = None,
per_type_energy_scales_trainable: Optional[bool] = False,
per_type_energy_shifts_trainable: Optional[bool] = False,
pair_potential: Optional[Dict] = None,
# derivatives
do_derivatives: bool = True,
# developmental params
convnet_sc: bool = True,
learnable_shift: bool = False,
# == things that generally shouldn't be changed ==
# convnet
convnet_resnet: bool = False,
convnet_nonlinearity_type: str = "gate",
convnet_nonlinearity_scalars: Dict[int, Callable] = {"e": "silu", "o": "tanh"},
convnet_nonlinearity_gates: Dict[int, Callable] = {"e": "silu", "o": "tanh"},
) -> GraphModel:
"""NequIP GNN model that predicts energies based on a more extensive set of arguments."""
# === sanity checks and warnings ===
assert all(tn.isalnum() for tn in type_names), (
"`type_names` must contain only alphanumeric characters"
)
# learnable_shift requires skip connections to be enabled
assert not learnable_shift or (convnet_sc or convnet_resnet), (
"`learnable_shift=True` requires at least one of `convnet_sc` or `convnet_resnet` to be True"
)
# require every convnet layer to be specified explicitly in a list
# infer num_layers from the list size
assert (
len(radial_mlp_depth) == len(radial_mlp_width) == len(feature_irreps_hidden)
), (
f"radial_mlp_depth: {radial_mlp_depth}, radial_mlp_width: {radial_mlp_width}, feature_irreps_hidden: {feature_irreps_hidden} should all have the same length"
)
num_layers = len(radial_mlp_depth)
# assert that last convnet produces only scalars
assert all([l == 0 for l in o3.Irreps(feature_irreps_hidden[-1]).ls]), (
f"last convnet layer output must only contain scalars but found {feature_irreps_hidden[-1]}"
)
if per_type_energy_scales is None:
warnings.warn(
"Found `per_type_energy_scales=None` -- it is recommended to set `per_type_energy_scales` for better numerics during training."
)
if per_type_energy_shifts is None:
warnings.warn(
"Found `per_type_energy_shifts=None` -- it is HIGHLY recommended to set `per_type_energy_shifts` as it determines the per-atom energies approaching the isolated atom regime."
)
# === encode and embed features ===
# == node scalar embedding ==
# NOTE: node embed is done first in case we need to pass in categorical graph fields as inputs
# see how `irreps_in` is registered in the `NodeTypeEmbed` class
type_embed = NodeTypeEmbed(
type_names=type_names,
num_features=type_embed_num_features,
categorical_graph_field_embed=categorical_graph_field_embed,
)
# == edge tensor embedding ==
spharm = SphericalHarmonicEdgeAttrs(
irreps_edge_sh=irreps_edge_sh,
irreps_in=type_embed.irreps_out,
)
# == edge scalar embedding ==
edge_norm = EdgeLengthNormalizer(
r_max=r_max,
type_names=type_names,
per_edge_type_cutoff=per_edge_type_cutoff,
irreps_in=spharm.irreps_out,
)
bessel_encode = BesselEdgeLengthEncoding(
num_bessels=num_bessels,
trainable=bessel_trainable,
cutoff=PolynomialCutoff(polynomial_cutoff_p),
edge_invariant_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
irreps_in=edge_norm.irreps_out,
)
# for backwards compatibility of NequIP's bessel encoding
factor = ApplyFactor(
in_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
factor=(2 * math.pi) / (r_max * r_max),
irreps_in=bessel_encode.irreps_out,
)
modules = {
"type_embed": type_embed,
"spharm": spharm,
"edge_norm": edge_norm,
"bessel_encode": bessel_encode,
"factor": factor,
}
prev_irreps_out = factor.irreps_out
# === convnet layers ===
for layer_i in range(num_layers):
current_convnet = ConvNetLayer(
irreps_in=prev_irreps_out,
feature_irreps_hidden=feature_irreps_hidden[layer_i],
convolution_kwargs={
"radial_mlp_depth": radial_mlp_depth[layer_i],
"radial_mlp_width": radial_mlp_width[layer_i],
# to ensure isolated atom limit
"use_sc": convnet_sc
if learnable_shift
else (layer_i != 0) and convnet_sc,
"is_first_layer": layer_i == 0,
# normalization parameters
"avg_num_neighbors": avg_num_neighbors,
"type_names": type_names,
},
resnet=convnet_resnet
if learnable_shift
else (layer_i != 0) and convnet_resnet,
nonlinearity_type=convnet_nonlinearity_type,
nonlinearity_scalars=convnet_nonlinearity_scalars,
nonlinearity_gates=convnet_nonlinearity_gates,
)
prev_irreps_out = current_convnet.irreps_out
modules.update({f"layer{layer_i}_convnet": current_convnet})
# === readout ===
if readout_mlp_hidden_layers_width is None:
readout_mlp_hidden_layers_width = o3.Irreps(feature_irreps_hidden[-1]).dim
per_atom_energy_readout = ScalarMLP(
output_dim=1,
hidden_layers_depth=readout_mlp_hidden_layers_depth,
hidden_layers_width=readout_mlp_hidden_layers_width,
nonlinearity=readout_mlp_nonlinearity,
bias=False,
forward_weight_init=True,
field=AtomicDataDict.NODE_FEATURES_KEY,
out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
irreps_in=prev_irreps_out,
)
per_type_energy_scale_shift = PerTypeScaleShift(
type_names=type_names,
field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
scales=per_type_energy_scales,
shifts=per_type_energy_shifts,
scales_trainable=per_type_energy_scales_trainable,
shifts_trainable=per_type_energy_shifts_trainable,
irreps_in=per_atom_energy_readout.irreps_out,
)
modules.update(
{
"per_atom_energy_readout": per_atom_energy_readout,
"per_type_energy_scale_shift": per_type_energy_scale_shift,
}
)
energy_model = SequentialGraphNetwork(modules)
energy_model = _append_energy_modules(
model=energy_model,
type_names=type_names,
pair_potential=pair_potential,
)
return ForceStressOutput(energy_model, do_derivatives)