Source code for nequip.model.utils

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from lightning.pytorch.utilities.seed import isolate_rng

from nequip.nn.graph_model import GraphModel
from nequip.nn.compile import CompileGraphModel
from nequip.utils import (
    dtype_from_name,
    torch_default_dtype,
    conditional_torchscript_mode,
)
from nequip.utils.global_state import (
    global_state_initialized,
    get_latest_global_state,
    TF32_KEY,
)

import functools
import contextvars
import contextlib

from typing import Optional, Final

_IS_BUILDING_MODEL = contextvars.ContextVar("_IS_BUILDING_MODEL", default=False)
_CURRENT_MODEL_BUILDER_DEFAULTS = contextvars.ContextVar(
    "_CURRENT_MODEL_BUILDER_DEFAULTS",
    default=None,
)

# the following is the set of model build types for specific purposes
_EAGER_MODEL_KEY = "eager"
_TRAIN_TIME_COMPILE_KEY: Final[str] = "compile"

_COMPILE_MODE_OPTIONS = {
    _EAGER_MODEL_KEY,
    _TRAIN_TIME_COMPILE_KEY,
}


_OVERRIDE_COMPILE_MODE = contextvars.ContextVar("_OVERRIDE_COMPILE_MODE", default=False)
_CURRENT_COMPILE_MODE = contextvars.ContextVar(
    "_CURRENT_COMPILE_MODE", default=_EAGER_MODEL_KEY
)


@contextlib.contextmanager
def override_model_compile_mode(compile_mode: Optional[str]):
    """
    Overrides the ``compile_mode`` for model building.
    If several of these context managers are nested, the outermost one will be prioritized while the inner ones are ignored.
    The intended client is `ModelFromCheckpoint`.
    Anybody using this function should be warned that the behavior is designed for loading models from checkpoints and packages correctly.
    """
    assert compile_mode in _COMPILE_MODE_OPTIONS
    global _OVERRIDE_COMPILE_MODE
    global _CURRENT_COMPILE_MODE
    init_state = _OVERRIDE_COMPILE_MODE.get()
    # in the case of nested overrides, we prioritize the outermost context manager
    if init_state:
        yield
    else:
        init_mode = _CURRENT_COMPILE_MODE.get()
        _OVERRIDE_COMPILE_MODE.set(True)
        _CURRENT_COMPILE_MODE.set(compile_mode)
        try:
            yield
        finally:
            _OVERRIDE_COMPILE_MODE.set(init_state)
            _CURRENT_COMPILE_MODE.set(init_mode)


@contextlib.contextmanager
def fresh_model_builder_context():
    """Temporarily treat nested model-builder calls as fresh top-level builds.

    This is an explicit escape hatch for composing models where an inner builder
    should run with full `@model_builder` behavior (dtype/seed/wrapping),
    instead of being returned as a raw nested module.

    Required builder args (`seed`, `model_dtype`, `type_names`) are inherited
    from the active outer model-builder context when not explicitly provided.
    """
    # TODO: decide compile_mode semantics for fresh nested builds:
    # should they inherit outer builder compile_mode or use current default/override?
    global _IS_BUILDING_MODEL
    init_state = _IS_BUILDING_MODEL.get()
    _IS_BUILDING_MODEL.set(False)
    try:
        yield
    finally:
        _IS_BUILDING_MODEL.set(init_state)


def get_current_compile_mode(return_override: bool = False):
    # returns tuple of (whether compile mode is overriden, compile mode)
    global _CURRENT_COMPILE_MODE
    if return_override:
        global _OVERRIDE_COMPILE_MODE
        return _CURRENT_COMPILE_MODE.get(), _OVERRIDE_COMPILE_MODE.get()
    else:
        return _CURRENT_COMPILE_MODE.get()


