Dataset Statistics¶
The following CommonDataStatisticsManager can generally be used for common force field training scenarios.
- nequip.data.CommonDataStatisticsManager(dataloader_kwargs: Dict[str, Any] | None = None, type_names: List[str] | None = None)[source]¶
DataStatisticsManagerwrapper that implements common dataset statistics.The dataset statistics computed by using this wrapper include
num_neighbors_mean,per_atom_energy_mean,forces_rms, andper_type_forces_rms, which are variables that can be interpolated for in themodelsection of the config file.For example:
training_module: _target_: nequip.train.EMALightningModule # other `EMALightningModule` arguments model: _target_: nequip.model.NequIPGNNModel # other model hyperparameters avg_num_neighbors: ${training_data_stats:num_neighbors_mean} per_type_energy_shifts: ${training_data_stats:per_atom_energy_mean} per_type_energy_scales: ${training_data_stats:forces_rms} # or alternatively the per-type forces RMS # per_type_energy_scales: ${training_data_stats:per_type_forces_rms}
The following can be used for energy-only datasets (without forces):
- nequip.data.EnergyOnlyDataStatisticsManager(dataloader_kwargs: Dict[str, Any] | None = None, type_names: List[str] | None = None)[source]¶
DataStatisticsManagerwrapper for energy-only datasets.This manager computes statistics for datasets that only contain energies and no forces. The dataset statistics computed include
num_neighbors_mean,per_atom_energy_mean,per_atom_energy_std, andtotal_energy_std, which are variables that can be interpolated for in themodelsection of the config file.For example:
training_module: _target_: nequip.train.EMALightningModule # other `EMALightningModule` arguments model: _target_: nequip.model.NequIPGNNModel do_derivatives: false # other model hyperparameters avg_num_neighbors: ${training_data_stats:num_neighbors_mean} per_type_energy_shifts: ${training_data_stats:per_atom_energy_mean} per_type_energy_scales: ${training_data_stats:total_energy_std}
For users who seek to configure their own custom dataset statistics, the following API is offered.
As an example, we show how one can configure the full DataStatisticsManager to have behavior equivalent to using
CommonDataStatisticsManager as follows:
stats_manager:
_target_: nequip.data.DataStatisticsManager
type_names: ${model_type_names}
metrics:
- name: num_neighbors_mean
field:
_target_: nequip.data.NumNeighbors
metric:
_target_: nequip.data.Mean
- name: per_atom_energy_mean
field:
_target_: nequip.data.PerAtomModifier
field: total_energy
metric:
_target_: nequip.data.Mean
- name: forces_rms
field: forces
metric:
_target_: nequip.data.RootMeanSquare
- name: per_type_forces_rms
per_type: true
field: forces
metric:
_target_: nequip.data.RootMeanSquare
- class nequip.data.DataStatisticsManager(metrics: List[Dict[str, float | str | Dict[str, str | Callable] | Metric]], dataloader_kwargs: Dict[str, Any] | None = None, type_names: List[str] | None = None)[source]¶
Manages
nequipmetrics that can be applied toAtomicDataDictto compute dataset statistics.The main input argument
metricsis a list of dictionaries, where each dictionary contains the following keys:Mandatory keys:
fieldrefers to the quantity of interest for metric computation. It has two formats:a
strfor anequipdefined field (e.g.total_energy,forces,stress), ora
Callablethat performs some additional operations before returning atorch.Tensorfor metric computation (e.g.PerAtomModifier).
metricis anequipdata metric object (a subclass oftorchmetrics.Metric).
Optional keys:
per_typeis abool(defaults toFalseif not provided). IfTrue, node fields (such asforces) will have their metrics computed separately for each node type based on thetype_namesargument.ignore_nanis abool(defaults toFalseif not provided). This should be set to true if one expects the underlyingtargetdata to containNaNentries.nameis the name that the metric is logged as. Default names are used if not provided, but it is recommended for users to set custom names for clarity and control.
- Parameters:
metrics (list) – list of dictionaries with keys
field,metric,per_type,ignore_nan, andnamedataloader_kwargs (dict) – arguments of
torch.utils.data.DataLoaderfor dataset statistics computation (ideally, thebatch_sizeshould be as large as possible without triggering OOM)type_names (list) – required for
per_typemetrics (this must match thetype_namesargument of the model, it is advisable to use variable interpolation in the config file to make sure they are consistent)
- class nequip.data.MeanAbsolute(**kwargs)[source]¶
Mean of absolute value computed in a running fashion.
- class nequip.data.StandardDeviation(squared: bool = False, unbiased=True, **kwargs)[source]¶
Standard deviation computed in a running fashion with Welford’s online algorithm.
- class nequip.data.Min(abs: bool = False, **kwargs)[source]¶
Smallest entry seen.
- Parameters:
abs (bool) – whether to use absolute values