Source code for nequip.data.dataset.utils
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
import torch
from torch.utils.data import Dataset
from typing import Dict, Union
[docs]
class SubsetByRandomSlice(torch.utils.data.Subset):
"""Subset of dataset by slicing a random permutation of the dataset.
Args:
dataset (Dataset): :class:`torch.utils.data.Dataset` to get subset of
start (int): starting index for the slice
length (int): number of samples to slice from ``start``
seed (int): seed for reproducibility of the random permutation of indices
"""
def __init__(
self,
dataset: Dataset,
start: int,
length: int,
seed: int,
):
data_len = len(dataset)
assert length <= data_len, (
f"Unable to get a subset (length {length}) larger than the size of the dataset (length {data_len}) provided"
)
generator = torch.Generator().manual_seed(seed)
indices = torch.randperm(len(dataset), generator=generator)
indices = indices[slice(start, start + length)]
super().__init__(dataset, indices)
[docs]
def RandomSplitAndIndexDataset(
dataset: torch.utils.data.Dataset,
split_dict: Dict[str, Union[int, float]],
dataset_key: str,
seed: int,
) -> torch.utils.data.Dataset:
"""
Args:
dataset (Dataset): the base dataset that is to be split
split_dict (Dict): dictionary with signature ``{name_of_subset: num_data/frac_data}`` where ``num_data`` must sum up to the size of the given dataset or ``frac_data`` must sum up to 1
dataset_key (str): name of the data subset to return
seed (int) : seed for reproducible splits
"""
# make a new generator from the seed every time a split is done -- reproducible splits as long as seed is the same
generator = torch.Generator().manual_seed(seed)
# API based on dicts (instead of lists and indices) makes it easier to keep track of what each dataset entry is
subset_names = list(split_dict.keys())
lengths = [split_dict[name] for name in subset_names]
# torch.utils.data.random_split will error out if the splits don't make sense, e.g. don't sum up to num_data or 1
# => no need to do safety checks on our part (though the error only appears when a split is first attempted)
splits = torch.utils.data.random_split(dataset, lengths, generator=generator)
return splits[subset_names.index(dataset_key)]