# Loss and Metrics Loss functions and metrics are configured by specifying a field (e.g. `total_energy`, `forces`, [etc.](../../api/data_fields.rst)) and an error quantity to calculate for it (e.g. {class}`~nequip.train.MeanSquaredError`, {class}`~nequip.train.MeanAbsoluteError`, etc). Loss functions and metrics are configured through {class}`~nequip.train.MetricsManager` objects in the [`training_module`](config.md#training_module) section of the config. The loss function determines what the model optimizes during training, while metrics are used for monitoring training progress and conditioning training behavior (early stopping, learning rate scheduling, etc.). ## Units All loss components and metrics are in the physical units associated with the dataset. For example, if the dataset uses force units of eV/Å, a force mean-squared error (MSE) would have units of (eV/Å)². ## Simplified Wrappers Most users should use the simplified wrapper classes for common force field training scenarios. These wrappers automatically configure the appropriate metrics for you: **For Loss Functions:** - {class}`~nequip.train.EnergyForceLoss` - {class}`~nequip.train.EnergyForceStressLoss` - {class}`~nequip.train.EnergyOnlyLoss` (for energy-only datasets) **For Validation/Test Metrics:** - {class}`~nequip.train.EnergyForceMetrics` - {class}`~nequip.train.EnergyForceStressMetrics` - {class}`~nequip.train.EnergyOnlyMetrics` (for energy-only datasets) When using simplified wrappers, the actual metric names logged during training may not be immediately obvious. Each wrapper creates specific metrics with predetermined names. To see exactly what metric names each wrapper produces, refer to their individual API documentation in the [`nequip.train` metrics API reference](../../api/metrics.rst). ## Coefficients and Weighted Sum Users can set coefficients (`coeff`) for each loss or metric term, which leads to the computation of a `weighted_sum` metric. 1. For **loss functions**, `weighted_sum` is the actual loss value used for backpropagation. 2. For **validation/test metrics**, `weighted_sum` provides a single monitoring metric that balances multiple quantities to be used for conditioning checkpointing, early stopping, learning rate scheduling, etc. Coefficients are automatically normalized to sum to 1. For example: ```yaml coeffs: total_energy: 3.0 forces: 1.0 ``` becomes internally: `total_energy: 0.75, forces: 0.25`. The `weighted_sum` is calculated as: ``` weighted_sum = (coeff_1 * metric_1) + (coeff_2 * metric_2) + ... ``` Coefficients only affect the `weighted_sum` calculation. The individual metrics (e.g., `total_energy_rmse`, `forces_rmse`) are logged with their actual computed values, unmodified by coefficients. Metrics with `coeff: null` (or omitted from `coeffs`) are still computed and logged, but excluded from `weighted_sum`: ```yaml coeffs: total_energy_rmse: 1.0 # included in weighted_sum forces_rmse: 1.0 # included in weighted_sum total_energy_mae: null # computed but not in weighted_sum forces_mae: null # computed but not in weighted_sum ``` Here's an example showing how to set up metrics and use `weighted_sum` for monitoring: ```yaml # Define the monitored metric once for consistency monitored_metric: val0_epoch/weighted_sum training_module: _target_: nequip.train.EMALightningModule # Loss function loss: _target_: nequip.train.EnergyForceLoss coeffs: total_energy: 1.0 forces: 1.0 # Validation metrics - weighted_sum will be used for monitoring val_metrics: _target_: nequip.train.EnergyForceMetrics coeffs: total_energy_rmse: 1.0 forces_rmse: 1.0 total_energy_mae: null # logged but not in weighted_sum forces_mae: null # logged but not in weighted_sum trainer: _target_: lightning.Trainer callbacks: # Early stopping using the monitored metric - _target_: lightning.pytorch.callbacks.EarlyStopping monitor: ${monitored_metric} patience: 20 min_delta: 1e-3 # Model checkpointing using the monitored metric - _target_: lightning.pytorch.callbacks.ModelCheckpoint monitor: ${monitored_metric} filename: best # Learning rate scheduler using the monitored metric lr_scheduler: scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau factor: 0.6 patience: 5 monitor: ${monitored_metric} ``` ## Per-Species Force Loss Coefficients For systems with very heterogeneous species, it can help to emphasize the force errors on some atomic types over others. {class}`~nequip.train.EnergyForceLoss` and {class}`~nequip.train.EnergyForceStressLoss` accept an optional `per_type_forces_coeffs` dict for this purpose. These are loss-aggregation coefficients — not to be confused with model parameters. When supplied, the forces loss is computed per atom type and combined as a *weighted mean* `sum(c_i * mse_i) / sum(c_i)` instead of the default equal-mean over types. Equal coefficients reproduce the default behavior exactly. ```yaml loss: _target_: nequip.train.EnergyForceLoss per_atom_energy: true coeffs: total_energy: 1.0 forces: 1.0 # Emphasize H force errors relative to heavier species per_type_forces_coeffs: H: 5.0 O: 1.0 P: 1.0 Cs: 0.01 ``` The dict must contain a strictly positive coefficient for every type in `type_names` (no missing keys, no zeros, no negatives). To deemphasize a species, give it a small positive coefficient rather than `0`, since a force field with zero training signal on a species is rarely intended. Only the forces term is affected; per-structure terms (`total_energy`, `stress`) ignore this argument. Per-type breakdowns logged during training (e.g. `forces_mse_H`, `forces_mse_O`) are the raw per-species MSEs. The coefficients are applied only to the aggregated `forces_mse`. ### Logging per-species force MAE and RMSE The loss only emits per-species MSE breakdowns (`forces_mse_H`, `forces_mse_O`, etc.). To also monitor per-species MAE and RMSE in physical units, the metrics wrapper accepts an `extra_metrics` list: ```yaml val_metrics: _target_: nequip.train.EnergyForceMetrics coeffs: per_atom_energy_mae: 1.0 forces_mae: 1.0 # Per-species force breakdown (observation-only; omit `coeff` so they # don't enter `weighted_sum`). extra_metrics: - name: per_type_forces_mae field: forces metric: _target_: nequip.train.MeanAbsoluteError per_type: true - name: per_type_forces_rmse field: forces metric: _target_: nequip.train.RootMeanSquaredError per_type: true # Reuse for train / test so all three log the same breakdown: train_metrics: ${training_module.val_metrics} test_metrics: ${training_module.val_metrics} ``` This logs: - per-species values: `per_type_forces_mae_H`, `per_type_forces_mae_O`, `per_type_forces_mae_P`, `per_type_forces_mae_Cs`, and analogously for `_rmse`, - the equal-mean aggregate over species: `per_type_forces_mae` and `per_type_forces_rmse`. To include a per-species metric in the monitored `weighted_sum` (e.g. to early-stop on per-species RMSE), give the entry a non-null `coeff`. The snippet above omits `coeff`, leaving these as observation-only. ## Advanced Usage: Custom MetricsManager For scenarios not covered by the simplified wrappers, you can use the full {class}`~nequip.train.MetricsManager` directly. Technical details and advanced examples are provided in the [`nequip.train.MetricsManager` API documentation](../../api/metrics.rst). Common advanced use cases include: - Custom field modifiers beyond {class}`~nequip.data.PerAtomModifier` - Per-type metrics (separate metrics for each atom type, optionally combined as a weighted mean via `per_type_coeffs`) - Custom metric types (e.g., {class}`~nequip.train.HuberLoss`, {class}`~nequip.train.StratifiedHuberForceLoss`) - Handling datasets with missing labels (using `ignore_nan: true`)