Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Swarm Split Learning
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
INDA_ML
Swarm Split Learning
Commits
0fe7fef6
Commit
0fe7fef6
authored
9 months ago
by
Tim Tobias Bauerle
Browse files
Options
Downloads
Patches
Plain Diff
Added CIFAR-10 dataset
parent
e13c460a
No related branches found
No related tags found
2 merge requests
!18
Merge in main
,
!17
Added CIFAR-10 dataset
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
edml/config/dataset/cifar10.yaml
+3
-0
3 additions, 0 deletions
edml/config/dataset/cifar10.yaml
edml/dataset_utils/cifar/cifar.py
+56
-20
56 additions, 20 deletions
edml/dataset_utils/cifar/cifar.py
edml/helpers/load_dataset.py
+4
-2
4 additions, 2 deletions
edml/helpers/load_dataset.py
with
63 additions
and
22 deletions
edml/config/dataset/cifar10.yaml
0 → 100644
+
3
−
0
View file @
0fe7fef6
name
:
cifar10
average_setting
:
micro
num_classes
:
10
This diff is collapsed.
Click to expand it.
edml/dataset_utils/cifar
100
/cifar
100
.py
→
edml/dataset_utils/cifar/cifar.py
+
56
−
20
View file @
0fe7fef6
...
...
@@ -7,7 +7,39 @@ from torchvision import transforms, datasets
from
edml.helpers.data_partitioning
import
DataPartitioner
def
_load_transformed_data
():
def
_load_cifar100
(
train_transform
,
test_transform
):
train_data
=
datasets
.
CIFAR100
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
../../../data
"
),
train
=
True
,
download
=
True
,
transform
=
train_transform
,
)
test_data
=
datasets
.
CIFAR100
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
../../../data
"
),
train
=
False
,
download
=
True
,
transform
=
test_transform
,
)
return
train_data
,
test_data
def
_load_cifar10
(
train_transform
,
test_transform
):
train_data
=
datasets
.
CIFAR10
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
../../../data
"
),
train
=
True
,
download
=
True
,
transform
=
train_transform
,
)
test_data
=
datasets
.
CIFAR10
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
../../../data
"
),
train
=
False
,
download
=
True
,
transform
=
test_transform
,
)
return
train_data
,
test_data
def
_get_transforms
():
# transformations from https://github.com/akamaster/pytorch_resnet_cifar10
# However, in this repository the test data is used for validation
# Here, we use the test data for testing only and split the training data into train and validation data (90%/10%) as in the original resnet paper
...
...
@@ -25,28 +57,16 @@ def _load_transformed_data():
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
]
)
train_data
=
datasets
.
CIFAR100
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
../../../data
"
),
train
=
True
,
download
=
True
,
transform
=
train_transform
,
)
test_data
=
datasets
.
CIFAR100
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
../../../data
"
),
train
=
False
,
download
=
True
,
transform
=
test_transform
,
)
return
train_data
,
test_data
return
train_transform
,
test_transform
def
cifar_dataloaders
(
batch_size
:
int
,
split
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.1
),
data_partitioner
:
Optional
[
DataPartitioner
]
=
None
,
):
def
_cifar_dataloaders
(
batch_size
,
data_partitioner
,
split
,
dataset
=
100
):
"""
Returns the train, validation and test dataloaders for the CIFAR100 dataset
"""
train_data
,
test_data
=
_load_transformed_data
()
train_transform
,
test_transform
=
_get_transforms
()
if
dataset
==
100
:
train_data
,
test_data
=
_load_cifar100
(
train_transform
,
test_transform
)
else
:
train_data
,
test_data
=
_load_cifar10
(
train_transform
,
test_transform
)
# partition data for device
if
data_partitioner
is
not
None
:
train_data
=
data_partitioner
.
partition
(
train_data
)
...
...
@@ -57,3 +77,19 @@ def cifar_dataloaders(
DataLoader
(
val
,
batch_size
=
batch_size
),
DataLoader
(
test_data
,
batch_size
=
batch_size
),
)
def
cifar100_dataloaders
(
batch_size
:
int
,
split
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.1
),
data_partitioner
:
Optional
[
DataPartitioner
]
=
None
,
):
return
_cifar_dataloaders
(
batch_size
,
data_partitioner
,
split
,
dataset
=
100
)
def
cifar10_dataloaders
(
batch_size
:
int
,
split
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.1
),
data_partitioner
:
Optional
[
DataPartitioner
]
=
None
,
):
return
_cifar_dataloaders
(
batch_size
,
data_partitioner
,
split
,
dataset
=
10
)
This diff is collapsed.
Click to expand it.
edml/helpers/load_dataset.py
+
4
−
2
View file @
0fe7fef6
from
typing
import
Optional
from
edml.dataset_utils.cifar
100
.cifar
100
import
cifar_dataloaders
from
edml.dataset_utils.cifar.cifar
import
cifar
100_dataloaders
,
cifar10
_dataloaders
from
edml.dataset_utils.mnist.mnist
import
mnist_dataloaders
from
edml.dataset_utils.ptb_xl.ptb_xl
import
ptb_xl_train_val_test
from
edml.helpers.data_partitioning
import
DataPartitioner
...
...
@@ -36,7 +36,9 @@ def get_dataloaders(
return
mnist_dataloaders
(
batch_size
,
data_partitioner
=
data_partitioner
)
elif
name
==
"
ptbxl
"
:
return
ptb_xl_train_val_test
(
batch_size
,
data_partitioner
=
data_partitioner
)
elif
name
==
"
cifar10
"
:
return
cifar10_dataloaders
(
batch_size
,
data_partitioner
=
data_partitioner
)
elif
name
==
"
cifar100
"
:
return
cifar_dataloaders
(
batch_size
,
data_partitioner
=
data_partitioner
)
return
cifar
100
_dataloaders
(
batch_size
,
data_partitioner
=
data_partitioner
)
else
:
raise
ValueError
(
f
"
Dataset
{
name
}
not known.
"
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment