Source code for nequip.train.callbacks.loss_coeff_monitor

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
import lightning
from lightning.pytorch.callbacks import Callback
from nequip.data import AtomicDataDict
from nequip.train import NequIPLightningModule


[docs] class LossCoefficientMonitor(Callback): """Monitor and log loss coefficients during training. Example usage in config to log loss coefficients every 5 epochs: .. code-block:: yaml callbacks: - _target_: nequip.train.callbacks.LossCoefficientMonitor interval: epoch frequency: 5 Args: interval (str): ``batch`` or ``epoch`` frequency (int): number of intervals between each instance of loss coefficient logging """ def __init__( self, interval: str, frequency: int, ): assert interval in ["batch", "epoch"] assert frequency >= 1 self.interval = interval self.frequency = frequency def on_train_batch_end( self, trainer: lightning.Trainer, pl_module: NequIPLightningModule, outputs: torch.Tensor, batch: AtomicDataDict.Type, batch_idx: int, ) -> None: """""" if self.interval == "batch" and trainer.global_step % self.frequency == 0: self._log_coefficients(pl_module) def on_train_epoch_end( self, trainer: lightning.Trainer, pl_module: NequIPLightningModule, ) -> None: """""" if self.interval == "epoch" and trainer.current_epoch % self.frequency == 0: self._log_coefficients(pl_module) def _log_coefficients(self, pl_module: NequIPLightningModule) -> None: for metric_name, entry in pl_module.loss.entries.items(): if entry.coeff is not None: pl_module.log("loss_coeff/" + metric_name, entry.coeff)