Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Swarm Split Learning
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
INDA_ML
Swarm Split Learning
Commits
be776bcc
Commit
be776bcc
authored
8 months ago
by
Tim Tobias Bauerle
Browse files
Options
Downloads
Patches
Plain Diff
Experiment postprocessing and plotting for simulated parallelism
parent
f1692f50
No related branches found
No related tags found
1 merge request
!23
Simulating parallel execution
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
results/result_generation.ipynb
+64
-35
64 additions, 35 deletions
results/result_generation.ipynb
results/result_generation.py
+193
-30
193 additions, 30 deletions
results/result_generation.py
with
257 additions
and
65 deletions
results/result_generation.ipynb
+
64
−
35
View file @
be776bcc
...
...
@@ -29,30 +29,15 @@
"source": [
"df_base_dir = \"./dataframes\"\n",
"projects_with_model = [\n",
" (\"greedy_vs_smart_ecg-non-iid_RESULT\", \"tcn\"),\n",
" (\"greedy_vs_smart_cifar100_resnet20_RESULT\", \"resnet20\"),\n",
" (\"greedy_vs_smart_ecg-iid_RESULT\", \"tcn\"),\n",
" (\"greedy_vs_smart_PTBXL_equal_devices_RESULT\", \"tcn\"),\n",
" (\"greedy_vs_smart_PTBXL_unequal_processors_RESULT\", \"tcn\"),\n",
" (\"greedy_vs_smart_MNIST_unequal_processors_RESULT\", \"simple_conv\"),\n",
" (\"greedy_vs_smart_MNIST_unequal_batteries_unequal_partition_RESULT\", \"simple_conv\"),\n",
" (\"greedy_vs_smart_MNIST_equal_devices_RESULT\", \"simple_conv\"),\n",
" (\"greedy_vs_smart_MNIST_unequal_batteries_RESULT\", \"simple_conv\"),\n",
" (\"fed_vs_split_MNIST_limited_batteries_RESULT\", \"simple_conv\"),\n",
" (\"fed_vs_split_MNIST_unlimited_batteries_RESULT\", \"simple_conv\"),\n",
" (\"fed_vs_split_PTBXL_limited_batteries_RESULT\", \"tcn\"),\n",
" (\"fed_vs_split_PTBXL_unlimited_batteries_RESULT\", \"tcn\"),\n",
" (\"fed_vs_split_cifar100_unlimited_batteries_RESULT\", \"resnet20\"),\n",
" (\"fed_vs_split_CIFAR100_limited_batteries_RESULT\", \"resnet20\"),\n",
" (\"fed_vs_split_50_devices_RESULT\", \"resnet110\"),\n",
" (\"greedy_vs_smart_CIFAR100_equal_devices_RESULT\", \"resnet20\"),\n",
" (\"greedy_vs_smart_CIFAR100_unequal_processors_RESULT\", \"resnet20\"),\n",
" (\"5_devices_unlimited_new\", \"resnet110\"),\n",
" (\"50_devices_unlimited_new\", \"resnet110\"),\n",
" (\"controller_comparison\", \"resnet110\")\n",
"]"
],
"metadata": {
"collapsed": false
},
"id": "
6695251b9af7ea4b
"
"id": "
5b81c8c9ba4b483d
"
},
{
"cell_type": "code",
...
...
@@ -61,18 +46,33 @@
"source": [
"for project_name, _ in projects_with_model:\n",
" save_dataframes(project_name=project_name, strategies=[\n",
" \"swarm_seq\",\n",
" \"fed\",\n",
" \"swarm_max\",\n",
" \"swarm_rand\",\n",
" \"swarm_smart\",\n",
" \"split\"\n",
" #\"swarm_seq\",\n",
" #\"fed\",\n",
" #\"swarm_max\",\n",
" #\"swarm_rand\",\n",
" #\"swarm_smart\",\n",
" #\"split\",\n",
" #\"psl_rand_\",\n",
" #\"psl_sequential_\",\n",
" #\"psl_max_batteries_\",\n",
" #\"swarm_rand_\",\n",
" #\"swarm_sequential_\",\n",
" #\"swarm_max_batteries_\",\n",
" \"psl_sequential__\",\n",
" \"fed___\",\n",
" \"split___\",\n",
" \"swarm_sequential__\",\n",
" \"swarm_max_battery__\",\n",
" \"swarm_smart__\",\n",
" \"psl_sequential_static_at_resnet_decoderpth\",\n",
" \"psl_sequential__resnet_decoderpth\",\n",
" \"psl_sequential_static_at_\",\n",
" ])"
],
"metadata": {
"collapsed": false
},
"id": "
b07913828b33ffcc
"
"id": "
118f1ed9e7537718
"
},
{
"cell_type": "markdown",
...
...
@@ -82,7 +82,7 @@
"metadata": {
"collapsed": false
},
"id": "
d8269abd823cdcc7
"
"id": "
bbc47124f3c80f1c
"
},
{
"cell_type": "code",
...
...
@@ -92,28 +92,45 @@
"# Required for total number of FLOPs computation\n",
"model_flops = {\n",
" \"resnet20\": 41498880,\n",
" \"resnet20_ae\": 45758720,\n",
" \"resnet110\": 258136320,\n",
" \"resnet110_ae\": 262396160,\n",
" \"tcn\": 27240000,\n",
" \"simple_conv\": 16621560\n",
"}"
"}\n",
"\n",
"client_model_flops = {\n",
" \"resnet20\": 15171584,\n",
" \"resnet20_ae\": 19005440,\n",
" \"resnet110\": 88408064,\n",
" \"resnet110_ae\": 92241920,\n",
"}\n",
"\n",
"server_model_flops = {\n",
" \"resnet20\": 26327296,\n",
" \"resnet20_ae\": 26753280,\n",
" \"resnet110\": 169728256,\n",
" \"resnet110_ae\": 170154240,\n",
"}\n",
"experiment_batch_size = 64"
],
"metadata": {
"collapsed": false
},
"id": "
828bcb4737b21c6d
"
"id": "
6e4e3bd0198fe7e7
"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"plots_base_path
=
\"./plots\"\n",
"metrics_base_path
=
\"./metrics\""
"plots_base_path
=
\"./plots\"\n",
"metrics_base_path
=
\"./metrics\""
],
"metadata": {
"collapsed": false
},
"id": "
b13f9e0e98b7ac5b
"
"id": "
ede70693af668f55
"
},
{
"cell_type": "code",
...
...
@@ -126,14 +143,16 @@
" print(\" loading data from disk\")\n",
" dataframes = load_dataframes(proj_name, df_base_dir)\n",
" print(\" generating metrics\")\n",
" generate_metric_files(dataframes, proj_name, model_flops[model_name])\n",
" generate_metric_files(dataframes, proj_name, model_flops[model_name], client_model_flops[model_name],\n",
" # TODO distinguish AE\n",
" base_path=metrics_base_path, batch_size=experiment_batch_size)\n",
" print(\" generating plots\")\n",
" generate_plots(dataframes, proj_name)"
],
"metadata": {
"collapsed": false
},
"id": "
c4b1ed2d809c54e2
"
"id": "
1c72379feadc98cb
"
},
{
"cell_type": "code",
...
...
@@ -145,7 +164,17 @@
"metadata": {
"collapsed": false
},
"id": "378bf3365dd9fde2"
"id": "7927831aecd5a02"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "fc698fc664867532"
}
],
"metadata": {
...
...
%% Cell type:code id:26fa58dbda0ddc40 tags:
```
python
from
result_generation
import
save_dataframes
,
load_dataframes
,
generate_metric_files
,
generate_plots
```
%% Cell type:markdown id:3e72d0fe76a55e56 tags:
# Downloading the dataframes (takes time)
%% Cell type:code id:
6695251b9af7ea4b
tags:
%% Cell type:code id:
5b81c8c9ba4b483d
tags:
```
python
df_base_dir
=
"
./dataframes
"
projects_with_model
=
[
(
"
greedy_vs_smart_ecg-non-iid_RESULT
"
,
"
tcn
"
),
(
"
greedy_vs_smart_cifar100_resnet20_RESULT
"
,
"
resnet20
"
),
(
"
greedy_vs_smart_ecg-iid_RESULT
"
,
"
tcn
"
),
(
"
greedy_vs_smart_PTBXL_equal_devices_RESULT
"
,
"
tcn
"
),
(
"
greedy_vs_smart_PTBXL_unequal_processors_RESULT
"
,
"
tcn
"
),
(
"
greedy_vs_smart_MNIST_unequal_processors_RESULT
"
,
"
simple_conv
"
),
(
"
greedy_vs_smart_MNIST_unequal_batteries_unequal_partition_RESULT
"
,
"
simple_conv
"
),
(
"
greedy_vs_smart_MNIST_equal_devices_RESULT
"
,
"
simple_conv
"
),
(
"
greedy_vs_smart_MNIST_unequal_batteries_RESULT
"
,
"
simple_conv
"
),
(
"
fed_vs_split_MNIST_limited_batteries_RESULT
"
,
"
simple_conv
"
),
(
"
fed_vs_split_MNIST_unlimited_batteries_RESULT
"
,
"
simple_conv
"
),
(
"
fed_vs_split_PTBXL_limited_batteries_RESULT
"
,
"
tcn
"
),
(
"
fed_vs_split_PTBXL_unlimited_batteries_RESULT
"
,
"
tcn
"
),
(
"
fed_vs_split_cifar100_unlimited_batteries_RESULT
"
,
"
resnet20
"
),
(
"
fed_vs_split_CIFAR100_limited_batteries_RESULT
"
,
"
resnet20
"
),
(
"
fed_vs_split_50_devices_RESULT
"
,
"
resnet110
"
),
(
"
greedy_vs_smart_CIFAR100_equal_devices_RESULT
"
,
"
resnet20
"
),
(
"
greedy_vs_smart_CIFAR100_unequal_processors_RESULT
"
,
"
resnet20
"
),
(
"
5_devices_unlimited_new
"
,
"
resnet110
"
),
(
"
50_devices_unlimited_new
"
,
"
resnet110
"
),
(
"
controller_comparison
"
,
"
resnet110
"
)
]
```
%% Cell type:code id:
b07913828b33ffcc
tags:
%% Cell type:code id:
118f1ed9e7537718
tags:
```
python
for
project_name
,
_
in
projects_with_model
:
save_dataframes
(
project_name
=
project_name
,
strategies
=
[
"
swarm_seq
"
,
"
fed
"
,
"
swarm_max
"
,
"
swarm_rand
"
,
"
swarm_smart
"
,
"
split
"
#"swarm_seq",
#"fed",
#"swarm_max",
#"swarm_rand",
#"swarm_smart",
#"split",
#"psl_rand_",
#"psl_sequential_",
#"psl_max_batteries_",
#"swarm_rand_",
#"swarm_sequential_",
#"swarm_max_batteries_",
"
psl_sequential__
"
,
"
fed___
"
,
"
split___
"
,
"
swarm_sequential__
"
,
"
swarm_max_battery__
"
,
"
swarm_smart__
"
,
"
psl_sequential_static_at_resnet_decoderpth
"
,
"
psl_sequential__resnet_decoderpth
"
,
"
psl_sequential_static_at_
"
,
])
```
%% Cell type:markdown id:
d8269abd823cdcc7
tags:
%% Cell type:markdown id:
bbc47124f3c80f1c
tags:
# Generating Results from saved dataframes
%% Cell type:code id:
828bcb4737b21c6d
tags:
%% Cell type:code id:
6e4e3bd0198fe7e7
tags:
```
python
# Required for total number of FLOPs computation
model_flops
=
{
"
resnet20
"
:
41498880
,
"
resnet20_ae
"
:
45758720
,
"
resnet110
"
:
258136320
,
"
resnet110_ae
"
:
262396160
,
"
tcn
"
:
27240000
,
"
simple_conv
"
:
16621560
}
client_model_flops
=
{
"
resnet20
"
:
15171584
,
"
resnet20_ae
"
:
19005440
,
"
resnet110
"
:
88408064
,
"
resnet110_ae
"
:
92241920
,
}
server_model_flops
=
{
"
resnet20
"
:
26327296
,
"
resnet20_ae
"
:
26753280
,
"
resnet110
"
:
169728256
,
"
resnet110_ae
"
:
170154240
,
}
experiment_batch_size
=
64
```
%% Cell type:code id:
b13f9e0e98b7ac5b
tags:
%% Cell type:code id:
ede70693af668f55
tags:
```
python
plots_base_path
=
"
./plots
"
metrics_base_path
=
"
./metrics
"
plots_base_path
=
"
./plots
"
metrics_base_path
=
"
./metrics
"
```
%% Cell type:code id:
c4b1ed2d809c54e2
tags:
%% Cell type:code id:
1c72379feadc98cb
tags:
```
python
def
generate_result_files
(
projects_with_model
):
for
proj_name
,
model_name
in
projects_with_model
:
print
(
proj_name
)
print
(
"
loading data from disk
"
)
dataframes
=
load_dataframes
(
proj_name
,
df_base_dir
)
print
(
"
generating metrics
"
)
generate_metric_files
(
dataframes
,
proj_name
,
model_flops
[
model_name
])
generate_metric_files
(
dataframes
,
proj_name
,
model_flops
[
model_name
],
client_model_flops
[
model_name
],
# TODO distinguish AE
base_path
=
metrics_base_path
,
batch_size
=
experiment_batch_size
)
print
(
"
generating plots
"
)
generate_plots
(
dataframes
,
proj_name
)
```
%% Cell type:code id:
378bf3365dd9fde
2 tags:
%% Cell type:code id:
7927831aecd5a0
2 tags:
```
python
generate_result_files
(
projects_with_model
)
```
%% Cell type:code id:fc698fc664867532 tags:
```
python
```
...
...
This diff is collapsed.
Click to expand it.
results/result_generation.py
+
193
−
30
View file @
be776bcc
...
...
@@ -5,6 +5,7 @@ import matplotlib.pyplot as plt
import
numpy
as
np
import
pandas
as
pd
import
wandb
import
math
# For plotting
STRATEGY_MAPPING
=
{
...
...
@@ -14,6 +15,16 @@ STRATEGY_MAPPING = {
"
swarm_rand
"
:
"
Swarm SL (Rand)
"
,
"
swarm_max
"
:
"
Swarm SL (Greedy)
"
,
"
fed
"
:
"
Vanilla FL
"
,
"
psl_sequential__
"
:
"
PSSL (Seq)
"
,
"
fed___
"
:
"
Vanilla FL
"
,
"
swarm_sequential__
"
:
"
Swarm SL (Seq)
"
,
"
swarm_smart__
"
:
"
Swarm SL (Smart)
"
,
"
swarm_rand__
"
:
"
Swarm SL (Rand)
"
,
"
swarm_max_battery__
"
:
"
Swarm SL (Greedy)
"
,
"
split___
"
:
"
Vanilla SL
"
,
"
psl_sequential_static_at_resnet_decoderpth
"
:
"
PSSL (Seq) AE Static
"
,
"
psl_sequential__resnet_decoderpth
"
:
"
PSSL (Seq) AE
"
,
"
psl_sequential_static_at_
"
:
"
PSSL (Seq) Static
"
,
}
LABEL_MAPPING
=
{
...
...
@@ -29,6 +40,75 @@ LABEL_MAPPING = {
}
def
scale_parallel_time
(
run_df
,
scale_factor
=
1.0
):
"""
Scales the time by the provided scale_factor.
Args:
run_df: The dataframe of the project
scale_factor: (float) the factor to shorten time e.g. 2 halves the total time
Returns:
run_df: the dataframe with scaled timestamps
"""
if
scale_factor
==
1
:
return
run_df
if
"
_timestamp
"
in
run_df
.
columns
:
start_time
=
run_df
[
"
_timestamp
"
].
min
()
for
col
in
run_df
.
columns
:
if
col
.
endswith
(
"
.start
"
)
or
col
.
endswith
(
"
.end
"
)
or
col
==
"
_timestamp
"
:
run_df
[
col
]
=
(
run_df
[
col
]
-
start_time
)
/
scale_factor
+
start_time
if
col
.
endswith
(
"
.duration
"
)
or
col
==
"
_runtime
"
:
run_df
[
col
]
=
run_df
[
col
]
/
scale_factor
return
run_df
def
get_scale_factors
(
group
):
"""
Determines the scale factor to account for parallelism.
For each set of runs (i.e. one run of controller, d0, d1, ...), the time overhead introduced by running a
parallel operation sequentially is determined and the resulting factor to scale the runs down as well.
If no parallel operations were simulated, no time is deduced and the scale factor will equal 1.
Args:
group: the group of runs
Returns:
A list of factors to scale down each set of runs.
"""
columns_to_count
=
[
"
parallel_client_train_time
"
,
"
parallel_client_backprop_time
"
,
"
parallel_client_model_update_time
"
,
"
parallel_fed_time
"
,
]
scale_factors
=
[]
num_runs
=
len
(
next
(
iter
(
group
))[
1
])
max_runtime
=
[
0
]
*
num_runs
elapsed_time
=
[
0
]
*
num_runs
parallel_time
=
[
0
]
*
num_runs
for
name
,
runs
in
group
.
items
():
for
i
,
run_df
in
enumerate
(
runs
):
if
"
_runtime
"
in
run_df
.
columns
:
# assure that run_df is not empty
if
run_df
[
"
_runtime
"
].
max
()
>
max_runtime
[
i
]:
max_runtime
[
i
]
=
run_df
[
"
_runtime
"
].
max
()
for
col_name
in
columns_to_count
:
if
f
"
{
col_name
}
.parallel_time
"
in
run_df
.
columns
:
elapsed_time
[
i
]
+=
run_df
[
f
"
{
col_name
}
.elapsed_time
"
].
sum
()
parallel_time
[
i
]
+=
run_df
[
f
"
{
col_name
}
.parallel_time
"
].
sum
()
if
"
parallel_client_eval_time.parallel_time
"
in
run_df
.
columns
:
if
"
evaluate_batch_time.duration
"
in
run_df
.
columns
:
elapsed_time
[
i
]
+=
run_df
[
"
parallel_client_eval_time.elapsed_time
"
].
sum
()
parallel_time
[
i
]
+=
(
run_df
[
"
parallel_client_eval_time.parallel_time
"
].
sum
()
-
run_df
[
"
evaluate_batch_time.duration
"
].
sum
()
)
# evaluate batch time measured at server -> sequential either way
for
i
,
max_rt
in
enumerate
(
max_runtime
):
if
max_rt
>
0
:
scale_factors
.
append
(
max_rt
/
(
max_rt
-
elapsed_time
[
i
]
+
parallel_time
[
i
]))
else
:
scale_factors
.
append
(
1.0
)
return
scale_factors
def
save_dataframes
(
project_name
,
strategies
,
base_dir
=
"
./dataframes
"
):
"""
Fetches the dataframes from wandb and saves them to the base_dir.
...
...
@@ -81,13 +161,22 @@ def save_dataframes(project_name, strategies, base_dir="./dataframes"):
history_groups
=
{}
for
(
strategy
,
job
),
group
in
run_groups
.
items
():
print
(
f
"
{
strategy
}
{
job
}
"
)
history
=
defaultdict
(
list
)
unscaled_runs
=
defaultdict
(
list
)
for
name
,
runs
in
group
.
items
():
print
(
f
"
{
name
}
"
)
for
run
in
runs
:
history_df
=
pd
.
DataFrame
(
run
.
scan_history
())
history
[
name
].
append
(
history_df
)
history_groups
[(
strategy
,
job
)]
=
history
unscaled_runs
[
name
].
append
(
history_df
)
# rescale if parallelism was only simulated
if
job
==
"
train
"
and
len
(
unscaled_runs
)
>
0
:
scale_factors
=
get_scale_factors
(
unscaled_runs
)
scaled_runs
=
defaultdict
(
list
)
for
name
,
runs
in
unscaled_runs
.
items
():
for
i
,
run
in
enumerate
(
runs
):
scaled_runs
[
name
].
append
(
scale_parallel_time
(
run
,
scale_factors
[
i
]))
history_groups
[(
strategy
,
job
)]
=
scaled_runs
else
:
history_groups
[(
strategy
,
job
)]
=
unscaled_runs
# save dataframe
print
(
"
saving data
"
)
for
(
strategy
,
job
),
group
in
history_groups
.
items
():
...
...
@@ -105,7 +194,7 @@ def save_dataframes(project_name, strategies, base_dir="./dataframes"):
def
load_dataframes
(
project_name
,
base_dir
=
"
./dataframes
"
):
"""
Load
e
s saved dataframes from the given project.
Loads saved dataframes from the given project.
Args:
project_name: (str) the name of the project folder
base_dir: (str) the base directory to fetch the dataframes from
...
...
@@ -130,7 +219,7 @@ def load_dataframes(project_name, base_dir="./dataframes"):
project_dir
,
strategy
,
job
,
device_id
=
path
.
split
(
os
.
sep
)
# Load dataframe from csv
df
=
pd
.
read_csv
(
os
.
path
.
join
(
root
,
file
))
df
=
pd
.
read_csv
(
os
.
path
.
join
(
root
,
file
)
,
low_memory
=
False
)
# Add dataframe to dictionary
if
(
strategy
,
job
)
not
in
history_groups
:
...
...
@@ -141,12 +230,13 @@ def load_dataframes(project_name, base_dir="./dataframes"):
return
history_groups
def
get_total_flops
(
groups
,
total_model_flops
):
def
get_total_flops
(
groups
,
total_model_flops
,
client_model_flops
,
batch_size
=
64
):
"""
Returns the total number of FLOPs for each group.
Args:
groups: The runs of one project, according to the structure of the wandb project
total_model_flops: (int) the total number of FLOPs of the model
client_model_flops: (int) the total number of FLOPs of the client model
Returns:
flops_per_group: (dict) the total number of FLOPs for each group
"""
...
...
@@ -155,6 +245,7 @@ def get_total_flops(groups, total_model_flops):
if
job
==
"
train
"
:
flops
=
0
num_runs
=
1
# avoid division by 0
num_clients
=
len
(
group
.
items
())
-
1
# minus controller
for
name
,
runs
in
group
.
items
():
if
(
name
!=
"
controller
"
...
...
@@ -170,6 +261,38 @@ def get_total_flops(groups, total_model_flops):
flops
+=
(
run_df
[
col_name
].
sum
()
*
total_model_flops
)
# 1x forward
if
col_name
==
"
adaptive_learning_threshold_applied
"
:
# deduce client model flops twice as client backprop is avoided
if
(
run_df
[
col_name
].
dtype
==
"
object
"
):
# if boolean values were logged
# assumptions: compute avg number of samples per batch
avg_samples_per_epoch
=
sum
(
run_df
[
"
train_accuracy.num_samples
"
].
dropna
()
)
/
len
(
run_df
[
"
train_accuracy.num_samples
"
].
dropna
()
)
avg_num_batches
=
(
math
.
ceil
(
avg_samples_per_epoch
/
num_clients
/
batch_size
)
*
num_clients
)
avg_samples_per_batch
=
(
avg_samples_per_epoch
/
avg_num_batches
)
flops
-=
(
len
(
run_df
[
col_name
].
dropna
())
*
client_model_flops
*
2
*
avg_samples_per_batch
)
else
:
# numbers of samples skipped are logged -> sum up
flops
-=
(
run_df
[
col_name
].
sum
()
*
client_model_flops
*
2
)
flops
=
flops
/
num_runs
flops_per_group
[
"
strategy
"
].
append
(
STRATEGY_MAPPING
[
strategy
])
flops_per_group
[
"
flops
"
].
append
(
round
(
flops
/
1000000000
,
3
))
# in GFLOPs
...
...
@@ -341,6 +464,27 @@ def accuracy_over_epoch(history_groups, phase="train"):
return
results
def
accuracy_over_time
(
history_groups
,
phase
=
"
train
"
):
"""
Returns the accuracy over time for each group. No averaging implemented yet if there are multiple runs per group!
Args:
history_groups: The runs of one project, according to the structure of the wandb project
phase: (str) the phase to get the accuracy for, either
"
train
"
or
"
val
"
Returns:
results: (dict) the accuracy (list(float)) per round (list(int)) for each group
"""
results
=
{}
for
(
strategy
,
job
),
group
in
history_groups
.
items
():
if
job
==
"
train
"
:
run_df
=
group
[
"
controller
"
][
0
]
# no averaging
time_acc
=
run_df
[[
f
"
{
phase
}
_accuracy.value
"
,
"
_runtime
"
]].
dropna
()
results
[(
strategy
,
job
)]
=
(
time_acc
[
"
_runtime
"
],
time_acc
[
f
"
{
phase
}
_accuracy.value
"
],
)
return
results
def
plot_accuracies
(
accuracies_per_round
,
save_path
=
None
,
phase
=
"
train
"
):
"""
Plots the accuracy over the epoch for each group.
...
...
@@ -366,6 +510,28 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"):
plt
.
close
()
def
plot_accuracies_over_time
(
accuracies_per_time
,
save_path
=
None
,
phase
=
"
train
"
):
"""
Plots the accuracy over the time for each group.
Args:
accuracies_per_time: (dict) the accuracy (list(float)) per time (list(float)) for each group
save_path: (str) the path to save the plot to
"""
plt
.
figure
()
for
(
strategy
,
job
),
(
time
,
accs
)
in
accuracies_per_time
.
items
():
plt
.
plot
(
time
,
accs
,
label
=
f
"
{
STRATEGY_MAPPING
[
strategy
]
}
"
)
plt
.
xlabel
(
LABEL_MAPPING
[
"
runtime
"
])
plt
.
ylabel
(
LABEL_MAPPING
[
f
"
{
phase
}
accuracy
"
])
plt
.
legend
()
plt
.
tight_layout
()
if
save_path
is
None
:
plt
.
show
()
else
:
plt
.
savefig
(
f
"
{
save_path
}
.pdf
"
,
format
=
"
pdf
"
)
plt
.
savefig
(
f
"
{
save_path
}
.png
"
,
format
=
"
png
"
)
plt
.
close
()
def
battery_over_time
(
history_groups
,
num_intervals
=
1000
):
"""
Returns the average battery over time for each group.
...
...
@@ -1028,20 +1194,6 @@ def generate_plots(history_groups, project_name, base_path="./plots"):
aggregated
=
False
,
)
train_times
=
get_train_times
(
history_groups
)
plot_batteries_over_time_with_activity
(
batteries_over_time
,
max_runtimes
,
train_times
,
save_path
=
f
"
{
project_path
}
/activity_over_time
"
,
)
plot_batteries_over_epoch_with_activity_at_time_scale
(
batteries_over_time
,
max_runtimes
,
train_times
,
save_path
=
f
"
{
project_path
}
/activity_over_time_with_epoch
"
,
)
# batteries over epoch
batteries_over_epoch
=
aggregated_battery_over_epoch
(
history_groups
,
num_intervals
=
1000
...
...
@@ -1058,13 +1210,6 @@ def generate_plots(history_groups, project_name, base_path="./plots"):
save_path
=
f
"
{
project_path
}
/batteries_over_epoch
"
,
aggregated
=
False
,
)
training_times
=
get_train_times
(
history_groups
)
plot_batteries_over_epoch_with_activity_at_epoch_scale
(
batteries_over_epoch
,
training_times
=
training_times
,
save_path
=
f
"
{
project_path
}
/activity_over_epoch
"
,
)
# remaining devices
remaining_devices
=
remaining_devices_per_round
(
history_groups
)
plot_remaining_devices
(
...
...
@@ -1080,16 +1225,32 @@ def generate_plots(history_groups, project_name, base_path="./plots"):
val_accs
=
accuracy_over_epoch
(
history_groups
,
"
val
"
)
plot_accuracies
(
val_accs
,
save_path
=
f
"
{
project_path
}
/val_accuracy
"
,
phase
=
"
val
"
)
time_train_accs
=
accuracy_over_time
(
history_groups
,
"
train
"
)
plot_accuracies_over_time
(
time_train_accs
,
save_path
=
f
"
{
project_path
}
/train_accuracy_over_time
"
,
phase
=
"
train
"
,
)
time_val_accs
=
accuracy_over_time
(
history_groups
,
"
val
"
)
plot_accuracies_over_time
(
time_val_accs
,
save_path
=
f
"
{
project_path
}
/val_accuracy_over_time
"
,
phase
=
"
val
"
)
def
generate_metric_files
(
history_groups
,
project_name
,
model_flops
,
base_path
=
"
./metrics
"
history_groups
,
project_name
,
total_model_flops
,
client_model_flops
,
base_path
=
"
./metrics
"
,
batch_size
=
64
,
):
"""
Generates metric file for the given history groups and saves them to the project_name folder.
Args:
history_groups: The runs of one project, according to the structure of the wandb project
project_name: (str) the name of the project
model_flops: (int) the total number of FLOPs of the model
total_
model_flops: (int) the total number of FLOPs of the model
base_path: (str) the base path to save the metrics to
"""
project_path
=
f
"
{
base_path
}
/
{
project_name
}
"
...
...
@@ -1103,7 +1264,9 @@ def generate_metric_files(
get_communication_overhead
(
history_groups
)
).
set_index
(
"
strategy
"
)
total_flops
=
pd
.
DataFrame
.
from_dict
(
get_total_flops
(
history_groups
,
model_flops
)
get_total_flops
(
history_groups
,
total_model_flops
,
client_model_flops
,
batch_size
)
).
set_index
(
"
strategy
"
)
df
=
pd
.
concat
([
test_acc
,
comm_overhead
,
total_flops
],
axis
=
1
)
df
.
to_csv
(
f
"
{
project_path
}
/metrics.csv
"
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment