Source code for nequip.model.nequip_models

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import math
from e3nn import o3

from nequip.data import AtomicDataDict

from nequip.nn import (
    GraphModel,
    SequentialGraphNetwork,
    ScalarMLP,
    PerTypeScaleShift,
    ConvNetLayer,
    ForceStressOutput,
    ApplyFactor,
)
from nequip.nn.embedding import (
    NodeTypeEmbed,
    PolynomialCutoff,
    EdgeLengthNormalizer,
    BesselEdgeLengthEncoding,
    SphericalHarmonicEdgeAttrs,
)

from .utils import model_builder
from .energy_modules import _append_energy_modules
import warnings
from typing import Sequence, Optional, List, Dict, Union, Callable


_NEQUIP_GNN_PRESETS = {
    "S": {
        "num_layers": 2,
        "l_max": 1,
        "num_features": [128, 64],
    },
    "M": {
        "num_layers": 4,
        "l_max": 2,
        "num_features": [128, 64, 32],
    },
    "L": {
        "num_layers": 6,
        "l_max": 3,
        "num_features": [128, 64, 32, 32],
    },
    "XL": {
        "num_layers": 6,
        "l_max": 4,
        "num_features": [320, 96, 64, 32, 32],
    },
}

_NEQUIP_GNN_STANDARD_PRESET = {
    "parity": False,
    "type_embed_num_features": 32,
    "radial_mlp_depth": 1,
    "radial_mlp_width": 128,
}


def _format_nequip_gnn_preset_docstring() -> str:
    shared_defaults = "\n".join(
        [
            f"        - ``{key}``: ``{value!r}``"
            for key, value in _NEQUIP_GNN_STANDARD_PRESET.items()
        ]
    )
    preset_defaults = "\n".join(
        [
            f"        - ``{preset}``: ``{defaults!r}``"
            for preset, defaults in _NEQUIP_GNN_PRESETS.items()
        ]
    )
    return f"""Build :func:`NequIPGNNModel` from a named architecture preset.

    This is a wrapper of :func:`NequIPGNNModel` that injects preset hyperparameters based on model sizes of the NequIP foundation potentials.
    All arguments are the same as :func:`NequIPGNNModel`, except this builder also requires ``preset`` and applies preset defaults before ``**kwargs``.
    For full argument documentation, see :func:`NequIPGNNModel`.
    Users can override the preset defaults by providing arguments for the fields to be overriden.

    Preset argument:
        preset (str): one of {", ".join([f"``{name}``" for name in _NEQUIP_GNN_PRESETS.keys()])}

    Override order:
        1. shared defaults
        2. per-preset defaults
        3. explicit ``**kwargs`` (highest priority)

    Shared defaults:
{shared_defaults}

    Per-preset defaults:
{preset_defaults}
    """


