Source code for nequip.train.callbacks.tf32_scheduler

# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import lightning
from lightning.pytorch.callbacks import Callback
from nequip.train import NequIPLightningModule
from nequip.utils.global_state import set_global_state

from typing import Dict, Optional


[docs] class TF32Scheduler(Callback): """Schedule TF32 precision during training. The ``TF32Scheduler`` takes a single argument ``schedule``, which is a ``Dict[int, bool]`` where the keys are the epochs at which TF32 changes and the values are: - ``True``: Enable TF32 (faster but less precise) - ``False``: Disable TF32 (slower but more precise) Basic example to enable TF32 for all training: .. code-block:: yaml callbacks: - _target_: nequip.train.callbacks.TF32Scheduler schedule: 0: true # Enable TF32 throughout training Dynamic scheduling example for two-stage training: .. code-block:: yaml callbacks: - _target_: nequip.train.callbacks.TF32Scheduler schedule: 0: true # Start with TF32 enabled 100: false # Disable TF32 at epoch 100 200: true # Re-enable TF32 at epoch 200 .. note:: The schedule must start at epoch 0. The initial setting will be applied at the beginning of training. .. note:: This callback is currently in beta testing. Please report any unexpected behavior or issues. Args: schedule (Dict[int, bool]): map of epoch to TF32 enabled/disabled """ def __init__(self, schedule: Dict[int, bool]): # ensure that the keys are `int`s self.schedule = {int(k): v for k, v in schedule.items()} # sanity check - epochs are >= 0 assert all([epoch >= 0 for epoch in self.schedule.keys()]) assert 0 in self.schedule, "First epoch in TF32 scheduler must be 0" # The TF32Scheduler is now the sole authority for TF32 settings during training. # Initialize state for restarts self.last_tf32_setting = self.schedule[0] def on_train_epoch_start( self, trainer: lightning.Trainer, pl_module: NequIPLightningModule, ) -> None: """""" if trainer.current_epoch in self.schedule: self._set_tf32(self.schedule[trainer.current_epoch], pl_module) def _set_tf32( self, enabled: bool, pl_module: Optional[NequIPLightningModule] = None ): set_global_state(allow_tf32=enabled) self.last_tf32_setting = enabled if pl_module is not None: pl_module.log( "tf32_enabled", float(enabled), on_step=False, on_epoch=True, ) def state_dict(self) -> Dict: """""" return { "last_tf32_setting": self.last_tf32_setting, } def load_state_dict(self, state_dict: Dict) -> None: """""" # restore the last TF32 state from checkpoint self.last_tf32_setting = state_dict["last_tf32_setting"] self._set_tf32(self.last_tf32_setting)