NequIPLightningModule¶
- class nequip.train.NequIPLightningModule(model: Dict, num_datasets: Dict[str, int], optimizer: Dict | None = None, lr_scheduler: Dict | None = None, loss: Dict | None = None, train_metrics: Dict | None = None, val_metrics: Dict | None = None, test_metrics: Dict | None = None, info_dict: Dict | None = None)[source]¶
LightningModulefor training, validating, testing and predicting with models constructed in the NequIP ecosystem.Data
The
NequIPLightningModulesupports a singletraindataset, but multiplevalandtestdatasets.Run Types and Metrics
For
trainruns, users must providelossandval_metrics. Thelossis computed on the training dataset to train the model, and requires each metric to have a corresponding coefficient that will be used to generate aweighted_sum. Thisweighted_sumis the loss function that will be minimized over the course of training.val_metricsis computed on the validation dataset(s) for monitoring. Additionally, users may providetrain_metricsto monitor metrics on the training dataset.For
valruns, users must provideval_metrics.For
testruns, users must providetest_metrics.
Logging Conventions
Logging is performed for the
train,val, andtestdatasets.- During
trainruns, logging occurs at each batch
stepand at eachepoch,there is only one training set, so no
data_idxis used in the logging.
- For
valandtestruns, logging only occurs at each validation or testing
epoch, i.e. one pass over the entirety of each validation/testing dataset,there can be multiple validation and testing sets, so a zero-based
data_idxindex is used in the logging.
- Logging Format
/is used as a delimiter for to exploit the automatic grouping functionality of most loggers. Logged metrics will have the formtrain_{loss/metric}_{step/epoch}/{metric_name}and{val/test}{data_idx}_epoch/{metric_name}. For example,train_loss_step/force_MSE,train_metric_epoch/E_MAE,val0_epoch/F_RMSE, etc.Note that this may have implications on how one would set the parameters for the ModelCheckpoint callback, i.e. if the name of a metric is used in the checkpoint file’s name, the
/will cause a directory to be created when instead a file is desired.
- class nequip.train.EMALightningModule(ema_decay: float = 0.999, **kwargs)[source]¶
An exponential moving average (EMA) of the model weights are maintained. Validation and test metrics will be that of the EMA weight model. If EMA is used, models loaded from checkpoint files (except during restarts) will always be the model with EMA weights. Specifically, whenever a model is prepared for inference, the EMA weights are used when available, including model loading for inference,
nequip-compile, andnequip-package.Note: EMA requires
check_val_every_n_epochto be 1 (the default).- Parameters:
ema_decay (float) – decay constant for the exponential moving average (EMA) of model weights (default
0.999)
- class nequip.train.ConFIGLightningModule(gradient_clip_val: float | None = None, gradient_clip_algorithm: str | None = None, lsqr: bool = True, norm_eps: float = 1e-08, **kwargs)[source]¶
Conflict-free inverse gradient (ConFIG) approach to multitask learning. See https://arxiv.org/abs/2408.11104.
The arguments for this class are exactly the same as
NequIPLightningModule, but the loss coefficients take on a different meaning – they are now the “b” in the “Ax=b” linear solve (see paper).Set
lsqr=Falseto use the pseudo-inverse of the gradient matrix to determine the update direction (instead of the default least squares method), as certain devices may not be able to do the (underdetermined) least squares solve (e.g. ROCm).Note
Only
ReduceLROnPlateauworks with this class. The following warning may be safely ignored.The lr scheduler dict contains the key(s) ['monitor'], but the keys will be ignored. You need to call `lr_scheduler.step()` manually in manual optimization.Note
LR schedulers won’t be able to monitor training metrics using this class – which should not be a problem since LR schedulers should usually be monitoring validation metrics.
Note
To use gradient clipping in training, the
gradient_clip_valmust be provided to this training module and not toTrainer, as automatic gradient clipping is not supported for manual optimization with PyTorch.- Parameters:
gradient_clip_val (Union[int, float, None]) – gradient clipping value (default:
None, which disables gradient clipping)gradient_clip_algorithm (Optional[str]) –
valueto clip by value, ornormto clip by norm (default:norm)lsqr (bool) – whether to use least squares solve for determining best update direction (default:
True)norm_eps (float) – small value to avoid division by zero during normalization (default:
1e-8)
- class nequip.train.EMAConFIGLightningModule(**kwargs)[source]¶
Composition of
ConFIGLightningModuleandEMALightningModule- Parameters:
gradient_clip_val (Union[int, float, None]) – gradient clipping value (default:
None, which disables gradient clipping)gradient_clip_algorithm (Optional[str]) –
valueto clip by value, ornormto clip by norm (default:norm)lsqr (bool) – whether to use least squares solve for determining best update direction (default:
True)norm_eps (float) – small value to avoid division by zero during normalization (default:
1e-8)ema_decay (float) – decay constant for the exponential moving average (EMA) of model weights (default
0.999)
- class nequip.train.ScheduleFreeLightningModule(optimizer: Dict[str, Any], **kwargs)[source]¶
NequIP LightningModule using Facebook’s Schedule-Free optimizer.
This module wraps the model’s optimizer in one of Facebook’s Schedule-Free variants. See: https://github.com/facebookresearch/schedule_free
- Parameters:
optimizer (Dict[str, Any]) – Dictionary that must include a _target_ corresponding to one of the Schedule-Free optimizers and other keyword arguments compatible with the Schedule-Free variants.