NequIPLightningModule

class nequip.train.NequIPLightningModule(model: Dict, num_datasets: Dict[str, int], optimizer: Dict | None = None, lr_scheduler: Dict | None = None, loss: Dict | None = None, train_metrics: Dict | None = None, val_metrics: Dict | None = None, test_metrics: Dict | None = None, info_dict: Dict | None = None)[source]

LightningModule for training, validating, testing and predicting with models constructed in the NequIP ecosystem.

Data

The NequIPLightningModule supports a single train dataset, but multiple val and test datasets.

Run Types and Metrics

  • For train runs, users must provide loss and val_metrics. The loss is computed on the training dataset to train the model, and requires each metric to have a corresponding coefficient that will be used to generate a weighted_sum. This weighted_sum is the loss function that will be minimized over the course of training. val_metrics is computed on the validation dataset(s) for monitoring. Additionally, users may provide train_metrics to monitor metrics on the training dataset.

  • For val runs, users must provide val_metrics.

  • For test runs, users must provide test_metrics.

Logging Conventions

Logging is performed for the train, val, and test datasets.

During train runs,
  • logging occurs at each batch step and at each epoch,

  • there is only one training set, so no data_idx is used in the logging.

For val and test runs,
  • logging only occurs at each validation or testing epoch, i.e. one pass over the entirety of each validation/testing dataset,

  • there can be multiple validation and testing sets, so a zero-based data_idx index is used in the logging.

Logging Format
  • / is used as a delimiter for to exploit the automatic grouping functionality of most loggers. Logged metrics will have the form train_{loss/metric}_{step/epoch}/{metric_name} and {val/test}{data_idx}_epoch/{metric_name}. For example, train_loss_step/force_MSE, train_metric_epoch/E_MAE, val0_epoch/F_RMSE, etc.

  • Note that this may have implications on how one would set the parameters for the ModelCheckpoint callback, i.e. if the name of a metric is used in the checkpoint file’s name, the / will cause a directory to be created when instead a file is desired.

class nequip.train.EMALightningModule(ema_decay: float = 0.999, **kwargs)[source]

An exponential moving average (EMA) of the model weights are maintained. Validation and test metrics will be that of the EMA weight model. If EMA is used, models loaded from checkpoint files (except during restarts) will always be the model with EMA weights. Specifically, whenever a model is prepared for inference, the EMA weights are used when available, including model loading for inference, nequip-compile, and nequip-package.

Note: EMA requires check_val_every_n_epoch to be 1 (the default).

Parameters:

ema_decay (float) – decay constant for the exponential moving average (EMA) of model weights (default 0.999)

class nequip.train.ConFIGLightningModule(gradient_clip_val: float | None = None, gradient_clip_algorithm: str | None = None, lsqr: bool = True, norm_eps: float = 1e-08, **kwargs)[source]

Conflict-free inverse gradient (ConFIG) approach to multitask learning. See https://arxiv.org/abs/2408.11104.

The arguments for this class are exactly the same as NequIPLightningModule, but the loss coefficients take on a different meaning – they are now the “b” in the “Ax=b” linear solve (see paper).

Set lsqr=False to use the pseudo-inverse of the gradient matrix to determine the update direction (instead of the default least squares method), as certain devices may not be able to do the (underdetermined) least squares solve (e.g. ROCm).

Note

Only ReduceLROnPlateau works with this class. The following warning may be safely ignored. The lr scheduler dict contains the key(s) ['monitor'], but the keys will be ignored. You need to call `lr_scheduler.step()` manually in manual optimization.

Note

LR schedulers won’t be able to monitor training metrics using this class – which should not be a problem since LR schedulers should usually be monitoring validation metrics.

Note

To use gradient clipping in training, the gradient_clip_val must be provided to this training module and not to Trainer, as automatic gradient clipping is not supported for manual optimization with PyTorch.

Parameters:
  • gradient_clip_val (Union[int, float, None]) – gradient clipping value (default: None, which disables gradient clipping)

  • gradient_clip_algorithm (Optional[str]) – value to clip by value, or norm to clip by norm (default: norm)

  • lsqr (bool) – whether to use least squares solve for determining best update direction (default: True)

  • norm_eps (float) – small value to avoid division by zero during normalization (default: 1e-8)

class nequip.train.EMAConFIGLightningModule(**kwargs)[source]

Composition of ConFIGLightningModule and EMALightningModule

Parameters:
  • gradient_clip_val (Union[int, float, None]) – gradient clipping value (default: None, which disables gradient clipping)

  • gradient_clip_algorithm (Optional[str]) – value to clip by value, or norm to clip by norm (default: norm)

  • lsqr (bool) – whether to use least squares solve for determining best update direction (default: True)

  • norm_eps (float) – small value to avoid division by zero during normalization (default: 1e-8)

  • ema_decay (float) – decay constant for the exponential moving average (EMA) of model weights (default 0.999)

class nequip.train.ScheduleFreeLightningModule(optimizer: Dict[str, Any], **kwargs)[source]

NequIP LightningModule using Facebook’s Schedule-Free optimizer.

This module wraps the model’s optimizer in one of Facebook’s Schedule-Free variants. See: https://github.com/facebookresearch/schedule_free

Parameters:

optimizer (Dict[str, Any]) – Dictionary that must include a _target_ corresponding to one of the Schedule-Free optimizers and other keyword arguments compatible with the Schedule-Free variants.