Source code for nequip.model.saved_models.load_utils
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import contextlib
import pathlib
import requests
from tqdm.auto import tqdm
from nequip.model.utils import _EAGER_MODEL_KEY
from nequip.model.saved_models import ModelFromPackage, ModelFromCheckpoint
from nequip.model.modify_utils import only_apply_persistent_modifiers
from nequip.train.lightning import _SOLE_MODEL_KEY
from nequip.utils import model_repository
from nequip.utils.logger import RankedLogger
from nequip.utils.model_cache import get_cached_model, cache_model
logger = RankedLogger(__name__, rank_zero_only=True)
@contextlib.contextmanager
def _get_model_file_path(input_path):
"""Context manager that provides a file path for both local and nequip.net models.
For local files: yields the input path directly
For nequip.net downloads: uses cache if available, otherwise downloads and caches
(default cache location: ``~/.nequip/model_cache``, configurable via ``NEQUIP_CACHE_DIR``)
Args:
input_path: path to the model checkpoint or package file, or nequip.net model ID
(format: ``nequip.net:group-name/model-name:version``)
Yields:
pathlib.Path: Path to the model file (either original or cached)
"""
is_nequip_net_download: bool = str(input_path).startswith("nequip.net:")
if is_nequip_net_download:
# get model ID
model_id = str(input_path)[len("nequip.net:") :]
logger.info(f"Fetching {model_id} from nequip.net...")
# get download URL
with model_repository.NequIPNetAPIClient() as client:
model_info = client.get_model_download_info(model_id)
if model_info.newer_version_id is not None:
logger.info(
f"Model {model_id} has a newer version available: {model_info.newer_version_id}"
)
download_url = model_info.artifact.download_url
# check cache first
cached_path = get_cached_model(model_id, download_url)
if cached_path is not None:
yield cached_path
return
# cache miss: download and cache
def download_fn(target_path: pathlib.Path):
response = requests.get(download_url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
with open(target_path, "wb") as f:
with tqdm(
total=total_size,
unit="B",
unit_scale=True,
desc=f"Downloading from {model_info.artifact.host_name}",
) as pbar:
for chunk in response.iter_content(chunk_size=65536):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
# download and cache (cache_model will skip caching if NEQUIP_NO_CACHE is set)
cached_path = cache_model(model_id, download_url, download_fn)
logger.info("Download complete, loading model...")
yield cached_path
else:
logger.info(f"Loading model from {input_path} ...")
yield pathlib.Path(input_path)
[docs]
def load_saved_model(
input_path,
compile_mode: str = _EAGER_MODEL_KEY,
model_key: str = _SOLE_MODEL_KEY,
return_data_dict: bool = False,
):
"""Load a saved model from checkpoint, package, or nequip.net.
This function can load models from:
- **Checkpoint files** (``.ckpt``): saved during training runs
- **Package files** (``.nequip.zip``): created with ``nequip-package``
- **nequip.net models**: using model ID format ``nequip.net:group-name/model-name:version`` from `nequip.net <https://www.nequip.net/>`__
Args:
input_path: path to the model checkpoint or package file, or nequip.net model ID
(format: ``nequip.net:group-name/model-name:version``)
compile_mode (str): compile mode for the model (default: ``"eager"``)
model_key (str): key to select the model from ModuleDict (default: ``"sole_model"``)
return_data_dict (bool): if ``True``, also return the data dict for compilation (default: ``False``)
Returns:
torch.nn.Module or tuple: the loaded model, or ``(model, data)`` tuple if ``return_data_dict=True``
"""
with _get_model_file_path(input_path) as actual_input_path:
# check if the resolved file exists
if not actual_input_path.exists():
raise ValueError(
f"Model file does not exist: {input_path} (resolved to: {actual_input_path})"
)
# use package load path if extension matches, otherwise assume checkpoint file
use_ckpt = not str(actual_input_path).endswith(".nequip.zip")
# load model
if use_ckpt:
# we only apply persistent modifiers when building from checkpoint
# i.e. acceleration modifiers won't be applied, and have to be specified during compile time
with only_apply_persistent_modifiers(persistent_only=True):
model = ModelFromCheckpoint(
actual_input_path, compile_mode=compile_mode
)
else:
# packaged models will never have non-persistent modifiers built in
model = ModelFromPackage(actual_input_path, compile_mode=compile_mode)
if model_key is not None:
model = model[model_key]
# ^ `ModuleDict` of `GraphModel` is loaded, we then select the desired `GraphModel` (`model_key` defaults to work for single model case)
# otherwise, return the `ModuleDict`
# load data dict if requested
if return_data_dict:
from nequip.model.saved_models.checkpoint import data_dict_from_checkpoint
from nequip.model.saved_models.package import data_dict_from_package
if use_ckpt:
data = data_dict_from_checkpoint(str(actual_input_path))
else:
data = data_dict_from_package(str(actual_input_path))
return model, data
else:
return model