-
Tim Tobias Bauerle authoredTim Tobias Bauerle authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
load_dataset.py 1.79 KiB
from typing import Optional
from edml.dataset_utils.cifar.cifar import cifar100_dataloaders, cifar10_dataloaders
from edml.dataset_utils.mnist.mnist import mnist_dataloaders
from edml.dataset_utils.ptb_xl.ptb_xl import ptb_xl_train_val_test
from edml.helpers.data_partitioning import DataPartitioner
from edml.helpers.types import DatasetDataLoaders
def get_dataloaders(
name: str,
batch_size: int,
data_partitioner: Optional[DataPartitioner] = None,
) -> DatasetDataLoaders:
"""
Returns the :class:`DataLoader`s for the given dataset name. In total, the function should return three
:class:`DataLoader` instances for training, validation and testing.
Args:
name (str): The name of the dataset to create the :class:`DataLoader`s for. Currently supported values are
`mnist`, `ptbxl` and `cifar100`.
batch_size (int): The batch size.
data_partitioner (DataPartitioner, optional): A custom data partitioner to use for splitting the data. Defaults
to `None`.
Raises:
ValueError: If the dataset name is unknown.
Notes:
If the data partitioner is not set explicitly, the data should be split randomly.
"""
# To add your own datasets, you can simply introduce a new name check that returns
# the appropriate data loaders.
if name == "mnist":
return mnist_dataloaders(batch_size, data_partitioner=data_partitioner)
elif name == "ptbxl":
return ptb_xl_train_val_test(batch_size, data_partitioner=data_partitioner)
elif name == "cifar10":
return cifar10_dataloaders(batch_size, data_partitioner=data_partitioner)
elif name == "cifar100":
return cifar100_dataloaders(batch_size, data_partitioner=data_partitioner)
else:
raise ValueError(f"Dataset {name} not known.")