# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from .. import AtomicDataDict
from nequip.utils.logger import RankedLogger
import lightning
from omegaconf import OmegaConf, DictConfig, ListConfig
from hydra.utils import instantiate
import copy
from typing import List, Dict, Any, Union, Optional
logger = RankedLogger(__name__, rank_zero_only=True)
[docs]
class NequIPDataModule(lightning.LightningDataModule):
"""
Sanity checking is only performed at runtime -- ensure that the correct datasets are provided for the intended runs,
which can be ``train``, ``val``, ``test``, and/or ``predict``.
- ``train`` runs require ``train_dataset`` and ``val_dataset``
- ``val`` runs require ``val_dataset``
- ``test`` runs require ``test_dataset``
- ``predict`` runs require ``predict_dataset``
One can explicitly specify which ``train``, ``val``, ``test``, ``predict`` datasets to use, or randomly split a dataset to be used for any of those tasks with the ``split_dataset`` argument. These options are not mutually exclusive, e.g. if a single ``test_dataset`` is provided, and ``split_dataset`` is used to get another test set, there will now be two test sets (indexed by ``0`` and ``1``) used for testing. If ``test_dataset`` is a list, i.e. multiple test datasets are provided (e.g. if there are ``n`` test sets with indices ``0``, ``1``, ..., ``n - 1``) and multiple ``split_ataset`` is a list that contributes multiple test sets (say ``m`` such test sets are provided). There will be a total of ``m+n`` test sets, with the ones from ``test_dataset`` taking indices ``0``, ``1``, ..., ``n - 1`` and the ones from the ``split_dataset`` taking indices ``n``, ``n+1``, ..., ``n+m-1``.
Args:
seed (int): data seed for reproducibility
train_dataset (Dict/List[Dict]): training dataset
val_dataset (Dict/List[Dict]): validation dataset(s) (can provide multiple datasets in a list)
test_dataset (Dict/List[Dict]): test dataset(s) (can provide multiple datasets in a list)
predict_dataset (Dict/List[Dict]): prediction dataset(s) (can provide multiple datasets in a list)
split_dataset (Dict/List[Dict]): dictionary with a ``dataset`` key, which defines the dataset and the keys ``train``, ``val``, ``test``, ``predict`` which represent the subsets to split ``dataset`` into and are either ``int`` s that sum up to the size of ``dataset`` or ``float`` s that sum up to 1 (at least 2, but not necessarily all of ``train``, ``val``, ``test``, ``predict`` must be provided if this option is used)
train_dataloader (Dict): training ``DataLoader`` configuration dictionary
val_dataloader (Dict): validation ``DataLoader`` configuration dictionary
test_dataloader (Dict): testing ``DataLoader`` configuration dictionary
predict_dataloader (Dict): prediction ``DataLoader`` configuration dictionary
stats_manager (Dict): dictionary that can be instantiated into a :class:`~nequip.data.DataStatisticsManager` object
"""
def __init__(
self,
seed: int,
train_dataset: Optional[Union[Dict, List]] = [],
val_dataset: Optional[Union[Dict, List]] = [],
test_dataset: Optional[Union[Dict, List]] = [],
predict_dataset: Optional[Union[Dict, List]] = [],
split_dataset: Optional[Union[Dict, List]] = [],
train_dataloader: Dict = {},
val_dataloader: Dict = {},
test_dataloader: Dict = {},
predict_dataloader: Dict = {},
stats_manager: Optional[Dict] = None,
):
super().__init__()
# internal logic follows lists in order of train, val, test, predict, split
# == first convert all dataset configs to lists if not already lists ==
dconfigs = []
for dconfig in [
train_dataset,
val_dataset,
test_dataset,
predict_dataset,
split_dataset,
]:
# convert to primitives as later logic is based on types
if isinstance(dconfig, DictConfig) or isinstance(dconfig, ListConfig):
dconfig = OmegaConf.to_container(dconfig, resolve=True)
assert isinstance(dconfig, dict) or isinstance(dconfig, list)
# make deep copies of the dicts to avoid mutating them in case they are used outside (should be relatively cheap)
if not isinstance(dconfig, list):
dconfigs.append([copy.deepcopy(dconfig)])
else:
dconfigs.append(copy.deepcopy(dconfig))
# == account for split datasets ==
dataset_type_map = ["train", "val", "test", "predict"] # index matches dconfig
# loop over datasets to split
for split_config in dconfigs[4]:
split_dict = split_config.copy()
dataset_to_split = split_dict.pop("dataset")
assert all(
[dataset_type in dataset_type_map for dataset_type in split_dict.keys()]
)
for dataset_type in split_dict:
# dataset_type is one of "train", "val", "test", "predict"
dconfigs[dataset_type_map.index(dataset_type)].append(
{
"_target_": "nequip.data.dataset.RandomSplitAndIndexDataset",
"dataset": dataset_to_split,
"split_dict": split_dict,
"dataset_key": dataset_type,
"seed": seed,
}
)
# == set dataset configs (which are all List[Dict] by this point) ==
self.train_dataset_config = dconfigs[0]
self.val_dataset_config = dconfigs[1]
self.test_dataset_config = dconfigs[2]
self.predict_dataset_config = dconfigs[3]
# == keep track of number of each dataset in order of train, val, test, predict ==
# for communicating to NequIPLightningModule during runs to create enough MetricsManagers to match the number of dataloaders
self.num_datasets = {
"train": len(self.train_dataset_config),
"val": len(self.val_dataset_config),
"test": len(self.test_dataset_config),
"predict": len(self.predict_dataset_config),
}
logger.info(
"Found {} training dataset(s), {} validation dataset(s), {} test dataset(s), and {} predict dataset(s).".format(
self.num_datasets["train"],
self.num_datasets["val"],
self.num_datasets["test"],
self.num_datasets["predict"],
)
)
# == reproducibility params ==
self.seed = seed
# distinguish train generator and generators for other datasets
# to control reproducibility of training runs
# generators for val/test/predict aren't important since we're only interested in their accumulated metrics
self.train_generator_state = (
torch.Generator().manual_seed(self.seed).get_state()
)
self.generator_state = torch.Generator().manual_seed(self.seed).get_state()
# == dataloader ==
for dataloader_dict, varname in zip(
[train_dataloader, val_dataloader, test_dataloader, predict_dataloader],
["train", "val", "test", "predict"],
):
# copy to be safe against mutation
dataloader_dict = dataloader_dict.copy()
if isinstance(dataloader_dict, DictConfig):
dataloader_dict = OmegaConf.to_container(
dataloader_dict.copy(), resolve=True
)
# provide a default just in case
if "_target_" not in dataloader_dict:
dataloader_dict["_target_"] = "torch.utils.data.DataLoader"
assert "dataset" not in dataloader_dict
assert "generator" not in dataloader_dict
if "collate_fn" not in dataloader_dict:
# Allow collate_fn to be overridden by a function wrapping
# AtomicDataDict.batched_from_list, but default to it.
dataloader_dict["collate_fn"] = {
"_target_": "nequip.data.datamodule._base_datamodule._default_collate_fn_factory"
}
setattr(self, f"{varname}_dataloader_config", dataloader_dict)
# == data statistics manager ==
self.stats_manager_cfg = stats_manager
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""""""
self.train_generator_state = state_dict["train_generator_state"]
self.generator_state = state_dict["generator_state"]
for varname in ["train", "val", "test", "predict"]:
for i in range(self.num_datasets[varname]):
dataloader_sd = state_dict.get(f"_{varname}_dataloader_{i}", {})
# Save the state dict to be loaded when the dataloader is created
setattr(self, f"_{varname}_dataloader_state_dict_{i}", dataloader_sd)
def state_dict(self) -> Dict[str, Any]:
""""""
if self.train_generator is not None:
train_generator_state = self.train_generator.get_state()
else:
train_generator_state = self.train_generator_state
if self.generator is not None:
generator_state = self.generator.get_state()
else:
generator_state = self.generator_state
sd = {
"train_generator_state": train_generator_state,
"generator_state": generator_state,
}
for varname in ["train", "val", "test", "predict"]:
dloader = getattr(self, f"_{varname}_dataloader", None)
for i in range(self.num_datasets[varname]):
key = f"_{varname}_dataloader_{i}"
# check if dloader exists and has a state_dict method
if (
not dloader
or not dloader[i]
or not callable(getattr(dloader[i], "state_dict", None))
): # user has not specified restartable dataloader
sd[key] = {}
else:
sd[key] = dloader[i].state_dict()
return sd
def setup(self, stage: str) -> None:
""""""
self.generator = torch.Generator().manual_seed(self.seed)
self.generator.set_state(self.generator_state)
if stage == "fit":
# requires both "train" and "val" datasets
if len(self.train_dataset_config) == 0:
raise RuntimeError("No train dataset provided -- unable to do training")
else:
self.train_dataset = instantiate(self.train_dataset_config)
if len(self.val_dataset_config) == 0:
raise RuntimeError("No val dataset provided -- unable to do training")
else:
self.val_dataset = instantiate(self.val_dataset_config)
# set train generator
self.train_generator = torch.Generator().manual_seed(self.seed)
self.train_generator.set_state(self.train_generator_state)
elif stage == "validate":
if len(self.val_dataset_config) == 0:
raise RuntimeError("No val dataset provided -- unable to do validation")
else:
self.val_dataset = instantiate(self.val_dataset_config)
elif stage == "test":
if len(self.test_dataset_config) == 0:
raise RuntimeError("No test dataset provided -- unable to do testing")
else:
self.test_dataset = instantiate(self.test_dataset_config)
elif stage == "predict":
if len(self.predict_dataset_config) == 0:
raise RuntimeError("No predict dataset provided -- unable to predict")
else:
self.predict_dataset = instantiate(self.predict_dataset_config)
return
def teardown(self, stage: str):
""""""
if hasattr(self, "generator"):
self.generator_state = self.generator.get_state()
del self.generator
if stage == "fit":
if hasattr(self, "train_dataset"):
del self.train_dataset
if hasattr(self, "val_dataset"):
del self.val_dataset
if hasattr(self, "train_generator"):
self.train_generator_state = self.train_generator.get_state()
del self.train_generator
if hasattr(self, "_train_dataloader"):
del self._train_dataloader
if hasattr(self, "_val_dataloader"):
del self._val_dataloader
elif stage == "validate":
if hasattr(self, "val_dataset"):
del self.val_dataset
if hasattr(self, "_val_dataloader"):
del self._val_dataloader
elif stage == "test":
if hasattr(self, "test_dataset"):
del self.test_dataset
if hasattr(self, "_test_dataloader"):
del self._test_dataloader
elif stage == "predict":
if hasattr(self, "predict_dataset"):
del self.predict_dataset
if hasattr(self, "_predict_dataloader"):
del self._predict_dataloader
def train_dataloader(self):
""""""
if hasattr(self, "_train_dataloader"):
return self._train_dataloader[0]
# must only return single train dataloader for now
# see https://lightning.ai/docs/pytorch/stable/data/iterables.html#multiple-dataloaders
self._train_dataloader = self._get_dloader(
self.train_dataset, self.train_generator, self.train_dataloader_config
)
self._maybe_load_dataloader_state_dict(self._train_dataloader, "train")
return self._train_dataloader[0]
def val_dataloader(self):
""""""
if hasattr(self, "_val_dataloader"):
return self._val_dataloader
self._val_dataloader = self._get_dloader(
self.val_dataset, self.generator, self.val_dataloader_config
)
self._maybe_load_dataloader_state_dict(self._val_dataloader, "val")
return self._val_dataloader
def test_dataloader(self):
""""""
if hasattr(self, "_test_dataloader"):
return self._test_dataloader
self._test_dataloader = self._get_dloader(
self.test_dataset, self.generator, self.test_dataloader_config
)
self._maybe_load_dataloader_state_dict(self._test_dataloader, "test")
return self._test_dataloader
def predict_dataloader(self):
""""""
if hasattr(self, "_predict_dataloader"):
return self._predict_dataloader
# we don't expect this method to be used but it's here for consistency
logger.warning(
"predict_dataloader() is not expected to be used in typical workflows"
)
self._predict_dataloader = self._get_dloader(
self.predict_dataset, self.generator, self.predict_dataloader_config
)
self._maybe_load_dataloader_state_dict(self._predict_dataloader, "predict")
return self._predict_dataloader
def _maybe_load_dataloader_state_dict(self, dloader, varname):
for i in range(self.num_datasets[varname]):
# load the state dict if it exists
if hasattr(dloader[i], "load_state_dict") and hasattr(
self, f"_{varname}_dataloader_state_dict_{i}"
):
dloader[i].load_state_dict(
getattr(self, f"_{varname}_dataloader_state_dict_{i}")
)
def _get_dloader(self, datasets, generator, dataloader_dict):
if "_target_" not in dataloader_dict:
raise RuntimeError(
f"`_target_` is missing from the dataloder dict: {dataloader_dict}"
)
return [
instantiate(
dataloader_dict,
dataset=dataset,
generator=generator,
)
for dataset in datasets
]
[docs]
def get_statistics(self, dataset: str = "train", dataset_idx: int = 0):
"""
Compute statistics of the dataset.
Args:
dataset (str) : ``train``, ``val``, ``test``, or ``predict``
dataset_idx (int): dataset index (there can be multiple ``val``, ``test``, ``predict`` datasets)
"""
if self.stats_manager_cfg is None:
return {}
stats_manager = instantiate(self.stats_manager_cfg) # , _recursive_=False)
assert dataset in ["train", "val", "test", "predict"]
task_map = {
"train": "fit",
"val": "validate",
"test": "test",
"predict": "predict",
}
try:
self.prepare_data()
self.setup(stage=task_map[dataset])
# get dataloader, using dataloader_kwargs from the appropriate dataset.
# stats manager can override options if it wants to, like batch size.
dataloader_dict = getattr(self, dataset + "_dataloader_config").copy()
if stats_manager.dataloader_kwargs is not None:
dataloader_dict.update(stats_manager.dataloader_kwargs)
dloader = self._get_dloader(
getattr(self, dataset + "_dataset"), self.generator, dataloader_dict
)
dloader = dloader[dataset_idx]
stats_dict = stats_manager.get_statistics(dloader)
finally:
self.teardown(stage=task_map[dataset])
return stats_dict
def _default_collate_fn_factory() -> callable:
"""Allow `instantiate` to get the default collate_fn by calling this function."""
return AtomicDataDict.batched_from_list