# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
import lightning
from lightning.pytorch.callbacks import Callback
import ase
from nequip.data import AtomicDataDict, to_ase
from nequip.data import (
_register_field_prefix,
register_fields,
_NODE_FIELDS,
_EDGE_FIELDS,
_GRAPH_FIELDS,
)
from nequip.train import NequIPLightningModule
from typing import List, Dict, Union, Optional
class XYZFileWriter(Callback):
"""Writes model outputs to an ``xyz`` file.
Users must provide an ``out_file`` that does not contain an extension. The actual output file will take
the form ``{out_file}_dataset{idx}[_epoch{epoch}].xyz`` where ``idx`` is the dataset index (would be ``0`` for a single
validation set but varies depending on number of validation sets) and ``epoch`` is the epoch when the file is produced.
To incorporate original dataset fields in the ``xyz`` file to simplify analysis, users may provide
``output_fields_from_original_dataset``. Such fields will have the prefix ``original_dataset_`` in the ``xyz`` file.
To obtain correct chemical species information, users must provide ``chemical_species`` in an order consistent with
the model's ``type_names``.
To activate the option to save to a different file every epoch, users should set ``separate_file_per_epoch`` true.
Args:
out_file (str): path to output file (must NOT contain ``.xyz`` or ``.extxyz`` extension)
output_fields_from_original_dataset (List[str]): values from the original dataset to save in the ``out_file``
extra_fields (List[str]): extra fields to save in addition to ASE's default fields
chemical_species (List[str]): chemical species in the same order as model's ``type_names``
"""
def __init__(
self,
out_file: str,
output_fields_from_original_dataset: Optional[List[str]] = [],
extra_fields: List[str] = [],
chemical_symbols: Optional[List[str]] = None,
):
assert not (out_file.endswith(".xyz") or out_file.endswith(".extxyz"))
self.out_file = out_file
assert all(
[
field in (_NODE_FIELDS | _EDGE_FIELDS | _GRAPH_FIELDS)
for field in output_fields_from_original_dataset
]
)
# special case total_energy (nequip's convention) vs energy (ase's convention)
self.output_fields_from_original_dataset = []
for field in output_fields_from_original_dataset:
if field == "total_energy":
self.output_fields_from_original_dataset.append("energy")
register_fields(graph_fields=["original_dataset_energy"])
else:
self.output_fields_from_original_dataset.append(field)
_register_field_prefix("original_dataset_")
self.extra_fields = [
"original_dataset_" + field
for field in self.output_fields_from_original_dataset
] + extra_fields
self.chemical_symbols = chemical_symbols
# Could be overwritten by children:
self.separate_file_per_epoch = False
# To be overridden by children
self.prefix = None
def _batch_end(
self,
trainer: lightning.Trainer,
outputs: Dict[str, Union[torch.Tensor, AtomicDataDict.Type]],
batch: AtomicDataDict.Type,
dataloader_idx=0,
):
with torch.no_grad():
output_out = outputs[f"{self.prefix}_{dataloader_idx}_output"].copy()
for field in self.output_fields_from_original_dataset:
# special case total_energy (nequip's convention) vs energy (ase's convention)
if field == "energy":
output_out["original_dataset_energy"] = batch["total_energy"]
else:
output_out["original_dataset_" + field] = batch[field]
# !! EXTREMELY IMPORTANT -- special handling of PBC key if present !!
# ASE data inputs would possess it to be used at data preprocessing time (i.e. neighborlist construction)
# but it won't be passed through the model, so we get it from `batch`
if AtomicDataDict.PBC_KEY in batch:
output_out[AtomicDataDict.PBC_KEY] = batch[AtomicDataDict.PBC_KEY]
# Determine the file
if self.separate_file_per_epoch:
out_path = (
self.out_file
+ f"_dataset{dataloader_idx}_epoch{trainer.current_epoch}.xyz"
)
else:
out_path = self.out_file + f"_dataset{dataloader_idx}.xyz"
# append to the file
ase.io.write(
out_path,
to_ase(
output_out,
chemical_symbols=self.chemical_symbols,
extra_fields=self.extra_fields,
),
format="extxyz",
append=True,
)
del output_out
[docs]
class TestTimeXYZFileWriter(XYZFileWriter):
"""XYZFileWriter designed for saving Test Time Predictions
Users must provide an ``out_file`` that does not contain an extension. The actual output file will take
the form ``{out_file}_dataset{idx}[_epoch{epoch}].xyz`` where ``idx`` is the dataset index (would be ``0`` for a single
validation set but varies depending on number of validation sets) and ``epoch`` is the epoch when the file is produced.
To incorporate original dataset fields in the ``xyz`` file to simplify analysis, users may provide
``output_fields_from_original_dataset``. Such fields will have the prefix ``original_dataset_`` in the ``xyz`` file.
To obtain correct chemical species information, users must provide ``chemical_species`` in an order consistent with
the model's ``type_names``.
To activate the option to save to a different file every epoch, users should set ``separate_file_per_epoch`` true.
Args:
out_file (str): path to output file (must NOT contain ``.xyz`` or ``.extxyz`` extension)
output_fields_from_original_dataset (List[str]): values from the original dataset to save in the ``out_file``
extra_fields (List[str]): extra fields to save in addition to ASE's default fields
chemical_species (List[str]): chemical species in the same order as model's ``type_names``
Example usage in config to write predictions and original dataset ``total_energy`` and ``forces`` to an ``xyz`` file:
.. code-block:: yaml
callbacks:
- _target_: nequip.train.callbacks.TestTimeXYZFileWriter
out_file: ${hydra:runtime.output_dir}/test
output_fields_from_original_dataset: [total_energy, forces]
chemical_symbols: ${chemical_symbols}
"""
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.prefix = "test"
def on_test_batch_end(
self,
trainer: lightning.Trainer,
pl_module: NequIPLightningModule,
outputs: Dict[str, Union[torch.Tensor, AtomicDataDict.Type]],
batch: AtomicDataDict.Type,
batch_idx: int,
dataloader_idx=0,
):
""""""
self._batch_end(
trainer=trainer,
outputs=outputs,
batch=batch,
dataloader_idx=dataloader_idx,
)
[docs]
class ValTimeXYZFileWriter(XYZFileWriter):
"""XYZFileWriter designed for saving Val Time Predictions
Users must provide an ``out_file`` that does not contain an extension. The actual output file will take
the form ``{out_file}_dataset{idx}[_epoch{epoch}].xyz`` where ``idx`` is the dataset index (would be ``0`` for a single
validation set but varies depending on number of validation sets) and ``epoch`` is the epoch when the file is produced.
To incorporate original dataset fields in the ``xyz`` file to simplify analysis, users may provide
``output_fields_from_original_dataset``. Such fields will have the prefix ``original_dataset_`` in the ``xyz`` file.
To obtain correct chemical species information, users must provide ``chemical_species`` in an order consistent with
the model's ``type_names``.
To activate the option to save to a different file every epoch, users should set ``separate_file_per_epoch`` true.
Args:
out_file (str): path to output file (must NOT contain ``.xyz`` or ``.extxyz`` extension)
output_fields_from_original_dataset (List[str]): values from the original dataset to save in the ``out_file``
extra_fields (List[str]): extra fields to save in addition to ASE's default fields
chemical_species (List[str]): chemical species in the same order as model's ``type_names``
separate_file_per_epoch (bool): if True, write outputs to a separate file per epoch (Useful for ``Train`` run types with ValTimeXYZFileWriter)
every_n_epochs (int): if nonzero, only call on epoch multiples of this variable
Example usage in config to write predictions and original dataset ``total_energy`` and ``forces`` to an ``xyz`` file:
.. code-block:: yaml
callbacks:
- _target_: nequip.train.callbacks.ValTimeXYZFileWriter
out_file: ${hydra:runtime.output_dir}/val
output_fields_from_original_dataset: [total_energy, forces]
chemical_symbols: ${chemical_symbols}
separate_file_per_epoch: true
every_n_epochs: 5
"""
def __init__(
self,
separate_file_per_epoch: bool = False,
every_n_epochs: int = 1,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.prefix = "val"
if every_n_epochs <= 0:
raise ValueError("every_n_epochs must be > 0")
self.every_n_epochs = every_n_epochs
self.separate_file_per_epoch = separate_file_per_epoch
def on_validation_batch_end(
self,
trainer: lightning.Trainer,
pl_module: NequIPLightningModule,
outputs: Dict[str, Union[torch.Tensor, AtomicDataDict.Type]],
batch: AtomicDataDict.Type,
batch_idx: int,
dataloader_idx=0,
):
""""""
if not (trainer.current_epoch % self.every_n_epochs):
self._batch_end(
trainer=trainer,
outputs=outputs,
batch=batch,
dataloader_idx=dataloader_idx,
)