[docs] def model_builder(func=None, *, wrapper_class=None, compile_wrapper_class=None): """Decorator for model builder functions in the ``nequip`` ecosystem. Handles model building with proper seeding, floating point precision (``float32`` or ``float64``), and wraps the result with ``GraphModel``. Requires ``seed``, ``model_dtype``, and ``type_names`` arguments. Supports ``eager`` and ``compile`` modes via ``compile_mode``. The ``seed``, ``model_dtype``, and ``compile_mode`` arguments are consumed by the decorator and not passed to the decorated function. Can be used in two ways: - @model_builder (uses GraphModel wrapper, backward compatible) - @model_builder(wrapper_class=CustomGraphModel) (uses custom wrapper) Args: func: The function to decorate (when used without parentheses) wrapper_class: Custom GraphModel subclass to use for wrapping (default: GraphModel) compile_wrapper_class: Custom wrapper for compile mode (default: CompileGraphModel) """ # default wrapper classes if wrapper_class is None: wrapper_class = GraphModel if compile_wrapper_class is None: compile_wrapper_class = CompileGraphModel def decorator(f): @functools.wraps(f) def wrapper(*args, **kwargs): # to handle nested model building global _IS_BUILDING_MODEL # to handle compile modes global _OVERRIDE_COMPILE_MODE global _CURRENT_COMPILE_MODE # this means we're in an inner model, so we shouldn't apply the model builder operations, and just pass the function if _IS_BUILDING_MODEL.get(): return f(*args, **kwargs) # this means we're in the outer model, and have to apply the model builder operations _IS_BUILDING_MODEL.set(True) prev_builder_defaults = _CURRENT_MODEL_BUILDER_DEFAULTS.get() try: default_builder_kwargs = _CURRENT_MODEL_BUILDER_DEFAULTS.get() if default_builder_kwargs is not None: for key in ("seed", "model_dtype", "type_names"): if key not in kwargs and key in default_builder_kwargs: kwargs[key] = default_builder_kwargs[key] model_cfg = kwargs.copy() # === sanity checks === assert global_state_initialized(), ( "global state must be initialized before building models" ) assert all( key in kwargs for key in ["seed", "model_dtype", "type_names"] ), ( "`seed`, `model_dtype`, and `type_names` are mandatory model arguments." ) if get_latest_global_state().get(TF32_KEY, False): assert kwargs["model_dtype"] == "float32", ( "`allow_tf32=True` only works with `model_dtype=float32`" ) # seed and model_dtype are removed from kwargs, so they will NOT get passed to inner models seed = kwargs.pop("seed") model_dtype = kwargs.pop("model_dtype") dtype = dtype_from_name(model_dtype) inherited_builder_defaults = { "seed": seed, "model_dtype": model_dtype, "type_names": kwargs["type_names"], } _CURRENT_MODEL_BUILDER_DEFAULTS.set(inherited_builder_defaults) # === compilation options === # `compile_mode` dictates the optimization path chosen # users can set this with the `compile_mode` arg to the model builder # devs can override it with `override_model_compile_mode` # always pop because inner models won't need `compile_mode` arg compile_mode = kwargs.pop("compile_mode", _CURRENT_COMPILE_MODE.get()) # compile mode overriding logic if _OVERRIDE_COMPILE_MODE.get(): compile_mode = _CURRENT_COMPILE_MODE.get() assert compile_mode in _COMPILE_MODE_OPTIONS, ( f"`compile_mode` can only be any of {_COMPILE_MODE_OPTIONS}, but `{compile_mode}` found" ) # use custom wrapper class or default if compile_mode == _TRAIN_TIME_COMPILE_KEY: # === torch version check === from nequip.utils.versions import check_pt2_compile_compatibility check_pt2_compile_compatibility() graph_model_module = compile_wrapper_class else: graph_model_module = wrapper_class # never script with conditional_torchscript_mode(False): # set dtype and seed with torch_default_dtype(dtype): with isolate_rng(): torch.manual_seed(seed) model = f(*args, **kwargs) # wrap with GraphModel graph_model = graph_model_module( model=model, model_config=model_cfg, model_input_fields=model.irreps_in, ) return graph_model finally: _CURRENT_MODEL_BUILDER_DEFAULTS.set(prev_builder_defaults) # reset to default in case of failure _IS_BUILDING_MODEL.set(False) return wrapper # handle both @model_builder and @model_builder(...) if func is None: # called with arguments: @model_builder(wrapper_class=X) return decorator else: # called without arguments: @model_builder return decorator(func)