Source code for nequip.model.inference_models.compiled
# This file is a part of the `nequip` package. Please see LICENSE and README
# at the root for information on using it.
import torch
from pathlib import Path
from typing import Union, Tuple, List, Optional
from .torchscript import load_torchscript_model
from .aotinductor import load_aotinductor_model
from nequip.utils.global_state import TF32_KEY, set_global_state
[docs]
def load_compiled_model(
compile_path: str,
device: Union[str, torch.device],
input_keys: Optional[List[str]] = None,
output_keys: Optional[List[str]] = None,
) -> Tuple[torch.nn.Module, dict]:
"""Load a compiled model from either TorchScript or AOTInductor format.
This function can load compiled models created with ``nequip-compile``:
- **TorchScript models** (``.nequip.pth``): legacy compiled format
- **AOT Inductor models** (``.nequip.pt2``): modern compiled format with better performance
Args:
compile_path: path to compiled model file (``.nequip.pth`` or ``.nequip.pt2``)
device: the device to use
input_keys: optional input field names for AOTInductor models (for ``.nequip.pt2``)
output_keys: optional output field names for AOTInductor models (for ``.nequip.pt2``)
Returns:
tuple: ``(model, metadata)`` with model prepared for inference
"""
compile_fname = Path(compile_path).name
if compile_fname.endswith(".nequip.pth"):
model, metadata = load_torchscript_model(compile_path, device)
elif compile_fname.endswith(".nequip.pt2"):
model, metadata = load_aotinductor_model(
compile_path, device, input_keys, output_keys
)
else:
raise ValueError(
f"Unknown file type: {compile_fname} "
f"(expected `*.nequip.pth` or `*.nequip.pt2`)"
)
# set global state from metadata
set_global_state(
**{
TF32_KEY: bool(int(metadata[TF32_KEY])),
}
)
# prepare model for inference
model = model.to(device)
model.eval()
return model, metadata