nequip.train.callbacks¶
- class nequip.train.callbacks.SoftAdapt(beta: float, interval: str, frequency: int, eps: float = 1e-08)[source]¶
Adaptively modify loss coefficients over a training run using the SoftAdapt scheme.
Note that the implementation here differs from the original
SoftAdaptscheme (which tends to 1:1:1 loss coefficient ratios), where the coefficient updates are weighted by the input loss coefficients (see PR #515).Warning
The SoftAdapt requires that all components of the loss function contribute to the loss function, i.e. that their
coeffin theMetricsManageris notNone.Warning
It is dangerous to restart training (with SoftAdapt) and use a differently configured loss function for the restart because SoftAdapt’s loaded checkpoint state will become ill-suited for the new loss function.
Example usage in config where the loss coefficients are updated every 5 epochs:
callbacks: - _target_: nequip.train.callbacks.SoftAdapt beta: 1.1 interval: epoch frequency: 5
- class nequip.train.callbacks.LossCoefficientScheduler(schedule: Dict[int, Dict[str, float]])[source]¶
Schedule loss coefficients during training.
The
LossCoefficientSchedulertakes a single argumentschedule, which is aDict[int, Dict[str, float]]where the keys are the epochs at which the loss coefficients change and the values are dictionaries mapping loss metric names (corresponding to how the loss was configured) to their coefficients.When the trainer’s epoch counter matches any of the keys (representing epochs), the loss coefficients will be changed to the values (representing the coefficients for each loss term).
The coefficients will be normalized to sum up to 1 in line with the convention of
MetricsManager.Example usage in config where there are two loss contributions:
callbacks: - _target_: nequip.train.callbacks.LossCoefficientScheduler schedule: 100: per_atom_energy_mse: 1.0 forces_mse: 5.0 200: per_atom_energy_mse: 5.0 forces_mse: 1.0
- class nequip.train.callbacks.LinearLossCoefficientScheduler(final_coeffs: Dict[str, float], transition_epochs: int, start_epoch: int = 0)[source]¶
Linearly schedule loss coefficients during training.
The
LinearLossCoefficientSchedulerlinearly interpolates loss coefficients from the current values atstart_epochto the specifiedfinal_coeffsovertransition_epochsepochs.This callback is stateful and captures the loss coefficients at
start_epochfor interpolation.Note
This callback is currently in beta testing. Please report any unexpected behavior or issues.
Example usage in config to transition to energy:force:stress = 1:1:1 over 200 epochs starting at epoch 100 (from whatever coefficients they were originally at):
callbacks: - _target_: nequip.train.callbacks.LinearLossCoefficientScheduler final_coeffs: per_atom_energy_mse: 1.0 forces_mse: 1.0 stress_mse: 1.0 start_epoch: 100 transition_epochs: 200
Multiple
LinearLossCoefficientSchedulercallbacks can be composed for multi-stage scheduling:callbacks: # First transition: current -> 1:5:1 from epoch 50-150 - _target_: nequip.train.callbacks.LinearLossCoefficientScheduler final_coeffs: per_atom_energy_mse: 1.0 forces_mse: 5.0 stress_mse: 1.0 start_epoch: 50 transition_epochs: 100 # Second transition: current -> 1:1:1 from epoch 200-400 - _target_: nequip.train.callbacks.LinearLossCoefficientScheduler final_coeffs: per_atom_energy_mse: 1.0 forces_mse: 1.0 stress_mse: 1.0 start_epoch: 200 transition_epochs: 200
Warning
When composing multiple schedulers, ensure their epoch ranges do not overlap. No safety checks are performed to validate scheduler composition. Additionally, callback execution order is not guaranteed and training protocols should not rely on specific callback execution orders.
- class nequip.train.callbacks.LossCoefficientMonitor(interval: str, frequency: int)[source]¶
Monitor and log loss coefficients during training.
Example usage in config to log loss coefficients every 5 epochs:
callbacks: - _target_: nequip.train.callbacks.LossCoefficientMonitor interval: epoch frequency: 5
- class nequip.train.callbacks.TestTimeXYZFileWriter(out_file: str, output_fields_from_original_dataset: List[str] | None = [], extra_fields: List[str] = [], chemical_symbols: List[str] | None = None)[source]¶
Writes model outputs to an
xyzfile.Users must provide an
out_filethat does not contain an extension. The actual output file will take the form{out_file}_dataset{idx}.xyzwhereidxis the dataset index (would be0for a single test set but varies depending on number of test sets).To incorporate original dataset fields in the
xyzfile to simplify analysis, users may provideoutput_fields_from_original_dataset. Such fields will have the prefixoriginal_dataset_in thexyzfile.To obtain correct chemical species information, users must provide
chemical_speciesin an order consistent with the model’stype_names.Example usage in config to write predictions and original dataset
total_energyandforcesto anxyzfile:callbacks: - _target_: nequip.train.callbacks.TestTimeXYZFileWriter out_file: ${hydra:runtime.output_dir}/test output_fields_from_original_dataset: [total_energy, forces] chemical_symbols: ${chemical_symbols}
- Parameters:
out_file (str) – path to output file (must NOT contain
.xyzor.extxyzextension)output_fields_from_original_dataset (List[str]) – values from the original dataset to save in the
out_fileextra_fields (List[str]) – extra fields to save in addition to ASE’s default fields
chemical_species (List[str]) – chemical species in the same order as model’s
type_names
- class nequip.train.callbacks.WandbWatch(log_freq: int, log: str = 'gradients', log_graph: bool = False)[source]¶
Monitor and log weights and gradients during training with PyTorch Lightning’s
WandbLogger.This class provides a way to call https://docs.wandb.ai/ref/python/watch/ when using a
WandbLoggerfor monitoring weights and gradients over the course of training.
- class nequip.train.callbacks.TF32Scheduler(schedule: Dict[int, bool])[source]¶
Schedule TF32 precision during training.
The
TF32Schedulertakes a single argumentschedule, which is aDict[int, bool]where the keys are the epochs at which TF32 changes and the values are:True: Enable TF32 (faster but less precise)False: Disable TF32 (slower but more precise)
Basic example to enable TF32 for all training:
callbacks: - _target_: nequip.train.callbacks.TF32Scheduler schedule: 0: true # Enable TF32 throughout training
Dynamic scheduling example for two-stage training:
callbacks: - _target_: nequip.train.callbacks.TF32Scheduler schedule: 0: true # Start with TF32 enabled 100: false # Disable TF32 at epoch 100 200: true # Re-enable TF32 at epoch 200
Note
The schedule must start at epoch 0. The initial setting will be applied at the beginning of training.
Note
This callback is currently in beta testing. Please report any unexpected behavior or issues.