Source code for nequip.train.schedulefree
from .lightning import NequIPLightningModule
from typing import Dict, Any
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from nequip.utils import RankedLogger
import torch
logger = RankedLogger(__name__, rank_zero_only=True)
# Note: Manual `.train()`/`.eval()` mode control for optimizer is required to ensure smoothed weights are captured at the right time
# Related discussion on Lightning timing hooks:
# https://github.com/Lightning-AI/pytorch-lightning/discussions/19759
[docs]
class ScheduleFreeLightningModule(NequIPLightningModule):
"""
NequIP LightningModule using Facebook's Schedule-Free optimizer.
This module wraps the model's optimizer in one of Facebook's Schedule-Free variants.
See: https://github.com/facebookresearch/schedule_free
Args:
optimizer (Dict[str, Any]): Dictionary that must include a _target_
corresponding to one of the Schedule-Free optimizers and other keyword arguments
compatible with the Schedule-Free variants.
"""
def __init__(self, optimizer: Dict[str, Any], **kwargs):
valid_targets = {
"AdamWScheduleFree",
"SGDScheduleFree",
"RAdamScheduleFree",
}
target = optimizer.get("_target_")
if not target or not any(target.endswith(name) for name in valid_targets):
raise MisconfigurationException(
f"Invalid optimizer: expected Schedule-Free optimizer (_target_ ending with one of {valid_targets}), "
f"but found '{target}'"
)
# Will be used to lazily restore optimizer state in evaluation_model
self._schedulefree_state_dict: Dict[str, Any] = {}
super().__init__(optimizer=optimizer, **kwargs)
# Lightning Hook
def on_save_checkpoint(self, checkpoint: dict):
""""""
# Schedule-Free optimizers require .eval() to expose smoothed weights.
# This hook is called AFTER Lightning has already saved model/optimizer state,
# so we only store the smoothed state_dict here for packaging.
opt = self.optimizers()
if opt is not None:
checkpoint["schedulefree_optimizer_state_dict"] = opt.state_dict()
# Lightning Hook
def on_load_checkpoint(self, checkpoint: dict):
""""""
# We extract our custom optimizer state for later lazy loading
state = checkpoint.get("schedulefree_optimizer_state_dict")
if state is not None:
logger.info(
"Storing Schedule-Free optimizer state from checkpoint for lazy loading."
)
self._schedulefree_state_dict = state
# NequIP-Specific Override for Packaging
@property
def evaluation_model(self) -> torch.nn.Module:
# This is used during packaging to get the smoothed evaluation weights.
logger.info("Loading Schedule-Free optimizer weights for evaluation.")
prev_state_dict = getattr(self, "_schedulefree_state_dict", None)
opt = self.configure_optimizers()
if prev_state_dict:
opt.load_state_dict(prev_state_dict)
# Set optimizer to evaluation mode for smoothed weights
opt.eval()
return self.model
# Lightning Hook
def on_train_epoch_start(self) -> None:
""""""
# Ensures fast weights are used during training
self.optimizers().train()
# Lightning Hook
def on_validation_epoch_start(self) -> None:
""""""
# Ensures smoothed weights are used for validation
self.optimizers().eval()
# Lightning Hook
def on_test_epoch_start(self) -> None:
""""""
# Ensures smoothed weights are used during testing
self.optimizers().eval()
# Lightning Hook
def on_predict_epoch_start(self) -> None:
""""""
# Ensures smoothed weights are used during prediction/inference
self.optimizers().eval()