Custom Models¶
Extension packages implementing custom models should be aware of the following model building infrastructure pieces in NequIP.
Model Builders¶
Model builders are functions decorated with model_builder() that construct models with proper handling of floating point precision, seeding, and compilation options:
@nequip.model.model_builder
def my_new_model_builder(arg1, arg2):
return model(arg1, arg2)
- nequip.model.model_builder(func=None, *, wrapper_class=None, compile_wrapper_class=None)[source]¶
Decorator for model builder functions in the
nequipecosystem.Handles model building with proper seeding, floating point precision (
float32orfloat64), and wraps the result withGraphModel. Requiresseed,model_dtype, andtype_namesarguments. Supportseagerandcompilemodes viacompile_mode.The
seed,model_dtype, andcompile_modearguments are consumed by the decorator and not passed to the decorated function.Can be used in two ways: - @model_builder (uses GraphModel wrapper, backward compatible) - @model_builder(wrapper_class=CustomGraphModel) (uses custom wrapper)
- Parameters:
func – The function to decorate (when used without parentheses)
wrapper_class – Custom GraphModel subclass to use for wrapping (default: GraphModel)
compile_wrapper_class – Custom wrapper for compile mode (default: CompileGraphModel)
Model Modifiers¶
Model modifiers are class methods decorated with model_modifier() that can modify loaded models on-the-fly:
@model_modifier(persistent=True)
@classmethod
def modify_PerTypeScaleShift(cls, model, scales=None, shifts=None, ...):
# Implementation here
pass
The persistent parameter determines whether the modifier is applied during model packaging:
Non-persistent (
persistent=False): Applied only at runtime, often for accelerationsPersistent (
persistent=True): Applied during packaging, for structural changes
Important: Model modifiers MUST preserve model state by transferring weights when there are trainable parameters.
Use replace_submodules() to help implement modifiers that replace specific module types.
Usage:
Train-time via config:
model:
_target_: nequip.model.modify
modifiers:
- modifier: modify_PerTypeScaleShift
shifts:
C: 1.23
H: 0.12
model:
_target_: nequip.model.ModelFromPackage
# ...
Compile-time for accelerations:
nequip-compile model.ckpt compiled.pth --modifiers enable_OpenEquivariance
Implementation Tips¶
When implementing custom torch.nn.Module subclasses, use torch.nn.Module.extra_repr() to provide crucial model information for debugging (visible when _NEQUIP_LOG_LEVEL=DEBUG is set).