Dataset Statistics

The following CommonDataStatisticsManager can generally be used for common force field training scenarios.

nequip.data.CommonDataStatisticsManager(dataloader_kwargs: Dict[str, Any] | None = None, type_names: List[str] | None = None)[source]

DataStatisticsManager wrapper that implements common dataset statistics.

The dataset statistics computed by using this wrapper include num_neighbors_mean, per_atom_energy_mean, forces_rms, and per_type_forces_rms, which are variables that can be interpolated for in the model section of the config file.

For example:

training_module:
_target_: nequip.train.EMALightningModule

# other `EMALightningModule` arguments

model:
  _target_: nequip.model.NequIPGNNModel

  # other model hyperparameters
  avg_num_neighbors: ${training_data_stats:num_neighbors_mean}
  per_type_energy_shifts: ${training_data_stats:per_atom_energy_mean}
  per_type_energy_scales: ${training_data_stats:forces_rms}
  # or alternatively the per-type forces RMS
  # per_type_energy_scales: ${training_data_stats:per_type_forces_rms}

The following can be used for energy-only datasets (without forces):

nequip.data.EnergyOnlyDataStatisticsManager(dataloader_kwargs: Dict[str, Any] | None = None, type_names: List[str] | None = None)[source]

DataStatisticsManager wrapper for energy-only datasets.

This manager computes statistics for datasets that only contain energies and no forces. The dataset statistics computed include num_neighbors_mean, per_atom_energy_mean, per_atom_energy_std, and total_energy_std, which are variables that can be interpolated for in the model section of the config file.

For example:

training_module:
_target_: nequip.train.EMALightningModule

# other `EMALightningModule` arguments

model:
  _target_: nequip.model.NequIPGNNModel
  do_derivatives: false

  # other model hyperparameters
  avg_num_neighbors: ${training_data_stats:num_neighbors_mean}
  per_type_energy_shifts: ${training_data_stats:per_atom_energy_mean}
  per_type_energy_scales: ${training_data_stats:total_energy_std}

For users who seek to configure their own custom dataset statistics, the following API is offered.

As an example, we show how one can configure the full DataStatisticsManager to have behavior equivalent to using CommonDataStatisticsManager as follows:

stats_manager:
  _target_: nequip.data.DataStatisticsManager
  type_names: ${model_type_names}
  metrics:
    - name: num_neighbors_mean
      field:
        _target_: nequip.data.NumNeighbors
      metric:
        _target_: nequip.data.Mean
    - name: per_atom_energy_mean
      field:
        _target_: nequip.data.PerAtomModifier
        field: total_energy
      metric:
        _target_: nequip.data.Mean
    - name: forces_rms
      field: forces
      metric:
        _target_: nequip.data.RootMeanSquare
    - name: per_type_forces_rms
      per_type: true
      field: forces
      metric:
        _target_: nequip.data.RootMeanSquare
class nequip.data.DataStatisticsManager(metrics: List[Dict[str, float | str | Dict[str, str | Callable] | Metric]], dataloader_kwargs: Dict[str, Any] | None = None, type_names: List[str] | None = None)[source]

Manages nequip metrics that can be applied to AtomicDataDict to compute dataset statistics.

The main input argument metrics is a list of dictionaries, where each dictionary contains the following keys:

Mandatory keys:

  • field refers to the quantity of interest for metric computation. It has two formats:

    • a str for a nequip defined field (e.g. total_energy, forces, stress), or

    • a Callable that performs some additional operations before returning a torch.Tensor for metric computation (e.g. PerAtomModifier).

  • metric is a nequip data metric object (a subclass of torchmetrics.Metric).

Optional keys:

  • per_type is a bool (defaults to False if not provided). If True, node fields (such as forces) will have their metrics computed separately for each node type based on the type_names argument.

  • ignore_nan is a bool (defaults to False if not provided). This should be set to true if one expects the underlying target data to contain NaN entries.

  • name is the name that the metric is logged as. Default names are used if not provided, but it is recommended for users to set custom names for clarity and control.

Parameters:
  • metrics (list) – list of dictionaries with keys field, metric, per_type, ignore_nan, and name

  • dataloader_kwargs (dict) – arguments of torch.utils.data.DataLoader for dataset statistics computation (ideally, the batch_size should be as large as possible without triggering OOM)

  • type_names (list) – required for per_type metrics (this must match the type_names argument of the model, it is advisable to use variable interpolation in the config file to make sure they are consistent)

reset()[source]

Resets accumulated statistics.

get_statistics(data_source: Iterable[Dict[str, Tensor]])[source]

Remember to call reset before this is needed.

Parameters:

data_source (Iterable[AtomicDataDict]) – iterable data source

class nequip.data.Mean(**kwargs)[source]

Mean computed in a running fashion.

class nequip.data.MeanAbsolute(**kwargs)[source]

Mean of absolute value computed in a running fashion.

class nequip.data.RootMeanSquare(**kwargs)[source]

Root mean square computed in a running fashion.

class nequip.data.StandardDeviation(squared: bool = False, unbiased=True, **kwargs)[source]

Standard deviation computed in a running fashion with Welford’s online algorithm.

Parameters:
  • squared (bool) – if True, returns variance, else returns standard deviation

  • unbiased (bool) – whether to use the unbiased estimate for standard deviation

class nequip.data.Min(abs: bool = False, **kwargs)[source]

Smallest entry seen.

Parameters:

abs (bool) – whether to use absolute values

class nequip.data.Max(abs: bool = False, **kwargs)[source]

Largest entry seen.

Parameters:

abs (bool) – whether to use absolute values

class nequip.data.Count(**kwargs)[source]

Total number of entries.