Custom Models

Extension packages implementing custom models should be aware of the following model building infrastructure pieces in NequIP.

Model Builders

Model builders are functions decorated with model_builder() that construct models with proper handling of floating point precision, seeding, and compilation options:

@nequip.model.model_builder
def my_new_model_builder(arg1, arg2):
    return model(arg1, arg2)
nequip.model.model_builder(func=None, *, wrapper_class=None, compile_wrapper_class=None)[source]

Decorator for model builder functions in the nequip ecosystem.

Handles model building with proper seeding, floating point precision (float32 or float64), and wraps the result with GraphModel. Requires seed, model_dtype, and type_names arguments. Supports eager and compile modes via compile_mode.

The seed, model_dtype, and compile_mode arguments are consumed by the decorator and not passed to the decorated function.

Can be used in two ways: - @model_builder (uses GraphModel wrapper, backward compatible) - @model_builder(wrapper_class=CustomGraphModel) (uses custom wrapper)

Parameters:
  • func – The function to decorate (when used without parentheses)

  • wrapper_class – Custom GraphModel subclass to use for wrapping (default: GraphModel)

  • compile_wrapper_class – Custom wrapper for compile mode (default: CompileGraphModel)

Model Modifiers

Model modifiers are class methods decorated with model_modifier() that can modify loaded models on-the-fly:

@model_modifier(persistent=True)
@classmethod
def modify_PerTypeScaleShift(cls, model, scales=None, shifts=None, ...):
    # Implementation here
    pass

The persistent parameter determines whether the modifier is applied during model packaging:

  • Non-persistent (persistent=False): Applied only at runtime, often for accelerations

  • Persistent (persistent=True): Applied during packaging, for structural changes

Important: Model modifiers MUST preserve model state by transferring weights when there are trainable parameters.

Use replace_submodules() to help implement modifiers that replace specific module types.

Usage:

Train-time via config:

model:
  _target_: nequip.model.modify
  modifiers:
    - modifier: modify_PerTypeScaleShift
      shifts:
        C: 1.23
        H: 0.12
  model:
    _target_: nequip.model.ModelFromPackage
    # ...

Compile-time for accelerations:

nequip-compile model.ckpt compiled.pth --modifiers enable_OpenEquivariance

Implementation Tips

When implementing custom torch.nn.Module subclasses, use torch.nn.Module.extra_repr() to provide crucial model information for debugging (visible when _NEQUIP_LOG_LEVEL=DEBUG is set).