diff --git a/edml/config/dataset/cifar10.yaml b/edml/config/dataset/cifar10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b203a6b8036bd6446e8ad9f4f5ecbb4ea9d86bbc --- /dev/null +++ b/edml/config/dataset/cifar10.yaml @@ -0,0 +1,3 @@ +name: cifar10 +average_setting: micro +num_classes: 10 diff --git a/edml/dataset_utils/cifar100/cifar100.py b/edml/dataset_utils/cifar/cifar.py similarity index 61% rename from edml/dataset_utils/cifar100/cifar100.py rename to edml/dataset_utils/cifar/cifar.py index 099a40ba46386f920495dc00554a378f5f29a431..49bfdf8b18604f227bf338fcdb78c29c9c74bcb4 100644 --- a/edml/dataset_utils/cifar100/cifar100.py +++ b/edml/dataset_utils/cifar/cifar.py @@ -7,7 +7,39 @@ from torchvision import transforms, datasets from edml.helpers.data_partitioning import DataPartitioner -def _load_transformed_data(): +def _load_cifar100(train_transform, test_transform): + train_data = datasets.CIFAR100( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=True, + download=True, + transform=train_transform, + ) + test_data = datasets.CIFAR100( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=False, + download=True, + transform=test_transform, + ) + return train_data, test_data + + +def _load_cifar10(train_transform, test_transform): + train_data = datasets.CIFAR10( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=True, + download=True, + transform=train_transform, + ) + test_data = datasets.CIFAR10( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=False, + download=True, + transform=test_transform, + ) + return train_data, test_data + + +def _get_transforms(): # transformations from https://github.com/akamaster/pytorch_resnet_cifar10 # However, in this repository the test data is used for validation # Here, we use the test data for testing only and split the training data into train and validation data (90%/10%) as in the original resnet paper @@ -25,28 +57,16 @@ def _load_transformed_data(): transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) - train_data = datasets.CIFAR100( - os.path.join(os.path.dirname(__file__), "../../../data"), - train=True, - download=True, - transform=train_transform, - ) - test_data = datasets.CIFAR100( - os.path.join(os.path.dirname(__file__), "../../../data"), - train=False, - download=True, - transform=test_transform, - ) - return train_data, test_data + return train_transform, test_transform -def cifar_dataloaders( - batch_size: int, - split: Tuple[float, float] = (0.9, 0.1), - data_partitioner: Optional[DataPartitioner] = None, -): +def _cifar_dataloaders(batch_size, data_partitioner, split, dataset=100): """Returns the train, validation and test dataloaders for the CIFAR100 dataset""" - train_data, test_data = _load_transformed_data() + train_transform, test_transform = _get_transforms() + if dataset == 100: + train_data, test_data = _load_cifar100(train_transform, test_transform) + else: + train_data, test_data = _load_cifar10(train_transform, test_transform) # partition data for device if data_partitioner is not None: train_data = data_partitioner.partition(train_data) @@ -57,3 +77,19 @@ def cifar_dataloaders( DataLoader(val, batch_size=batch_size), DataLoader(test_data, batch_size=batch_size), ) + + +def cifar100_dataloaders( + batch_size: int, + split: Tuple[float, float] = (0.9, 0.1), + data_partitioner: Optional[DataPartitioner] = None, +): + return _cifar_dataloaders(batch_size, data_partitioner, split, dataset=100) + + +def cifar10_dataloaders( + batch_size: int, + split: Tuple[float, float] = (0.9, 0.1), + data_partitioner: Optional[DataPartitioner] = None, +): + return _cifar_dataloaders(batch_size, data_partitioner, split, dataset=10) diff --git a/edml/helpers/load_dataset.py b/edml/helpers/load_dataset.py index 1355d61d2b250346d2b7188b9b5bb33dd4e64456..242bfed6c38c182281edcf19fbc2cf732db7ba34 100644 --- a/edml/helpers/load_dataset.py +++ b/edml/helpers/load_dataset.py @@ -1,6 +1,6 @@ from typing import Optional -from edml.dataset_utils.cifar100.cifar100 import cifar_dataloaders +from edml.dataset_utils.cifar.cifar import cifar100_dataloaders, cifar10_dataloaders from edml.dataset_utils.mnist.mnist import mnist_dataloaders from edml.dataset_utils.ptb_xl.ptb_xl import ptb_xl_train_val_test from edml.helpers.data_partitioning import DataPartitioner @@ -36,7 +36,9 @@ def get_dataloaders( return mnist_dataloaders(batch_size, data_partitioner=data_partitioner) elif name == "ptbxl": return ptb_xl_train_val_test(batch_size, data_partitioner=data_partitioner) + elif name == "cifar10": + return cifar10_dataloaders(batch_size, data_partitioner=data_partitioner) elif name == "cifar100": - return cifar_dataloaders(batch_size, data_partitioner=data_partitioner) + return cifar100_dataloaders(batch_size, data_partitioner=data_partitioner) else: raise ValueError(f"Dataset {name} not known.")