Source code for nequip.train.callbacks.loss_coeff_scheduler
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import lightning
from lightning.pytorch.callbacks import Callback
from nequip.train import NequIPLightningModule
from typing import Dict
[docs]
class LossCoefficientScheduler(Callback):
"""Schedule loss coefficients during training.
The ``LossCoefficientScheduler`` takes a single argument ``schedule``, which is a ``Dict[int, Dict[str, float]]`` where the keys are the epochs at which the loss coefficients change and the values are dictionaries mapping loss metric names (corresponding to how the loss was configured) to their coefficients.
When the trainer's epoch counter matches any of the keys (representing epochs), the loss coefficients will be changed to the values (representing the coefficients for each loss term).
The coefficients will be normalized to sum up to 1 in line with the convention of :class:`~nequip.train.MetricsManager`.
Example usage in config where there are two loss contributions:
.. code-block:: yaml
callbacks:
- _target_: nequip.train.callbacks.LossCoefficientScheduler
schedule:
100:
per_atom_energy_mse: 1.0
forces_mse: 5.0
200:
per_atom_energy_mse: 5.0
forces_mse: 1.0
Args:
schedule (Dict[int, Dict[str,float]]): map of epoch to loss coefficient dictionary
"""
def __init__(self, schedule: Dict[int, Dict[str, float]]):
# ensure that the keys are `int`s
self.schedule = {int(k): v for k, v in schedule.items()}
# sanity check - epochs are >= 0
assert all([epoch >= 0 for epoch in self.schedule.keys()])
def on_train_epoch_start(
self,
trainer: lightning.Trainer,
pl_module: NequIPLightningModule,
) -> None:
""""""
# only change loss coefficients at the designated epochs
if trainer.current_epoch not in self.schedule.keys():
return
# set the loss coefficients
pl_module.loss.set_coeffs(self.schedule[trainer.current_epoch])
[docs]
class LinearLossCoefficientScheduler(Callback):
"""Linearly schedule loss coefficients during training.
The ``LinearLossCoefficientScheduler`` linearly interpolates loss coefficients from the current values at ``start_epoch`` to the specified ``final_coeffs`` over ``transition_epochs`` epochs.
This callback is stateful and captures the loss coefficients at ``start_epoch`` for interpolation.
.. note::
This callback is currently in beta testing. Please report any unexpected behavior or issues.
Example usage in config to transition to energy:force:stress = 1:1:1 over 200 epochs starting at epoch 100 (from whatever coefficients they were originally at):
.. code-block:: yaml
callbacks:
- _target_: nequip.train.callbacks.LinearLossCoefficientScheduler
final_coeffs:
per_atom_energy_mse: 1.0
forces_mse: 1.0
stress_mse: 1.0
start_epoch: 100
transition_epochs: 200
Multiple ``LinearLossCoefficientScheduler`` callbacks can be composed for multi-stage scheduling:
.. code-block:: yaml
callbacks:
# First transition: current -> 1:5:1 from epoch 50-150
- _target_: nequip.train.callbacks.LinearLossCoefficientScheduler
final_coeffs:
per_atom_energy_mse: 1.0
forces_mse: 5.0
stress_mse: 1.0
start_epoch: 50
transition_epochs: 100
# Second transition: current -> 1:1:1 from epoch 200-400
- _target_: nequip.train.callbacks.LinearLossCoefficientScheduler
final_coeffs:
per_atom_energy_mse: 1.0
forces_mse: 1.0
stress_mse: 1.0
start_epoch: 200
transition_epochs: 200
.. warning::
When composing multiple schedulers, ensure their epoch ranges do not overlap. No safety checks are performed to validate scheduler composition. Additionally, callback execution order is not guaranteed and training protocols should not rely on specific callback execution orders.
Args:
final_coeffs (Dict[str, float]): target loss coefficient dictionary
start_epoch (int): epoch at which to start the transition (default: 0)
transition_epochs (int): number of epochs over which to transition
"""
def __init__(
self,
final_coeffs: Dict[str, float],
transition_epochs: int,
start_epoch: int = 0,
):
# normalize final coefficients since captured coefficients will be normalized
final_total = sum(final_coeffs.values())
self.final_coeffs = {
key: val / final_total for key, val in final_coeffs.items()
}
self.start_epoch = start_epoch
self.transition_epochs = transition_epochs
self.captured_initial_coeffs = None
assert start_epoch >= 0, "Start epoch must be non-negative"
assert transition_epochs > 0, "Transition epochs must be positive"
def on_train_epoch_start(
self,
trainer: lightning.Trainer,
pl_module: NequIPLightningModule,
) -> None:
""""""
current_epoch = trainer.current_epoch
# NOTE: initial coeffs captured should already be normalized
# final coeffs were normalized at __init__
if current_epoch == self.start_epoch and self.captured_initial_coeffs is None:
# lazily capture the current coefficients when we start
self.captured_initial_coeffs = {
metric_name: entry.coeff
for metric_name, entry in pl_module.loss.entries.items()
if metric_name in self.final_coeffs
}
# sanity check that all `final_coeffs` keys are present in the metrics
assert set(self.final_coeffs.keys()) == set(
self.captured_initial_coeffs.keys()
), (
f"Mismatch between `final_coeffs` keys {set(self.final_coeffs.keys())} and available metrics {set(self.captured_initial_coeffs.keys())}"
)
if (
self.start_epoch
< current_epoch
<= self.start_epoch + self.transition_epochs
):
# linear interpolation during transition period
assert self.captured_initial_coeffs is not None, (
"Initial coefficients should have been captured"
)
epochs_into_transition = current_epoch - self.start_epoch
alpha = epochs_into_transition / self.transition_epochs
interpolated_coeffs = {}
for key in self.final_coeffs.keys():
initial_val = self.captured_initial_coeffs[key]
final_val = self.final_coeffs[key]
interpolated_coeffs[key] = initial_val + alpha * (
final_val - initial_val
)
pl_module.loss.set_coeffs(interpolated_coeffs)
def state_dict(self):
""""""
return {
"final_coeffs": self.final_coeffs,
"start_epoch": self.start_epoch,
"transition_epochs": self.transition_epochs,
"captured_initial_coeffs": self.captured_initial_coeffs,
}
def load_state_dict(self, state_dict):
""""""
self.final_coeffs = state_dict["final_coeffs"]
self.start_epoch = state_dict["start_epoch"]
self.transition_epochs = state_dict["transition_epochs"]
self.captured_initial_coeffs = state_dict["captured_initial_coeffs"]
@property
def state_key(self) -> str:
""""""
# See https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback.state_key
# This definition assumes that each start epoch is unique.
return f"{self.__class__.__qualname__}_{self.start_epoch}"