Source code for nequip.model.saved_models.package

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
"""
Functions for loading models from package files.
"""

import torch
import yaml
import warnings
import contextlib
import io
from typing import Dict, Any

from nequip.data import AtomicDataDict
from nequip.model.utils import (
    get_current_compile_mode,
    _EAGER_MODEL_KEY,
)
from nequip.scripts._workflow_utils import get_workflow_state
from nequip.utils.logger import RankedLogger

from ._utils import _check_compile_mode, _check_file_exists
from nequip.utils.asserts import assert_package_extension

# === setup logging ===
logger = RankedLogger(__name__, rank_zero_only=True)


@contextlib.contextmanager
def _cpu_deserialize_if_no_cuda():
    """Force CUDA-saved storages inside packaged models to load on CPU when CUDA is unavailable."""
    if torch.cuda.is_available():
        yield
        return

    orig = torch.storage._load_from_bytes

    def _load_from_bytes_cpu(b):
        return torch.load(io.BytesIO(b), map_location="cpu", weights_only=False)

    torch.storage._load_from_bytes = _load_from_bytes_cpu
    try:
        yield
    finally:
        torch.storage._load_from_bytes = orig


# === package importer utilities ===
# most of the complexity for `ModelFromPackage` is due to the need to keep track of the `Importer` if we ever repackage
# see `nequip/scripts/package.py` to get the full picture of how they interact
# we expect the following variable to only be used during `nequip-package`

_PACKAGE_TIME_SHARED_IMPORTER = None


def _get_shared_importer():
    global _PACKAGE_TIME_SHARED_IMPORTER
    return _PACKAGE_TIME_SHARED_IMPORTER


def _get_package_metadata(imp) -> Dict[str, Any]:
    """Load packaged model metadata from an existing PackageImporter."""
    pkg_metadata: Dict[str, Any] = yaml.safe_load(
        imp.load_text(package="model", resource="package_metadata.txt")
    )
    assert int(pkg_metadata["package_version_id"]) > 0
    # ^ extra sanity check since saving metadata in txt files was implemented in packaging version 1

    return pkg_metadata


# === warning management ===


@contextlib.contextmanager
def _suppress_package_importer_exporter_warnings():
    # Ideally this ceases to exist or becomes a no-op in future versions of PyTorch
    with warnings.catch_warnings():
        # suppress torch.package TypedStorage warning
        warnings.filterwarnings(
            "ignore",
            message="TypedStorage is deprecated.*",
            category=UserWarning,
            module=r"torch\.package\.(package_exporter|package_importer)",
        )
        yield


# === loading models from package files ===


[docs] def ModelFromPackage(package_path: str, compile_mode: str = _EAGER_MODEL_KEY): """Builds model from a NequIP framework packaged zip file constructed with ``nequip-package``. This function can be used in the config file as follows. .. code-block:: yaml model: _target_: nequip.model.ModelFromPackage package_path: path/to/pkg compile_mode: eager/compile .. warning:: DO NOT CHANGE the directory structure or location of the package file if this model loader is used for training. Any process that loads a checkpoint produced from training runs originating from a package file will look for the original package file at the location specified during training. It is also recommended to use full paths (instead or relative paths) to avoid potential errors. Args: package_path (str): path to NequIP framework packaged model with the ``.nequip.zip`` extension (an error will be thrown if the file has a different extension) compile_mode (str): ``eager`` or ``compile`` allowed for training """ # === sanity checks === _check_file_exists(file_path=package_path, file_type="package") assert_package_extension(package_path) # === account for checkpoint loading === # if `ModelFromPackage` is used by itself, `override=False` and the input `compile_mode` argument is used # if this function is called at the end of checkpoint loading via `ModelFromCheckpoint`, `override=True` and the overriden `compile_mode` takes precedence cm, override = get_current_compile_mode(return_override=True) compile_mode = cm if override else compile_mode # === sanity check compile modes === workflow_state = get_workflow_state() _check_compile_mode(compile_mode, "ModelFromPackage") # === load model === logger.info(f"Loading model from package file: {package_path} ...") with _suppress_package_importer_exporter_warnings(): # during `nequip-package`, we need to use the same importer for all the models for successful repackaging # see https://pytorch.org/docs/stable/package.html#re-export-an-imported-object if workflow_state == "package": global _PACKAGE_TIME_SHARED_IMPORTER imp = _PACKAGE_TIME_SHARED_IMPORTER # we load the importer from `package_path` for the first time if imp is None: imp = torch.package.PackageImporter(package_path) _PACKAGE_TIME_SHARED_IMPORTER = imp # if it's not `None`, it means we've previously loaded a model during `nequip-package` and should keep using the same importer else: # if not doing `nequip-package`, we just load a new importer every time `ModelFromPackage` is called imp = torch.package.PackageImporter(package_path) # do sanity checking with available models pkg_metadata = _get_package_metadata(imp) available_models = pkg_metadata["available_models"] # throw warning if desired `compile_mode` is not available, and default to eager if compile_mode not in available_models: warnings.warn( f"Requested `{compile_mode}` model is not present in the package file ({package_path}). `nequip-{workflow_state}` task will default to using the `{_EAGER_MODEL_KEY}` model." ) compile_mode = _EAGER_MODEL_KEY with _cpu_deserialize_if_no_cuda(): model = imp.load_pickle( package="model", resource=f"{compile_mode}_model.pkl", map_location="cpu", ) # NOTE: model returned is not a GraphModel object tied to the `nequip` in current Python env, but a GraphModel object from the packaged zip file return model
def data_dict_from_package(package_path: str) -> AtomicDataDict.Type: """Load example data from a .nequip.zip package file.""" with _suppress_package_importer_exporter_warnings(): imp = torch.package.PackageImporter(package_path) with _cpu_deserialize_if_no_cuda(): data = imp.load_pickle(package="model", resource="example_data.pkl") return data def ModelTypeNamesFromPackage(package_path: str): """Extract model type names from a packaged model file. Useful for setting up type mappers when fine-tuning models or when you need to know what atom types a model was trained on. Args: package_path (str): path to packaged model file """ from typing import List _check_file_exists(file_path=package_path, file_type="package") with _suppress_package_importer_exporter_warnings(): imp = torch.package.PackageImporter(package_path) pkg_metadata = _get_package_metadata(imp) atom_types_dict = pkg_metadata["atom_types"] # convert dict {idx: name} to list [name, ...] type_names: List[str] = [atom_types_dict[i] for i in range(len(atom_types_dict))] return type_names