Source code for nequip.scripts._compile_utils
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from nequip.data import AtomicDataDict
from typing import Dict, List, Callable, Union
# === Inputs and Outputs for AOT Compile ===
# standard sets of input and output fields for specific integrations
AOTI_PAIR_NEQUIP_TARGET = "pair_nequip"
AOTI_ASE_TARGET = "ase"
AOTI_BATCH_TARGET = "batch"
PAIR_NEQUIP_INPUTS = [
AtomicDataDict.POSITIONS_KEY,
AtomicDataDict.EDGE_INDEX_KEY,
AtomicDataDict.ATOM_TYPE_KEY,
AtomicDataDict.CELL_KEY,
AtomicDataDict.EDGE_CELL_SHIFT_KEY,
]
BATCH_INPUTS = PAIR_NEQUIP_INPUTS + [
AtomicDataDict.BATCH_KEY,
AtomicDataDict.NUM_NODES_KEY,
]
LMP_OUTPUTS = [
AtomicDataDict.PER_ATOM_ENERGY_KEY,
AtomicDataDict.FORCE_KEY,
AtomicDataDict.VIRIAL_KEY,
]
ASE_OUTPUTS = [
AtomicDataDict.PER_ATOM_ENERGY_KEY,
AtomicDataDict.TOTAL_ENERGY_KEY,
AtomicDataDict.FORCE_KEY,
AtomicDataDict.STRESS_KEY,
]
# === batch map rules ===
def single_frame_batch_map_settings(batch_map):
# make num_frames batch dims static, for single frame case
# relevant for single-frame use cases, e.g. pair_nequip and ase
batch_map["graph"] = torch.export.Dim.STATIC
return batch_map
# === data rules ===
def single_frame_data_settings(data):
# because of the 0/1 specialization problem,
# and the fact that the LAMMPS pair style (and ASE) requires `num_frames=1`
# we need to augment to data to remove the `BATCH_KEY` and `NUM_NODES_KEY`
# to take more optimized code paths
if AtomicDataDict.BATCH_KEY in data:
data.pop(AtomicDataDict.BATCH_KEY)
data.pop(AtomicDataDict.NUM_NODES_KEY)
return data
def batched_data_settings(data):
assert AtomicDataDict.BATCH_KEY in data
assert AtomicDataDict.NUM_NODES_KEY in data
# just make a batch of 2 frames to avoid 0/1 specialization problem later on
data = AtomicDataDict.batched_from_list([data, data])
return data
PAIR_NEQUIP_TARGET = {
"input": PAIR_NEQUIP_INPUTS,
"output": LMP_OUTPUTS,
"batch_map_settings": single_frame_batch_map_settings,
"data_settings": single_frame_data_settings,
}
ASE_TARGET = {
"input": PAIR_NEQUIP_INPUTS,
"output": ASE_OUTPUTS,
"batch_map_settings": single_frame_batch_map_settings,
"data_settings": single_frame_data_settings,
}
BATCH_TARGET = {
"input": BATCH_INPUTS,
"output": ASE_OUTPUTS,
"batch_map_settings": lambda batch_map: batch_map, # no static shapes
"data_settings": batched_data_settings,
}
COMPILE_TARGET_DICT = {
AOTI_PAIR_NEQUIP_TARGET: PAIR_NEQUIP_TARGET,
AOTI_ASE_TARGET: ASE_TARGET,
AOTI_BATCH_TARGET: BATCH_TARGET,
}
[docs]
def register_compile_targets(
target_dict: Dict[str, Union[List[str], Callable]],
) -> None:
"""Register compile targets for AOT compilation.
The intended clients of this function are NequIP extension packages to register their custom compilation targets.
Args:
target_dict: dict containing keys ``input``, ``output``, ``batch_map_settings``, ``data_settings``
"""
# update target dict
global COMPILE_TARGET_DICT
COMPILE_TARGET_DICT.update(target_dict)