[docs] @model_builder def PresetNequIPGNNModel( preset: str, **kwargs, ) -> GraphModel: preset = preset.upper() assert preset in _NEQUIP_GNN_PRESETS, ( f"`preset` must be one of {list(_NEQUIP_GNN_PRESETS.keys())}, but found `{preset}`" ) model_kwargs = { **_NEQUIP_GNN_STANDARD_PRESET, **_NEQUIP_GNN_PRESETS[preset], } # explicit kwargs override standard and preset defaults model_kwargs.update(kwargs) return NequIPGNNModel(**model_kwargs)
[docs] @model_builder def NequIPGNNModel( num_layers: int = 4, l_max: int = 1, parity: bool = True, num_features: Union[int, List[int]] = 32, type_embed_num_features: Optional[int] = None, radial_mlp_depth: int = 1, radial_mlp_width: int = 128, **kwargs, ) -> GraphModel: """NequIP GNN model that can predict energies only or energies with forces/stresses. Args: seed (int): seed for reproducibility model_dtype (str): ``float32`` or ``float64`` r_max (float): cutoff radius per_edge_type_cutoff (Dict): one can optionally specify cutoffs for each edge type [must be smaller than ``r_max``] (default ``None``) type_names (Sequence[str]): list of atom type names num_layers (int): number of interaction blocks, we find 3-5 to work best (default ``4``) l_max (int): the maximum rotation order for the network's features, ``1`` is a good default, ``2`` is more accurate but slower (default ``1``) parity (bool): whether to include features with odd mirror parity -- often turning parity off gives equally good results but faster networks, so it's worth testing (default ``True``) num_features (int/List[int]): multiplicity of the features, smaller is faster (default ``32``); it is also possible to provide the multiplicity for each irrep, e.g. for ``l_max=2`` and ``parity=False``, ``num_features=[5, 2, 7]`` refers to ``5x0e``, ``2x1o`` and ``7x2e`` features type_embed_num_features (int): number of features for the type embedding layer; if not provided, defaults to ``num_features[0]`` (default ``None``) radial_mlp_depth (int): number of radial layers, usually 1-3 works best, smaller is faster (default ``1``) radial_mlp_width (int): number of hidden neurons in radial function, smaller is faster (default ``128``) readout_mlp_hidden_layers_depth (int): number of hidden layers in the readout MLP (default ``0``) readout_mlp_hidden_layers_width (int): width of hidden layers in the readout MLP (default 0e contribution of ``num_features``) readout_mlp_nonlinearity (str): ``silu``, ``mish``, ``gelu``, or ``None`` (default ``silu``) num_bessels (int): number of Bessel basis functions (default ``8``) bessel_trainable (bool): whether the Bessel roots are trainable (default ``False``) polynomial_cutoff_p (int): p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance (default ``6``) avg_num_neighbors (float/Dict[str, float]): used to normalize edge sums for better numerics (default ``None``) per_type_energy_scales (float/List[float]): per-atom energy scales, which could be derived from the force RMS of the data (default ``None``) per_type_energy_shifts (float/List[float]): per-atom energy shifts, which should generally be isolated atom reference energies or estimated from average per-atom energies of the data (default ``None``) per_type_energy_scales_trainable (bool): whether the per-atom energy scales are trainable (default ``False``) per_type_energy_shifts_trainable (bool): whether the per-atom energy shifts are trainable (default ``False``) pair_potential (torch.nn.Module): additional pair potential term, e.g. :class:`~nequip.nn.pair_potential.ZBL` (default ``None``) do_derivatives (bool): whether to compute forces and stresses via autograd (default ``True``) """ # === sanity checks and warnings === assert num_layers > 0, ( f"at least one convnet layer required, but found `num_layers={num_layers}`" ) # === spherical harmonics === irreps_edge_sh = repr(o3.Irreps.spherical_harmonics(lmax=l_max)) # === handle `num_features` === if isinstance(num_features, int): num_features = [num_features] * (l_max + 1) assert len(num_features) == l_max + 1, ( f"`num_features` should be of length `l_max + 1` ({l_max + 1}), but found `num_features={num_features}` with {len(num_features)} entries." ) # === type embedding === type_embed_num_features = ( type_embed_num_features if type_embed_num_features is not None else num_features[0] ) # === convnet === # convert a single set of parameters uniformly for every layer feature_irreps_hidden = repr( o3.Irreps( [ (num_features[l], (l, p)) for l in range(l_max + 1) for p in ( (1, -1) if parity else ((1,) if l % 2 == 0 else (-1,)) ) # p = 1 for even l, -1 for odd l, with parity = False ] ) ) feature_irreps_hidden_list = [feature_irreps_hidden] * (num_layers - 1) radial_mlp_depth_list = [radial_mlp_depth] * num_layers radial_mlp_width_list = [radial_mlp_width] * num_layers # === post convnets === feature_irreps_hidden_list += [repr(o3.Irreps([(num_features[0], (0, 1))]))] # === build model === model = FullNequIPGNNModel( irreps_edge_sh=irreps_edge_sh, type_embed_num_features=type_embed_num_features, feature_irreps_hidden=feature_irreps_hidden_list, radial_mlp_depth=radial_mlp_depth_list, radial_mlp_width=radial_mlp_width_list, **kwargs, ) return model
PresetNequIPGNNModel.__doc__ = _format_nequip_gnn_preset_docstring() @model_builder def FullNequIPGNNModel( r_max: float, type_names: Sequence[str], # convnet params radial_mlp_depth: Sequence[int], radial_mlp_width: Sequence[int], feature_irreps_hidden: Sequence[Union[str, o3.Irreps]], # irreps and dims irreps_edge_sh: Union[int, str, o3.Irreps], type_embed_num_features: int, categorical_graph_field_embed: Optional[List[Dict[str, int]]] = None, # readout readout_mlp_hidden_layers_depth: int = 0, readout_mlp_hidden_layers_width: Optional[int] = None, readout_mlp_nonlinearity: Optional[str] = "silu", # edge length encoding per_edge_type_cutoff: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, num_bessels: int = 8, bessel_trainable: bool = False, polynomial_cutoff_p: int = 6, # edge sum normalization avg_num_neighbors: Optional[Union[float, Dict[str, float]]] = None, # per atom energy params per_type_energy_scales: Optional[Union[float, Sequence[float]]] = None, per_type_energy_shifts: Optional[Union[float, Sequence[float]]] = None, per_type_energy_scales_trainable: Optional[bool] = False, per_type_energy_shifts_trainable: Optional[bool] = False, pair_potential: Optional[Dict] = None, # derivatives do_derivatives: bool = True, # developmental params convnet_sc: bool = True, learnable_shift: bool = False, # == things that generally shouldn't be changed == # convnet convnet_resnet: bool = False, convnet_nonlinearity_type: str = "gate", convnet_nonlinearity_scalars: Dict[int, Callable] = {"e": "silu", "o": "tanh"}, convnet_nonlinearity_gates: Dict[int, Callable] = {"e": "silu", "o": "tanh"}, ) -> GraphModel: """NequIP GNN model that predicts energies based on a more extensive set of arguments.""" # === sanity checks and warnings === assert all(tn.isalnum() for tn in type_names), ( "`type_names` must contain only alphanumeric characters" ) # learnable_shift requires skip connections to be enabled assert not learnable_shift or (convnet_sc or convnet_resnet), ( "`learnable_shift=True` requires at least one of `convnet_sc` or `convnet_resnet` to be True" ) # require every convnet layer to be specified explicitly in a list # infer num_layers from the list size assert ( len(radial_mlp_depth) == len(radial_mlp_width) == len(feature_irreps_hidden) ), ( f"radial_mlp_depth: {radial_mlp_depth}, radial_mlp_width: {radial_mlp_width}, feature_irreps_hidden: {feature_irreps_hidden} should all have the same length" ) num_layers = len(radial_mlp_depth) # assert that last convnet produces only scalars assert all([l == 0 for l in o3.Irreps(feature_irreps_hidden[-1]).ls]), ( f"last convnet layer output must only contain scalars but found {feature_irreps_hidden[-1]}" ) if per_type_energy_scales is None: warnings.warn( "Found `per_type_energy_scales=None` -- it is recommended to set `per_type_energy_scales` for better numerics during training." ) if per_type_energy_shifts is None: warnings.warn( "Found `per_type_energy_shifts=None` -- it is HIGHLY recommended to set `per_type_energy_shifts` as it determines the per-atom energies approaching the isolated atom regime." ) # === encode and embed features === # == node scalar embedding == # NOTE: node embed is done first in case we need to pass in categorical graph fields as inputs # see how `irreps_in` is registered in the `NodeTypeEmbed` class type_embed = NodeTypeEmbed( type_names=type_names, num_features=type_embed_num_features, categorical_graph_field_embed=categorical_graph_field_embed, ) # == edge tensor embedding == spharm = SphericalHarmonicEdgeAttrs( irreps_edge_sh=irreps_edge_sh, irreps_in=type_embed.irreps_out, ) # == edge scalar embedding == edge_norm = EdgeLengthNormalizer( r_max=r_max, type_names=type_names, per_edge_type_cutoff=per_edge_type_cutoff, irreps_in=spharm.irreps_out, ) bessel_encode = BesselEdgeLengthEncoding( num_bessels=num_bessels, trainable=bessel_trainable, cutoff=PolynomialCutoff(polynomial_cutoff_p), edge_invariant_field=AtomicDataDict.EDGE_EMBEDDING_KEY, irreps_in=edge_norm.irreps_out, ) # for backwards compatibility of NequIP's bessel encoding factor = ApplyFactor( in_field=AtomicDataDict.EDGE_EMBEDDING_KEY, factor=(2 * math.pi) / (r_max * r_max), irreps_in=bessel_encode.irreps_out, ) modules = { "type_embed": type_embed, "spharm": spharm, "edge_norm": edge_norm, "bessel_encode": bessel_encode, "factor": factor, } prev_irreps_out = factor.irreps_out # === convnet layers === for layer_i in range(num_layers): current_convnet = ConvNetLayer( irreps_in=prev_irreps_out, feature_irreps_hidden=feature_irreps_hidden[layer_i], convolution_kwargs={ "radial_mlp_depth": radial_mlp_depth[layer_i], "radial_mlp_width": radial_mlp_width[layer_i], # to ensure isolated atom limit "use_sc": convnet_sc if learnable_shift else (layer_i != 0) and convnet_sc, "is_first_layer": layer_i == 0, # normalization parameters "avg_num_neighbors": avg_num_neighbors, "type_names": type_names, }, resnet=convnet_resnet if learnable_shift else (layer_i != 0) and convnet_resnet, nonlinearity_type=convnet_nonlinearity_type, nonlinearity_scalars=convnet_nonlinearity_scalars, nonlinearity_gates=convnet_nonlinearity_gates, ) prev_irreps_out = current_convnet.irreps_out modules.update({f"layer{layer_i}_convnet": current_convnet}) # === readout === if readout_mlp_hidden_layers_width is None: readout_mlp_hidden_layers_width = o3.Irreps(feature_irreps_hidden[-1]).dim per_atom_energy_readout = ScalarMLP( output_dim=1, hidden_layers_depth=readout_mlp_hidden_layers_depth, hidden_layers_width=readout_mlp_hidden_layers_width, nonlinearity=readout_mlp_nonlinearity, bias=False, forward_weight_init=True, field=AtomicDataDict.NODE_FEATURES_KEY, out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY, irreps_in=prev_irreps_out, ) per_type_energy_scale_shift = PerTypeScaleShift( type_names=type_names, field=AtomicDataDict.PER_ATOM_ENERGY_KEY, out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY, scales=per_type_energy_scales, shifts=per_type_energy_shifts, scales_trainable=per_type_energy_scales_trainable, shifts_trainable=per_type_energy_shifts_trainable, irreps_in=per_atom_energy_readout.irreps_out, ) modules.update( { "per_atom_energy_readout": per_atom_energy_readout, "per_type_energy_scale_shift": per_type_energy_scale_shift, } ) energy_model = SequentialGraphNetwork(modules) energy_model = _append_energy_modules( model=energy_model, type_names=type_names, pair_potential=pair_potential, ) return ForceStressOutput(energy_model, do_derivatives)