Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
load_dataset.py 1.79 KiB
from typing import Optional

from edml.dataset_utils.cifar.cifar import cifar100_dataloaders, cifar10_dataloaders
from edml.dataset_utils.mnist.mnist import mnist_dataloaders
from edml.dataset_utils.ptb_xl.ptb_xl import ptb_xl_train_val_test
from edml.helpers.data_partitioning import DataPartitioner
from edml.helpers.types import DatasetDataLoaders


def get_dataloaders(
    name: str,
    batch_size: int,
    data_partitioner: Optional[DataPartitioner] = None,
) -> DatasetDataLoaders:
    """
    Returns the :class:`DataLoader`s for the given dataset name. In total, the function should return three
    :class:`DataLoader` instances for training, validation and testing.

    Args:
        name (str): The name of the dataset to create the :class:`DataLoader`s for. Currently supported values are
            `mnist`, `ptbxl` and `cifar100`.
        batch_size (int): The batch size.
        data_partitioner (DataPartitioner, optional): A custom data partitioner to use for splitting the data. Defaults
            to `None`.

    Raises:
        ValueError: If the dataset name is unknown.

    Notes:
        If the data partitioner is not set explicitly, the data should be split randomly.
    """

    # To add your own datasets, you can simply introduce a new name check that returns
    # the appropriate data loaders.
    if name == "mnist":
        return mnist_dataloaders(batch_size, data_partitioner=data_partitioner)
    elif name == "ptbxl":
        return ptb_xl_train_val_test(batch_size, data_partitioner=data_partitioner)
    elif name == "cifar10":
        return cifar10_dataloaders(batch_size, data_partitioner=data_partitioner)
    elif name == "cifar100":
        return cifar100_dataloaders(batch_size, data_partitioner=data_partitioner)
    else:
        raise ValueError(f"Dataset {name} not known.")