Source code for nequip.train.callbacks.softadapt

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
from math import sqrt, exp
import lightning
from lightning.pytorch.callbacks import Callback
from nequip.data import AtomicDataDict
from nequip.train import NequIPLightningModule
from typing import List, Dict


[docs] class SoftAdapt(Callback): """Adaptively modify loss coefficients over a training run using the `SoftAdapt <https://www.sciencedirect.com/science/article/pii/S0927025624003768>`_ scheme. Note that the implementation here differs from the original ``SoftAdapt`` scheme (which tends to 1:1:1 loss coefficient ratios), where the coefficient updates are weighted by the input loss coefficients (see `PR #515 <https://github.com/mir-group/nequip/pull/515>`_). .. warning:: The SoftAdapt requires that all components of the loss function contribute to the loss function, i.e. that their ``coeff`` in the :class:`~nequip.train.MetricsManager` is not ``None``. .. warning:: It is dangerous to restart training (with SoftAdapt) and use a differently configured loss function for the restart because SoftAdapt's loaded checkpoint state will become ill-suited for the new loss function. Example usage in config where the loss coefficients are updated every 5 epochs: .. code-block:: yaml callbacks: - _target_: nequip.train.callbacks.SoftAdapt beta: 1.1 interval: epoch frequency: 5 Args: beta (float): ``SoftAdapt`` hyperparameter (see paper) interval (str): ``batch`` or ``epoch`` frequency (int): number of intervals between loss coefficient updates eps (float): small value to avoid division by zero """ def __init__( self, beta: float, interval: str, frequency: int, eps: float = 1e-8, ): assert interval in ["batch", "epoch"] assert frequency >= 1 self.beta = beta self.interval = interval self.frequency = frequency self.eps = eps self.prev_losses: Dict[str, float] = None self.cached_coeffs: List[Dict[str, float]] = [] def _softadapt_update( self, new_losses: Dict[str, float], trainer: lightning.Trainer, pl_module: NequIPLightningModule, ): # === sanity checks === assert all( [entry.coeff is not None for entry in pl_module.loss.entries.values()] ), ( "all components of loss must have `coeff!=None` to use the SoftAdapt callback" ) if self.interval == "epoch": step = trainer.current_epoch # use epochs else: step = trainer.global_step # use batches # empty list of cached weights to store for next cycle if step % self.frequency == 0: self.cached_coeffs = [] # compute and cache new loss weights over the update cycle if self.prev_losses is None: self.prev_losses = new_losses return else: # TODO (maybe): the check could be stronger by matching the keys themselves, but might add overhead assert len(new_losses) == len(self.prev_losses) # compute loss component changes loss_changes = { k: new_losses[k] - self.prev_losses[k] for k in new_losses.keys() } # normalize and apply softmax sum_of_squares = sum( [loss_changes[k] * loss_changes[k] for k in new_losses.keys()] ) factor = self.beta / max(sqrt(sum_of_squares), self.eps) exps = {k: exp(factor * v) for k, v in loss_changes.items()} softmax_denom = sum([exps[k] for k in new_losses.keys()]) + self.eps new_coeffs = {k: exp_term / softmax_denom for k, exp_term in exps.items()} new_coeffs = { k: v * pl_module.loss.entries[k].coeff for k, v in new_coeffs.items() } # ensure normalised: new_coeffs = { k: v / sum(new_coeffs.values()) for k, v in new_coeffs.items() } # update with new coefficients self.cached_coeffs.append(new_coeffs) del new_coeffs # update previous loss components self.prev_losses = new_losses # average weights over previous cycle and update if step % self.frequency == 1: num_updates = len(self.cached_coeffs) softadapt_weights = { metric_name: sum( [self.cached_coeffs[idx][metric_name] for idx in range(num_updates)] ) / num_updates for metric_name in pl_module.loss.keys() } pl_module.loss.set_coeffs(softadapt_weights) def on_train_batch_start( self, trainer: lightning.Trainer, pl_module: NequIPLightningModule, batch: AtomicDataDict.Type, batch_idx: int, ): """""" del batch, batch_idx # unused but required by Callback interface if trainer.global_step == 0: return if self.interval == "batch": self._softadapt_update( pl_module.loss.metrics_values_step, trainer, pl_module ) def on_train_epoch_start( self, trainer: lightning.Trainer, pl_module: NequIPLightningModule, ): """""" if trainer.current_epoch == 0: return if self.interval == "epoch": self._softadapt_update( pl_module.loss.metrics_values_epoch, trainer, pl_module ) def state_dict(self): """""" return { "beta": self.beta, "interval": self.interval, "frequency": self.frequency, "eps": self.eps, "prev_losses": self.prev_losses, "cached_coeffs": self.cached_coeffs, } def load_state_dict(self, state_dict): """""" self.beta = state_dict["beta"] self.interval = state_dict["interval"] self.frequency = state_dict["frequency"] self.eps = state_dict["eps"] self.prev_losses = state_dict["prev_losses"] self.cached_coeffs = state_dict["cached_coeffs"]