Source code for nequip.model.saved_models.checkpoint

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
"""
Functions for loading models from checkpoint files.
"""

import torch
import hydra
import warnings

from nequip.model.utils import (
    override_model_compile_mode,
    _EAGER_MODEL_KEY,
)
from nequip.data import AtomicDataDict
from nequip.data.transforms import NonPeriodicCellTransform


from nequip.utils.global_dtype import _GLOBAL_DTYPE
from nequip.utils.logger import RankedLogger

from ._utils import _check_compile_mode, _check_file_exists

# === setup logging ===
logger = RankedLogger(__name__, rank_zero_only=True)


[docs] def ModelFromCheckpoint(checkpoint_path: str, compile_mode: str = _EAGER_MODEL_KEY): """Builds model from a NequIP framework checkpoint file. This function can be used in the config file as follows. .. code-block:: yaml model: _target_: nequip.model.ModelFromCheckpoint checkpoint_path: path/to/ckpt compile_mode: eager/compile .. warning:: DO NOT CHANGE the directory structure or location of the checkpoint file if this model loader is used for training. Any process that loads a checkpoint produced from training runs originating from a package file will look for the original package file at the location specified during training. It is also recommended to use full paths (instead or relative paths) to avoid potential errors. Args: checkpoint_path (str): path to a ``nequip`` framework checkpoint file compile_mode (str): ``eager`` or ``compile`` allowed for training """ # === sanity checks === _check_file_exists(file_path=checkpoint_path, file_type="checkpoint") _check_compile_mode(compile_mode, "ModelFromCheckpoint") logger.info(f"Loading model from checkpoint file: {checkpoint_path} ...") # === load checkpoint and extract info === checkpoint = torch.load( checkpoint_path, map_location="cpu", weights_only=False, ) # === versions === ckpt_versions = checkpoint["hyper_parameters"]["info_dict"]["versions"] from nequip.utils import get_current_code_versions session_versions = get_current_code_versions(verbose=False) for code, session_version in session_versions.items(): if code in ckpt_versions: ckpt_version = ckpt_versions[code] # sanity check that versions for current build matches versions from ckpt if ckpt_version != session_version: warnings.warn( f"`{code}` versions differ between the checkpoint file ({ckpt_version}) and the current run ({session_version}) -- `ModelFromCheckpoint` will be built with the current run's versions, but please check that this decision is as intended." ) # === load model via lightning module === training_module = hydra.utils.get_class( checkpoint["hyper_parameters"]["info_dict"]["training_module"]["_target_"] ) # ensure that model is built with correct `compile_mode` with override_model_compile_mode(compile_mode): lightning_module = training_module.load_from_checkpoint(checkpoint_path) model = lightning_module.evaluation_model return model
def data_dict_from_checkpoint(ckpt_path: str) -> AtomicDataDict.Type: from nequip.utils.dtype import torch_default_dtype with torch_default_dtype(_GLOBAL_DTYPE): # === get data from checkpoint === checkpoint = torch.load( ckpt_path, map_location="cpu", weights_only=False, ) data_config = checkpoint["hyper_parameters"]["info_dict"]["data"].copy() if "train_dataloader" not in data_config: data_config["train_dataloader"] = { "_target_": "torch.utils.data.DataLoader" } data_config["train_dataloader"]["batch_size"] = 1 datamodule = hydra.utils.instantiate(data_config, _recursive_=False) # TODO: better way of doing this? # instantiate the datamodule, dataset, and get train dataloader try: datamodule.prepare_data() # instantiate train dataset datamodule.setup(stage="fit") dloader = datamodule.train_dataloader() for data in dloader: if AtomicDataDict.num_nodes(data) > 3: break finally: datamodule.teardown(stage="fit") # === sanitize data === if AtomicDataDict.CELL_KEY not in data: # try to construct sensible cell for nonperiodic system transform = NonPeriodicCellTransform(padding=10.0, override_cell=False) data = transform(data) # if still no cell (transform was no-op), create a large cell if AtomicDataDict.CELL_KEY not in data: data[AtomicDataDict.CELL_KEY] = 1e5 * torch.eye( 3, dtype=_GLOBAL_DTYPE, device=data[AtomicDataDict.POSITIONS_KEY].device, ).unsqueeze(0) data[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = torch.zeros( (AtomicDataDict.num_edges(data), 3), dtype=_GLOBAL_DTYPE, device=data[AtomicDataDict.POSITIONS_KEY].device, ) return data