Saved Models¶
nequip-train can be used to train, validate, and test models loaded from saved model files.
Loading Saved Models¶
There are two main forms of saved models that can be loaded for use in training, validation, and/or testing with nequip-train or custom Python scripts.
There are checkpoint files (saved during nequip-train training runs) and package files (constructed with nequip-package and has the .nequip.zip extension).
These files can be loaded using the following ModelFromCheckpoint and ModelFromPackage model loaders.
Important
The following model loaders save the paths you provide to them into the checkpoint file created for the new training run. Loading that new checkpoint file, for example to restart training or package a model, will also require loading the file at the original path you provided during the initial invocation of ModelFromCheckpoint or ModelFromPackage.
Moving or modifying those original checkpoint or model files will cause loading the new checkpoint to fail.
As a result, the following is recommended when using ModelFromCheckpoint or ModelFromPackage.
Use absolute paths instead of relative paths.
Do not change the directory structure or move your files when using the model loaders.
Ideally, store the original checkpoint/package files somewhere that makes their association with the new training run clear to you.
Be aware that iterated nested use of ModelFromCheckpoint will result in a checkpoint chaining phenomenon where loading the checkpoint at the end of the chain requires successfully loading every intermediate checkpoint file in the chain. One can break this chain if necessary by using nequip-package to convert the checkpoint file into a packaged model, and then using ModelFromPackage.
- nequip.model.ModelFromCheckpoint(checkpoint_path: str, compile_mode: str = 'eager')[source]¶
Builds model from a NequIP framework checkpoint file.
This function can be used in the config file as follows.
model: _target_: nequip.model.ModelFromCheckpoint checkpoint_path: path/to/ckpt compile_mode: eager/compile
Warning
DO NOT CHANGE the directory structure or location of the checkpoint file if this model loader is used for training. Any process that loads a checkpoint produced from training runs originating from a package file will look for the original package file at the location specified during training. It is also recommended to use full paths (instead or relative paths) to avoid potential errors.
- nequip.model.ModelFromPackage(package_path: str, compile_mode: str = 'eager')[source]¶
Builds model from a NequIP framework packaged zip file constructed with
nequip-package.This function can be used in the config file as follows.
model: _target_: nequip.model.ModelFromPackage package_path: path/to/pkg compile_mode: eager/compile
Warning
DO NOT CHANGE the directory structure or location of the package file if this model loader is used for training. Any process that loads a checkpoint produced from training runs originating from a package file will look for the original package file at the location specified during training. It is also recommended to use full paths (instead or relative paths) to avoid potential errors.
Modifying Saved Models¶
- nequip.model.modify(model: Dict[str, Module] | Module, modifiers: List[Dict[str, Any]] | Dict[str, List[Dict[str, Any]]]) Dict[str, Module] | Module[source]¶
Applies a sequence of model modifier functions to a model.
The modifiers will be applied in the specified order. Whether the order of modifiers matters depends on the specific modifiers used.
- Parameters:
model (Union[Dict[str, torch.nn.Module], torch.nn.Module]) – The model(s) to modify.
modifiers (Union[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]) – A list of modifier configurations (if
modelis a single model) or a dictionary mapping model names to lists of modifier configurations (ifmodelis a dictionary). Each modifier configuration is a dictionary. The dictionary must contain a key “modifier” that specifies the name of the modifier function to apply as a string. All other keys in the dictionary are passed as keyword arguments to the modifier function.
- Returns:
The modified model(s).
- Return type:
Union[Dict[str, torch.nn.Module], torch.nn.Module]