# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
import lightning
from lightning.pytorch.utilities.warnings import PossibleUserWarning
from hydra.utils import instantiate
from hydra.utils import get_method, get_class
from nequip.data import AtomicDataDict
from nequip.utils import RankedLogger
import warnings
from typing import Optional, Dict
logger = RankedLogger(__name__, rank_zero_only=True)
# metrics are already synced before logging, but Lightning still sends a PossibleUserWarning about setting sync_dist=True in self.logdict()
warnings.filterwarnings(
"ignore",
message=".*when logging on epoch level in distributed setting to accumulate the metric across.*",
category=PossibleUserWarning,
)
_SOLE_MODEL_KEY = "sole_model"
[docs]
class NequIPLightningModule(lightning.LightningModule):
""":class:`~lightning.pytorch.core.LightningModule` for training, validating, testing and predicting with models constructed in the NequIP ecosystem.
**Data**
The ``NequIPLightningModule`` supports a single ``train`` dataset, but multiple ``val`` and ``test`` datasets.
**Run Types and Metrics**
- For ``train`` runs, users must provide ``loss`` and ``val_metrics``. The ``loss`` is computed on the training dataset to train the model, and requires each metric to have a corresponding coefficient that will be used to generate a ``weighted_sum``. This ``weighted_sum`` is the loss function that will be minimized over the course of training. ``val_metrics`` is computed on the validation dataset(s) for monitoring. Additionally, users may provide ``train_metrics`` to monitor metrics on the training dataset.
- For ``val`` runs, users must provide ``val_metrics``.
- For ``test`` runs, users must provide ``test_metrics``.
**Logging Conventions**
Logging is performed for the ``train``, ``val``, and ``test`` datasets.
During ``train`` runs,
* logging occurs at each batch ``step`` and at each ``epoch``,
* there is only one training set, so no ``data_idx`` is used in the logging.
For ``val`` and ``test`` runs,
* logging only occurs at each validation or testing ``epoch``, i.e. one pass over the entirety of each validation/testing dataset,
* there can be multiple validation and testing sets, so a zero-based ``data_idx`` index is used in the logging.
Logging Format
* ``/`` is used as a delimiter for to exploit the automatic grouping functionality of most loggers. Logged metrics will have the form ``train_{loss/metric}_{step/epoch}/{metric_name}`` and ``{val/test}{data_idx}_epoch/{metric_name}``. For example, ``train_loss_step/force_MSE``, ``train_metric_epoch/E_MAE``, ``val0_epoch/F_RMSE``, etc.
* Note that this may have implications on how one would set the parameters for the `ModelCheckpoint <https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html>`_ callback, i.e. if the name of a metric is used in the checkpoint file's name, the ``/`` will cause a directory to be created when instead a file is desired.
"""
def __init__(
self,
model: Dict,
num_datasets: Dict[str, int],
optimizer: Optional[Dict] = None,
lr_scheduler: Optional[Dict] = None,
loss: Optional[Dict] = None,
train_metrics: Optional[Dict] = None,
val_metrics: Optional[Dict] = None,
test_metrics: Optional[Dict] = None,
# for caching training info
info_dict: Optional[Dict] = None,
):
super().__init__()
# save arguments to instantiate LightningModule from checkpoint automatically
self.save_hyperparameters()
# === instantiate model ===
model_object = self._build_model(model)
# === account for multiple models ===
# contract:
# - for multiple models, they must be in the form of a `ModuleDict` of `GraphModel`s
# - if a single `GraphModel` is provided, we wrap it in a `ModuleDict`
# - all models must have the same `type_names`
# the reason for `hasattr(x, "is_graph_model")` and not just `isinstance(x, GraphModel)`
# is to support `GraphModel` from a `nequip-package`d model (see https://pytorch.org/docs/stable/package.html#torch-package-sharp-edges)
assert isinstance(model_object, torch.nn.ModuleDict) or hasattr(
model_object, "is_graph_model"
)
if not isinstance(model_object, torch.nn.ModuleDict):
model_object = torch.nn.ModuleDict({_SOLE_MODEL_KEY: model_object})
self.model = model_object
type_names_list = []
for k, v in self.model.items():
assert hasattr(v, "is_graph_model")
type_names_list.append(v.type_names)
logger.debug(f"Built Model Details ({k}):\n{str(v)}")
assert all(
[
all(
[
name1 == name2
for (name1, name2) in zip(type_names_list[0], type_names)
]
)
for type_names in type_names_list
]
), "If multiple models are used, they must have the same type names parameter."
type_names = type_names_list[0] # passed to `MetricsManager`s later
# === optimizer and lr scheduler ===
self.optimizer_config = optimizer
self.lr_scheduler_config = lr_scheduler
# === instantiate MetricsManager objects ===
# must have separate MetricsManagers for each dataloader
# num_datasets goes in order [train, val, test, predict]
self.num_datasets = num_datasets
assert self.num_datasets["train"] == 1, (
"currently only support one training dataset"
)
# == DDP concerns for loss ==
# to account for loss contributions from multiple ranks later on
# NOTE: this must be updated externally by the script that sets up the training run
self.world_size = 1
# == instantiate loss ==
self.loss = instantiate(loss, type_names=type_names)
if self.loss is not None:
assert self.loss.do_weighted_sum, (
"`coeff` must be set for entries of the `loss` MetricsManager for a weighted sum of metrics components to be used as the loss."
)
# set `dist_sync_on_step=True` for loss metrics
# to ensure correct DDP syncing of loss function for batch steps
for metric in self.loss.values():
metric.dist_sync_on_step = True
# == instantiate other metrics ==
self.train_metrics = instantiate(train_metrics, type_names=type_names)
# may need to instantate multiple instances to account for multiple val and test datasets
self.val_metrics = torch.nn.ModuleList(
[
instantiate(val_metrics, type_names=type_names)
for _ in range(self.num_datasets["val"])
]
)
self.test_metrics = torch.nn.ModuleList(
[
instantiate(test_metrics, type_names=type_names)
for _ in range(self.num_datasets["test"])
]
)
# use "/" as delimiter for loggers to automatically categorize logged metrics
self.logging_delimiter = "/"
# for statefulness of the run stage
self.register_buffer("run_stage", torch.zeros((1), dtype=torch.long))
def _build_model(self, model_config: Dict) -> torch.nn.ModuleDict:
"""Constructs a ``torch.nn.ModuleDict[str, nequip.nn.GraphModel]`` from a pure Python dictionary.
Subclasses that require more control over how the model is built can override this method.
"""
# reason for following implementation instead of just `hydra.utils.instantiate(model)` is to prevent omegaconf from being a model dependency
model_config = model_config.copy() # make a copy because of `pop` mutation
model_builder = get_method(model_config.pop("_target_"))
model = model_builder(**model_config)
return model
def configure_optimizers(self):
""""""
# currently support 1 optimizer and 1 scheduler
# potentially support N optimzier and N scheduler
# (see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers)
optimizer_config = self.optimizer_config.copy()
param_groups = optimizer_config.pop(
"param_groups",
{"_target_": "nequip.train.lightning._default_param_group_factory"},
)
param_groups = instantiate(param_groups, model=self.model)
optimizer_class = optimizer_config.pop("_target_")
optim = get_class(optimizer_class)(params=param_groups, **optimizer_config)
if self.lr_scheduler_config is None:
return optim
def _instantiate_scheduler(scheduler_config: dict, optimizer):
scheduler_config = dict(
scheduler_config
) # just in case, because of pop mutation
# NOTE: This assumes that nested schedulers always have a "schedulers" key
inner_configs = scheduler_config.pop("schedulers", None)
# Recursively instantiate inner schedulers if we use nested schedulers (e.g. ChainedScheduler, SequentialLR)
if inner_configs is not None:
inner_schedulers = [
_instantiate_scheduler(inner_config, optimizer)
for inner_config in inner_configs
]
return instantiate(
scheduler_config, optimizer=optimizer, schedulers=inner_schedulers
)
# Base case: instantiate a regular scheduler
return instantiate(scheduler_config, optimizer=optimizer)
# instantiate lr scheduler object separately to pass the optimizer to it during instantiation
lr_scheduler_config = dict(self.lr_scheduler_config.copy())
scheduler_config = lr_scheduler_config.pop("scheduler")
scheduler = _instantiate_scheduler(scheduler_config, optim)
lr_scheduler = dict(instantiate(lr_scheduler_config))
lr_scheduler.update({"scheduler": scheduler})
return {"optimizer": optim, "lr_scheduler": lr_scheduler}
def forward(self, inputs: AtomicDataDict.Type):
""""""
# enable grad for forces, stress, etc
with torch.enable_grad():
# multi-model subclasses will need to override this function
return self.model[_SOLE_MODEL_KEY](inputs)
@property
def evaluation_model(self) -> torch.nn.Module:
return self.model
def process_target(
self, batch: AtomicDataDict.Type, batch_idx: int, dataloader_idx: int = 0
) -> AtomicDataDict.Type:
""""""
# subclasses can override this function
return batch.copy()
def training_step(
self, batch: AtomicDataDict.Type, batch_idx: int, dataloader_idx: int = 0
):
""""""
target = self.process_target(batch, batch_idx, dataloader_idx)
output = self(batch)
# optionally compute training metrics
if self.train_metrics is not None:
with torch.no_grad():
train_metric_dict = self.train_metrics(
output, target, prefix=f"train_metric_step{self.logging_delimiter}"
)
self.log_dict(train_metric_dict)
# compute loss and return
loss_dict = self.loss(
output, target, prefix=f"train_loss_step{self.logging_delimiter}"
)
self.log_dict(loss_dict)
# In DDP training, because gradients are averaged rather than summed over nodes,
# we get an effective factor of 1/n_rank applied to the loss. Because our loss already
# manages correct accumulation of the metric over ranks, we want to cancel out this
# unnecessary 1/n_rank term. If DDP is disabled, this is 1 and has no effect.
loss = (
loss_dict[f"train_loss_step{self.logging_delimiter}weighted_sum"]
* self.world_size
)
return loss
def on_train_epoch_end(self):
""""""
# optionally compute training metrics
if self.train_metrics is not None:
train_metric_dict = self.train_metrics.compute(
prefix=f"train_metric_epoch{self.logging_delimiter}"
)
self.log_dict(train_metric_dict)
self.train_metrics.reset()
# loss
loss_dict = self.loss.compute(
prefix=f"train_loss_epoch{self.logging_delimiter}"
)
self.log_dict(loss_dict)
self.loss.reset()
def validation_step(
self, batch: AtomicDataDict.Type, batch_idx: int, dataloader_idx: int = 0
):
""""""
target = self.process_target(batch, batch_idx, dataloader_idx)
# === update basic val metrics ===
output = self(batch)
with torch.no_grad():
metric_dict = self.val_metrics[dataloader_idx](
output,
target,
prefix=f"val{dataloader_idx}_step{self.logging_delimiter}",
)
metric_dict.update({f"val_{dataloader_idx}_output": output})
return metric_dict
def on_validation_epoch_end(self):
""""""
# === reset basic val metrics ===
for idx, metrics in enumerate(self.val_metrics):
metric_dict = metrics.compute(
prefix=f"val{idx}_epoch{self.logging_delimiter}"
)
self.log_dict(metric_dict)
metrics.reset()
def test_step(
self, batch: AtomicDataDict.Type, batch_idx: int, dataloader_idx: int = 0
):
""""""
target = self.process_target(batch, batch_idx, dataloader_idx)
# === update basic test metrics ===
output = self(batch)
with torch.no_grad():
metric_dict = self.test_metrics[dataloader_idx](
output,
target,
prefix=f"test{dataloader_idx}_step{self.logging_delimiter}",
)
metric_dict.update({f"test_{dataloader_idx}_output": output})
return metric_dict
def on_test_epoch_end(self):
""""""
# === reset basic test metrics ===
for idx, metrics in enumerate(self.test_metrics):
metric_dict = metrics.compute(
prefix=f"test{idx}_epoch{self.logging_delimiter}"
)
self.log_dict(metric_dict)
metrics.reset()
def _default_param_group_factory(model):
return model.parameters()