nequip.train.callbacks

class nequip.train.callbacks.SoftAdapt(beta: float, interval: str, frequency: int, eps: float = 1e-08)[source]

Adaptively modify loss coefficients over a training run using the SoftAdapt scheme.

Note that the implementation here differs from the original SoftAdapt scheme (which tends to 1:1:1 loss coefficient ratios), where the coefficient updates are weighted by the input loss coefficients (see PR #515).

Warning

The SoftAdapt requires that all components of the loss function contribute to the loss function, i.e. that their coeff in the MetricsManager is not None.

Warning

It is dangerous to restart training (with SoftAdapt) and use a differently configured loss function for the restart because SoftAdapt’s loaded checkpoint state will become ill-suited for the new loss function.

Example usage in config where the loss coefficients are updated every 5 epochs:

callbacks:
  - _target_: nequip.train.callbacks.SoftAdapt
    beta: 1.1
    interval: epoch
    frequency: 5
Parameters:
  • beta (float) – SoftAdapt hyperparameter (see paper)

  • interval (str) – batch or epoch

  • frequency (int) – number of intervals between loss coefficient updates

  • eps (float) – small value to avoid division by zero

class nequip.train.callbacks.LossCoefficientScheduler(schedule: Dict[int, Dict[str, float]])[source]

Schedule loss coefficients during training.

The LossCoefficientScheduler takes a single argument schedule, which is a Dict[int, Dict[str, float]] where the keys are the epochs at which the loss coefficients change and the values are dictionaries mapping loss metric names (corresponding to how the loss was configured) to their coefficients.

When the trainer’s epoch counter matches any of the keys (representing epochs), the loss coefficients will be changed to the values (representing the coefficients for each loss term).

The coefficients will be normalized to sum up to 1 in line with the convention of MetricsManager.

Example usage in config where there are two loss contributions:

callbacks:
  - _target_: nequip.train.callbacks.LossCoefficientScheduler
    schedule:
      100:
        per_atom_energy_mse: 1.0
        forces_mse: 5.0
      200:
        per_atom_energy_mse: 5.0
        forces_mse: 1.0
Parameters:

schedule (Dict[int, Dict[str,float]]) – map of epoch to loss coefficient dictionary

class nequip.train.callbacks.LinearLossCoefficientScheduler(final_coeffs: Dict[str, float], transition_epochs: int, start_epoch: int = 0)[source]

Linearly schedule loss coefficients during training.

The LinearLossCoefficientScheduler linearly interpolates loss coefficients from the current values at start_epoch to the specified final_coeffs over transition_epochs epochs.

This callback is stateful and captures the loss coefficients at start_epoch for interpolation.

Note

This callback is currently in beta testing. Please report any unexpected behavior or issues.

Example usage in config to transition to energy:force:stress = 1:1:1 over 200 epochs starting at epoch 100 (from whatever coefficients they were originally at):

callbacks:
  - _target_: nequip.train.callbacks.LinearLossCoefficientScheduler
    final_coeffs:
      per_atom_energy_mse: 1.0
      forces_mse: 1.0
      stress_mse: 1.0
    start_epoch: 100
    transition_epochs: 200

Multiple LinearLossCoefficientScheduler callbacks can be composed for multi-stage scheduling:

callbacks:
  # First transition: current -> 1:5:1 from epoch 50-150
  - _target_: nequip.train.callbacks.LinearLossCoefficientScheduler
    final_coeffs:
      per_atom_energy_mse: 1.0
      forces_mse: 5.0
      stress_mse: 1.0
    start_epoch: 50
    transition_epochs: 100
  # Second transition: current -> 1:1:1 from epoch 200-400
  - _target_: nequip.train.callbacks.LinearLossCoefficientScheduler
    final_coeffs:
      per_atom_energy_mse: 1.0
      forces_mse: 1.0
      stress_mse: 1.0
    start_epoch: 200
    transition_epochs: 200

Warning

When composing multiple schedulers, ensure their epoch ranges do not overlap. No safety checks are performed to validate scheduler composition. Additionally, callback execution order is not guaranteed and training protocols should not rely on specific callback execution orders.

Parameters:
  • final_coeffs (Dict[str, float]) – target loss coefficient dictionary

  • start_epoch (int) – epoch at which to start the transition (default: 0)

  • transition_epochs (int) – number of epochs over which to transition

class nequip.train.callbacks.LossCoefficientMonitor(interval: str, frequency: int)[source]

Monitor and log loss coefficients during training.

Example usage in config to log loss coefficients every 5 epochs:

callbacks:
  - _target_: nequip.train.callbacks.LossCoefficientMonitor
    interval: epoch
    frequency: 5
Parameters:
  • interval (str) – batch or epoch

  • frequency (int) – number of intervals between each instance of loss coefficient logging

class nequip.train.callbacks.TestTimeXYZFileWriter(out_file: str, output_fields_from_original_dataset: List[str] | None = [], extra_fields: List[str] = [], chemical_symbols: List[str] | None = None)[source]

Writes model outputs to an xyz file.

Users must provide an out_file that does not contain an extension. The actual output file will take the form {out_file}_dataset{idx}.xyz where idx is the dataset index (would be 0 for a single test set but varies depending on number of test sets).

To incorporate original dataset fields in the xyz file to simplify analysis, users may provide output_fields_from_original_dataset. Such fields will have the prefix original_dataset_ in the xyz file.

To obtain correct chemical species information, users must provide chemical_species in an order consistent with the model’s type_names.

Example usage in config to write predictions and original dataset total_energy and forces to an xyz file:

callbacks:
  - _target_: nequip.train.callbacks.TestTimeXYZFileWriter
    out_file: ${hydra:runtime.output_dir}/test
    output_fields_from_original_dataset: [total_energy, forces]
    chemical_symbols: ${chemical_symbols}
Parameters:
  • out_file (str) – path to output file (must NOT contain .xyz or .extxyz extension)

  • output_fields_from_original_dataset (List[str]) – values from the original dataset to save in the out_file

  • extra_fields (List[str]) – extra fields to save in addition to ASE’s default fields

  • chemical_species (List[str]) – chemical species in the same order as model’s type_names

class nequip.train.callbacks.WandbWatch(log_freq: int, log: str = 'gradients', log_graph: bool = False)[source]

Monitor and log weights and gradients during training with PyTorch Lightning’s WandbLogger.

This class provides a way to call https://docs.wandb.ai/ref/python/watch/ when using a WandbLogger for monitoring weights and gradients over the course of training.

Parameters:
  • log_freq (int) – frequency (in batches) to log gradients and parameters

  • log (str) – specifies whether to log "gradients", "parameters", or "all"

  • log_graph (bool) – whether to log the model’s computational graph

class nequip.train.callbacks.TF32Scheduler(schedule: Dict[int, bool])[source]

Schedule TF32 precision during training.

The TF32Scheduler takes a single argument schedule, which is a Dict[int, bool] where the keys are the epochs at which TF32 changes and the values are:

  • True: Enable TF32 (faster but less precise)

  • False: Disable TF32 (slower but more precise)

Basic example to enable TF32 for all training:

callbacks:
  - _target_: nequip.train.callbacks.TF32Scheduler
    schedule:
      0: true      # Enable TF32 throughout training

Dynamic scheduling example for two-stage training:

callbacks:
  - _target_: nequip.train.callbacks.TF32Scheduler
    schedule:
      0: true      # Start with TF32 enabled
      100: false   # Disable TF32 at epoch 100
      200: true    # Re-enable TF32 at epoch 200

Note

The schedule must start at epoch 0. The initial setting will be applied at the beginning of training.

Note

This callback is currently in beta testing. Please report any unexpected behavior or issues.

Parameters:

schedule (Dict[int, bool]) – map of epoch to TF32 enabled/disabled