Source code for nequip.train.callbacks.wandb_watch
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import lightning
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers.wandb import WandbLogger
from nequip.train import NequIPLightningModule
[docs]
class WandbWatch(Callback):
"""Monitor and log weights and gradients during training with PyTorch Lightning's :class:`~lightning.pytorch.loggers.WandbLogger`.
This class provides a way to call https://docs.wandb.ai/ref/python/watch/ when using a :class:`~lightning.pytorch.loggers.WandbLogger` for monitoring weights and gradients over the course of training.
Args:
log_freq (int): frequency (in batches) to log gradients and parameters
log (str): specifies whether to log ``"gradients"``, ``"parameters"``, or ``"all"``
log_graph (bool): whether to log the model's computational graph
"""
def __init__(
self,
log_freq: int,
log: str = "gradients",
log_graph: bool = False,
):
self.log_freq = log_freq
assert log in ["gradients", "parameters", "all", None]
self.wandb_log_setting = log # use 'wandb_log_setting' as attr name to avoid conflict with LightningModule.log()
self.log_graph = log_graph
def on_train_start(
self,
trainer: lightning.Trainer,
pl_module: NequIPLightningModule,
) -> None:
""""""
assert isinstance(trainer.logger, WandbLogger), (
"NequIP's `WandbWatch` callback only works for `WandbLogger` loggers"
)
# see https://docs.wandb.ai/ref/python/watch/
trainer.logger.watch(
pl_module.model,
log=self.wandb_log_setting,
log_freq=self.log_freq,
log_graph=self.log_graph,
)
def on_train_end(
self,
trainer: lightning.Trainer,
pl_module: NequIPLightningModule,
) -> None:
""""""
# see unwatch syntax
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html#lightning.pytorch.loggers.wandb.WandbLogger
trainer.logger.experiment.unwatch(pl_module.model)