Config

The config file has five main sections – run, data, trainer, training_module, global_options. These top level config entries must always be present. Before going into what each section entails, users are advised to take note of OmegaConf’s variable interpolation utilities, which may be a useful tool for managing configuring training and testing runs.

run

run allows users to specify the run types, of which there are four – train (which requires a train and at least one val dataset), val (which requires val dataset(s)), test (which requires test dataset(s)), predict (which requires predict datasets).

Users can specify one or more of these run types in the config. A common use mode is to perform training, followed immediately by testing (using the best model checkpoint).

run: [train, test]

If one seeks to check how the untrained model performs on the validation and test datasets before training, then assess the trained model’s performance, one can use the following.

run: [val, test, train, val, test]

NOTE: the test run type is the replacement for the role nequip-evaluate has in the pre-0.7.0 nequip package.

data

data is the DataModule object to be used. Users are directed to the API page of nequip.data.datamodule for the nequip supported DataModule classes. Custom datamodules that subclass from nequip.data.datamodule.NequIPDataModule can also be used.

trainer

The trainer is meant to instantiate a lightning.Trainer object. To understand how to configure it, users are directed to lightning.Trainer’s page. The sections on trainer flags and its API are especially important.

NOTE: it is in the lightning.Trainer that users can specify callbacks used to influence the course of training. This includes the very important ModelCheckpoint callback that should be configured to save checkpoint files in the way the user so pleases. nequip’s own callbacks can also be used here.

NOTE: it is also here that users specify the logger, e.g. Tensorboard, Weights & Biases, etc.

training_module

training_module defines the NequIPLightningModule (or its subclasses). Users are directed to its API page to learn how to configure it. It is here that the following parameters are defined

  • the model

  • the loss and metrics

  • the optimizer and lr_scheduler

global_options

For now, global_options is used to specify

  • seed, the global seed (in addition to the data seed and model seed)

  • allow_tf32, which controls whether TensorFloat-32 is used