Skip to content
Snippets Groups Projects
Commit 0fe7fef6 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Added CIFAR-10 dataset

parent e13c460a
No related branches found
No related tags found
2 merge requests!18Merge in main,!17Added CIFAR-10 dataset
name: cifar10
average_setting: micro
num_classes: 10
......@@ -7,7 +7,39 @@ from torchvision import transforms, datasets
from edml.helpers.data_partitioning import DataPartitioner
def _load_transformed_data():
def _load_cifar100(train_transform, test_transform):
train_data = datasets.CIFAR100(
os.path.join(os.path.dirname(__file__), "../../../data"),
train=True,
download=True,
transform=train_transform,
)
test_data = datasets.CIFAR100(
os.path.join(os.path.dirname(__file__), "../../../data"),
train=False,
download=True,
transform=test_transform,
)
return train_data, test_data
def _load_cifar10(train_transform, test_transform):
train_data = datasets.CIFAR10(
os.path.join(os.path.dirname(__file__), "../../../data"),
train=True,
download=True,
transform=train_transform,
)
test_data = datasets.CIFAR10(
os.path.join(os.path.dirname(__file__), "../../../data"),
train=False,
download=True,
transform=test_transform,
)
return train_data, test_data
def _get_transforms():
# transformations from https://github.com/akamaster/pytorch_resnet_cifar10
# However, in this repository the test data is used for validation
# Here, we use the test data for testing only and split the training data into train and validation data (90%/10%) as in the original resnet paper
......@@ -25,28 +57,16 @@ def _load_transformed_data():
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
train_data = datasets.CIFAR100(
os.path.join(os.path.dirname(__file__), "../../../data"),
train=True,
download=True,
transform=train_transform,
)
test_data = datasets.CIFAR100(
os.path.join(os.path.dirname(__file__), "../../../data"),
train=False,
download=True,
transform=test_transform,
)
return train_data, test_data
return train_transform, test_transform
def cifar_dataloaders(
batch_size: int,
split: Tuple[float, float] = (0.9, 0.1),
data_partitioner: Optional[DataPartitioner] = None,
):
def _cifar_dataloaders(batch_size, data_partitioner, split, dataset=100):
"""Returns the train, validation and test dataloaders for the CIFAR100 dataset"""
train_data, test_data = _load_transformed_data()
train_transform, test_transform = _get_transforms()
if dataset == 100:
train_data, test_data = _load_cifar100(train_transform, test_transform)
else:
train_data, test_data = _load_cifar10(train_transform, test_transform)
# partition data for device
if data_partitioner is not None:
train_data = data_partitioner.partition(train_data)
......@@ -57,3 +77,19 @@ def cifar_dataloaders(
DataLoader(val, batch_size=batch_size),
DataLoader(test_data, batch_size=batch_size),
)
def cifar100_dataloaders(
batch_size: int,
split: Tuple[float, float] = (0.9, 0.1),
data_partitioner: Optional[DataPartitioner] = None,
):
return _cifar_dataloaders(batch_size, data_partitioner, split, dataset=100)
def cifar10_dataloaders(
batch_size: int,
split: Tuple[float, float] = (0.9, 0.1),
data_partitioner: Optional[DataPartitioner] = None,
):
return _cifar_dataloaders(batch_size, data_partitioner, split, dataset=10)
from typing import Optional
from edml.dataset_utils.cifar100.cifar100 import cifar_dataloaders
from edml.dataset_utils.cifar.cifar import cifar100_dataloaders, cifar10_dataloaders
from edml.dataset_utils.mnist.mnist import mnist_dataloaders
from edml.dataset_utils.ptb_xl.ptb_xl import ptb_xl_train_val_test
from edml.helpers.data_partitioning import DataPartitioner
......@@ -36,7 +36,9 @@ def get_dataloaders(
return mnist_dataloaders(batch_size, data_partitioner=data_partitioner)
elif name == "ptbxl":
return ptb_xl_train_val_test(batch_size, data_partitioner=data_partitioner)
elif name == "cifar10":
return cifar10_dataloaders(batch_size, data_partitioner=data_partitioner)
elif name == "cifar100":
return cifar_dataloaders(batch_size, data_partitioner=data_partitioner)
return cifar100_dataloaders(batch_size, data_partitioner=data_partitioner)
else:
raise ValueError(f"Dataset {name} not known.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment