NequIP Message Passing GNN Models¶
- nequip.model.NequIPGNNModel(num_layers: int = 4, l_max: int = 1, parity: bool = True, num_features: int | List[int] = 32, type_embed_num_features: int | None = None, radial_mlp_depth: int = 1, radial_mlp_width: int = 128, **kwargs) GraphModel[source]¶
NequIP GNN model that can predict energies only or energies with forces/stresses.
- Parameters:
seed (int) – seed for reproducibility
model_dtype (str) –
float32orfloat64r_max (float) – cutoff radius
per_edge_type_cutoff (Dict) – one can optionally specify cutoffs for each edge type [must be smaller than
r_max] (defaultNone)type_names (Sequence[str]) – list of atom type names
num_layers (int) – number of interaction blocks, we find 3-5 to work best (default
4)l_max (int) – the maximum rotation order for the network’s features,
1is a good default,2is more accurate but slower (default1)parity (bool) – whether to include features with odd mirror parity – often turning parity off gives equally good results but faster networks, so it’s worth testing (default
True)num_features (int/List[int]) – multiplicity of the features, smaller is faster (default
32); it is also possible to provide the multiplicity for each irrep, e.g. forl_max=2andparity=False,num_features=[5, 2, 7]refers to5x0e,2x1oand7x2efeaturestype_embed_num_features (int) – number of features for the type embedding layer; if not provided, defaults to
num_features[0](defaultNone)radial_mlp_depth (int) – number of radial layers, usually 1-3 works best, smaller is faster (default
1)radial_mlp_width (int) – number of hidden neurons in radial function, smaller is faster (default
128)readout_mlp_hidden_layers_depth (int) – number of hidden layers in the readout MLP (default
0)readout_mlp_hidden_layers_width (int) – width of hidden layers in the readout MLP (default 0e contribution of
num_features)readout_mlp_nonlinearity (str) –
silu,mish,gelu, orNone(defaultsilu)num_bessels (int) – number of Bessel basis functions (default
8)bessel_trainable (bool) – whether the Bessel roots are trainable (default
False)polynomial_cutoff_p (int) – p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance (default
6)avg_num_neighbors (float/Dict[str, float]) – used to normalize edge sums for better numerics (default
None)per_type_energy_scales (float/List[float]) – per-atom energy scales, which could be derived from the force RMS of the data (default
None)per_type_energy_shifts (float/List[float]) – per-atom energy shifts, which should generally be isolated atom reference energies or estimated from average per-atom energies of the data (default
None)per_type_energy_scales_trainable (bool) – whether the per-atom energy scales are trainable (default
False)per_type_energy_shifts_trainable (bool) – whether the per-atom energy shifts are trainable (default
False)pair_potential (torch.nn.Module) – additional pair potential term, e.g.
ZBL(defaultNone)do_derivatives (bool) – whether to compute forces and stresses via autograd (default
True)
Preset NequIP Message Passing GNN Model¶
- nequip.model.PresetNequIPGNNModel(preset: str, **kwargs) GraphModel[source]¶
Build
NequIPGNNModel()from a named architecture preset.This is a wrapper of
NequIPGNNModel()that injects preset hyperparameters based on model sizes of the NequIP foundation potentials. All arguments are the same asNequIPGNNModel(), except this builder also requirespresetand applies preset defaults before**kwargs. For full argument documentation, seeNequIPGNNModel(). Users can override the preset defaults by providing arguments for the fields to be overriden.- Preset argument:
preset (str): one of
S,M,L,XL- Override order:
shared defaults
per-preset defaults
explicit
**kwargs(highest priority)
- Shared defaults:
parity:Falsetype_embed_num_features:32radial_mlp_depth:1radial_mlp_width:128
- Per-preset defaults:
S:{'num_layers': 2, 'l_max': 1, 'num_features': [128, 64]}M:{'num_layers': 4, 'l_max': 2, 'num_features': [128, 64, 32]}L:{'num_layers': 6, 'l_max': 3, 'num_features': [128, 64, 32, 32]}XL:{'num_layers': 6, 'l_max': 4, 'num_features': [320, 96, 64, 32, 32]}