diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ef61377d042a5fddc75a5ef7b978825897e9b373 --- /dev/null +++ b/.gitignore @@ -0,0 +1,315 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python,pycharm,jupyternotebooks +# Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm,jupyternotebooks + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### PyCharm ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### PyCharm Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint +.idea/**/sonarlint/ + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin +.idea/**/sonarIssues.xml + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced +.idea/**/markdown-navigator.xml +.idea/**/markdown-navigator-enh.xml +.idea/**/markdown-navigator/ + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 +.idea/$CACHE_FILE$ + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream +.idea/codestream.xml + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij +.idea/**/azureSettings.xml + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook + +# IPython + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# End of https://www.toptal.com/developers/gitignore/api/python,pycharm,jupyternotebooks + + + +# Data files. +*.csv +data + +# Temporary files. +logs +multirun +outputs +/wandb + +# Private files. +wandb_key.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ace604b3d220af079d4207ab1a9344dfa790061e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,22 @@ +repos: + - repo: https://github.com/psf/black + rev: 24.2.0 + hooks: + - id: black + exclude: edml/generated + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v1.2.3 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: local + hooks: + - id: pytest-check + name: pytest-check + stages: [ pre-push ] + types: [ python ] + entry: pytest + language: system + pass_filenames: false + always_run: true diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..20079283b2eecb5aa39a1e59a36866263e2de83e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Tim Bauerle, Ahmad Ayad, Sven Lechner + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index c0689cc6205a04bf371d2f84f6a243596caae8ff..e61b9064479425699e5b1645d81bff6e434d96c4 100644 --- a/README.md +++ b/README.md @@ -1,93 +1,43 @@ -# Swarm Split Learning +# SwarmSplitLearning +In this repository, we introduce our own fully distributed variant of the well-known split learning algorithm. -## Getting started - -To make it easy for you to get started with GitLab, here's a list of recommended next steps. - -Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)! - -## Add your files +## Installation -- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files -- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command: +For a faster installation, install the libmamba-solver first before creating the conda environment: +```bash +conda install -n base conda-libmamba-solver +conda config --set solver libmamba ``` -cd existing_repo -git remote add origin https://git.rwth-aachen.de/inda_ml/swarm-split-learning.git -git branch -M main -git push -uf origin main -``` - -## Integrate with your tools - -- [ ] [Set up project integrations](https://git.rwth-aachen.de/inda_ml/swarm-split-learning/-/settings/integrations) - -## Collaborate with your team - -- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/) -- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html) -- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically) -- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/) -- [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html) - -## Test and Deploy - -Use the built-in continuous integration in GitLab. - -- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html) -- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/) -- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html) -- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/) -- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html) - -*** -# Editing this README +Otherwise, the default conda solver will be used, which may take forever. -When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template. +Update the file path to your desired directory for the environment by changing the value of `prefix:` at the end of +the [environment.yml](environment.yml) file. +Then run: -## Suggestions for a good README - -Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information. - -## Name -Choose a self-explaining name for your project. - -## Description -Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors. - -## Badges -On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge. - -## Visuals -Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method. - -## Installation -Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection. - -## Usage -Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README. - -## Support -Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc. - -## Roadmap -If you have ideas for releases in the future, it is a good idea to list them in the README. +```bash +conda env create -f environment.yml +conda activate [ENV_NAME] +``` -## Contributing -State if you are open to contributions and what your requirements are for accepting them. +For tracking experiments using Weights and Biases, place your API key without any spaces in +the [wandb_key.txt](wandb_key.txt) file and make sure that `wandb: True` is set in +the [wandb config](edml/config/wandb.yaml). +Otherwise, metrics etc. will be printed to the console. -For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self. +If you plan to commit to this repository, please install **pre-commit** for a consistent code formatting upon committing. +Therefore, run the following command in the repo: -You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser. +```bash + pre-commit install +``` -## Authors and acknowledgment -Show your appreciation to those who have contributed to the project. +Optionally, for formatting without committing, you may run: +```bash +pre-commit run --all-files +``` -## License -For open source projects, say how it is licensed. -## Project status -If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers. diff --git a/config/controller/parallel_swarm2.yaml b/config/controller/parallel_swarm2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fcf824f56607b1f672a2f6d5d5cef8e457d5480a --- /dev/null +++ b/config/controller/parallel_swarm2.yaml @@ -0,0 +1,5 @@ +name: psl +_target_: edml.controllers.parallel_split_controller.ParallelSplitController +_partial_: true +scheduler: + _target_: edml.controllers.scheduler.sequential.SequentialNextServerScheduler diff --git a/config/controller/parallel_swarm_ash_1.65.yaml b/config/controller/parallel_swarm_ash_1.65.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f5616323ef4deb46eefcf2bf4ba091b406437da --- /dev/null +++ b/config/controller/parallel_swarm_ash_1.65.yaml @@ -0,0 +1,7 @@ +_target_: edml.controllers.parallel_split_controller.ParallelSplitController +_partial_: true +scheduler: + _target_: edml.controllers.scheduler.sequential.SequentialNextServerScheduler +adaptive_threshold_fn: + _target_: edml.controllers.adaptive_threshold.static.StaticAdaptiveThresholdFn + threshold: 1.65 diff --git a/config/default.yaml b/config/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ff71a77b1310a81a84df1bd7bfd040813ee1f9d --- /dev/null +++ b/config/default.yaml @@ -0,0 +1,30 @@ +# This file contains the default configuration for any experiment. The values do not matter as each experiment sweep +# Will override them. However, due to hydra limitations, the values must be present in the file. +defaults: + - dataset: mnist + - battery: flops_and_communication + - loss_fn: !!null + - experiment: default_experiment + - model_provider: mnist + - optimizer: !!null + - scheduler: !!null + - seed: default + - topology: equal_batteries + - grpc: default + - wandb: default + - _self_ + +own_device_id: "d0" +num_devices: ${len:${topology.devices}} + +# define config attributes for the group: +group_by: + - controller: [ name, scheduler: name, adaptive_threshold_fn: name ] + - model_provider: [ decoder: path ] +# group attribute determined by resolver with the given attributes +group: ${group_name:${group_by}} + +# This registers the framework-provided configuration files with hydra. +hydra: + searchpath: + - pkg://edml/config diff --git a/config/experiment/baseline.yaml b/config/experiment/baseline.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51dd14aed7bb47f4a5cbfc08b02b81cac5e85242 --- /dev/null +++ b/config/experiment/baseline.yaml @@ -0,0 +1,26 @@ +# Base properties for the experiment. +project: baseline +name: mnist +job: train + +# Training parameters. +batch_size: 64 +max_epochs: 1 +max_rounds: 20 +metrics: [ accuracy ] + +# Checkpoint saving and early stopping. +save_weights: True +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +early_stopping: True +early_stopping_patience: 5 +early_stopping_metric: accuracy + +# Dataset partitioning. +partition: True +fractions: !!null # set to !!null if dataset should not be partitioned or partitioned equally +latency: !!null # set to !!null for no latency + +# Debug. +load_single_batch_for_debugging: False diff --git a/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism-none.yaml b/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism-none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89d487cef2d286fb0139cf965c44cbb4551b05bd --- /dev/null +++ b/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism-none.yaml @@ -0,0 +1,26 @@ +# Base properties for the experiment. +project: inda-ml-comparisons +name: cifar100-effectiveness-adaptive-threshold-mechanism-none +job: train + +# Training parameters. +batch_size: 64 +max_epochs: 1 +max_rounds: 200 +metrics: [ accuracy ] + +# Checkpoint saving and early stopping. +save_weights: True +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +early_stopping: True +early_stopping_patience: 200 +early_stopping_metric: accuracy + +# Dataset partitioning. +partition: True +fractions: !!null +latency: !!null + +# Debug. +load_single_batch_for_debugging: False diff --git a/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism.yaml b/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad7614dcf4a01d6c75776bef1f136bdea03c1010 --- /dev/null +++ b/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism.yaml @@ -0,0 +1,26 @@ +# Base properties for the experiment. +project: inda-ml-comparisons +name: cifar100-effectiveness-adaptive-threshold-mechanism +job: train + +# Training parameters. +batch_size: 64 +max_epochs: 1 +max_rounds: 200 +metrics: [ accuracy ] + +# Checkpoint saving and early stopping. +save_weights: True +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +early_stopping: True +early_stopping_patience: 200 +early_stopping_metric: accuracy + +# Dataset partitioning. +partition: True +fractions: !!null +latency: !!null + +# Debug. +load_single_batch_for_debugging: False diff --git a/config/optimizer/adamw.yaml b/config/optimizer/adamw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ad79e3ac3f9939b0e927c42779c1c90e79de483 --- /dev/null +++ b/config/optimizer/adamw.yaml @@ -0,0 +1,2 @@ +_target_: torch.optim.AdamW +lr: 0.001 diff --git a/config/optimizer/sdg_with_momentum.yaml b/config/optimizer/sdg_with_momentum.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a32f2ff6ddbcb1777c847d651c39386711cb5252 --- /dev/null +++ b/config/optimizer/sdg_with_momentum.yaml @@ -0,0 +1,4 @@ +_target_: torch.optim.SGD +lr: 0.1 +momentum: 0.9 +weight_decay: 0.0001 diff --git a/config/scheduler/multistep.yaml b/config/scheduler/multistep.yaml new file mode 100644 index 0000000000000000000000000000000000000000..77af7c04aabbd32634c35cebc58bceb7426dc941 --- /dev/null +++ b/config/scheduler/multistep.yaml @@ -0,0 +1,3 @@ +_target_: torch.optim.lr_scheduler.MultiStepLR +milestones: [ 100, 150 ] +gamma: 0.1 diff --git a/config/sweep/cifar100/cifar100-effectiveness-adaptive-threshold-mechanism.yaml b/config/sweep/cifar100/cifar100-effectiveness-adaptive-threshold-mechanism.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8f367ee18ad6235e9eaf81fc5c57393cca363f4 --- /dev/null +++ b/config/sweep/cifar100/cifar100-effectiveness-adaptive-threshold-mechanism.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - override /battery: unlimited + - override /dataset: cifar100 + - override /experiment: cifar100-effectiveness-adaptive-threshold-mechanism + - override /loss_fn: cross_entropy + - override /model_provider: resnet20 + - override /optimizer: sdg_with_momentum + - override /scheduler: multistep + - override /topology: equal_batteries + - _self_ + +hydra: + mode: MULTIRUN + sweeper: + params: + +controller: parallel_swarm #parallel_swarm_ash_1.65 + # +controller/scheduler: max_battery + # controller.adaptive_learning_threshold: 1.65 diff --git a/config/sweep/mnist/all.yaml b/config/sweep/mnist/all.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8137a13cc92426ad4c5eefb7b5ad1b04205f54e2 --- /dev/null +++ b/config/sweep/mnist/all.yaml @@ -0,0 +1,15 @@ +# @package _global_ +defaults: + - override /battery: flops_and_communication + - override /loss_fn: nll + - override /model_provider: mnist + - override /optimizer: adamw + - override /topology: equal_batteries + - _self_ + +hydra: + mode: MULTIRUN + sweeper: + params: + +controller: swarm,parallel_swarm + controller/scheduler: max_battery,sequential,rand diff --git a/config/wandb/.gitignore b/config/wandb/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1e82fc7debc189b62994e390082556b955d65abe --- /dev/null +++ b/config/wandb/.gitignore @@ -0,0 +1 @@ +*.yaml diff --git a/edml/__init__.py b/edml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/config/__init__.py b/edml/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/config/battery/flops_and_communication.yaml b/edml/config/battery/flops_and_communication.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8625e4592757e59743c1c53802e89ea1394007f --- /dev/null +++ b/edml/config/battery/flops_and_communication.yaml @@ -0,0 +1,4 @@ +deduction_per_second: 0 +deduction_per_mflop: 0.001 +deduction_per_mbyte_received: 0.001 +deduction_per_mbyte_sent: 0.01 diff --git a/edml/config/battery/only_flops.yaml b/edml/config/battery/only_flops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a531f6cc0df74c53787d3c47cdd8ac7a0b0cec47 --- /dev/null +++ b/edml/config/battery/only_flops.yaml @@ -0,0 +1,4 @@ +deduction_per_second: 0 +deduction_per_mflop: 0.001 +deduction_per_mbyte_received: 0 +deduction_per_mbyte_sent: 0 diff --git a/edml/config/battery/ptbxl_tcn_cost.yaml b/edml/config/battery/ptbxl_tcn_cost.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ae7b4e9447b978d53eb53b43b38c6165e97fdd3 --- /dev/null +++ b/edml/config/battery/ptbxl_tcn_cost.yaml @@ -0,0 +1,4 @@ +deduction_per_second: 0.1 +deduction_per_mflop: 0.0001 +deduction_per_mbyte_received: 0.5 +deduction_per_mbyte_sent: 0.5 diff --git a/edml/config/battery/resnet20_cifar100_cost.yaml b/edml/config/battery/resnet20_cifar100_cost.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f3ed4e251535077f68cccbe6378546802e105ee5 --- /dev/null +++ b/edml/config/battery/resnet20_cifar100_cost.yaml @@ -0,0 +1,4 @@ +deduction_per_second: 1 +deduction_per_mflop: 0.0000001 +deduction_per_mbyte_received: 0.0001 +deduction_per_mbyte_sent: 0.0001 diff --git a/edml/config/battery/time_flops_communication.yaml b/edml/config/battery/time_flops_communication.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60baa0ca8fdf145bd9d7c7476987d6cab13baf3f --- /dev/null +++ b/edml/config/battery/time_flops_communication.yaml @@ -0,0 +1,4 @@ +deduction_per_second: 1 +deduction_per_mflop: 0.00025 +deduction_per_mbyte_received: 0.05 +deduction_per_mbyte_sent: 0.05 diff --git a/edml/config/battery/unlimited.yaml b/edml/config/battery/unlimited.yaml new file mode 100644 index 0000000000000000000000000000000000000000..40bb0e552fc474a80584cc9b6135c54ced4979ef --- /dev/null +++ b/edml/config/battery/unlimited.yaml @@ -0,0 +1,4 @@ +deduction_per_second: 0 +deduction_per_mflop: 0 +deduction_per_mbyte_received: 0 +deduction_per_mbyte_sent: 0 diff --git a/edml/config/config2.yaml b/edml/config/config2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32653ef0489b03cd1aa84b1e2fac9379ae6dc3c5 --- /dev/null +++ b/edml/config/config2.yaml @@ -0,0 +1,13 @@ +defaultsasdasd: + - dataset: mnist + - battery: flops_and_communication + - model: simple_conv +# - model_provider: resnet20 + - topology: equal_batteries + - experiment: default_experiment + - wandb + - grpc + - _self_ + +own_device_id: "d0" +num_devices: ${len:${topology.devices}} # Attribute num_devices is derived from the devices list or overridden from command line to train on a subset of devices diff --git a/edml/config/controller/README.md b/edml/config/controller/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b901e015bfd0dda3571671b2d272af1bb172d294 --- /dev/null +++ b/edml/config/controller/README.md @@ -0,0 +1,3 @@ +> Inheritance-based controller configurations. + +Allows custom next server scheduling strategies to be configured without changing code. diff --git a/edml/config/controller/adaptive_threshold_fn/dynamic.yaml b/edml/config/controller/adaptive_threshold_fn/dynamic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d81a4797e2a99b0697029f37d8e94c27f9fa707 --- /dev/null +++ b/edml/config/controller/adaptive_threshold_fn/dynamic.yaml @@ -0,0 +1,5 @@ +_target_: edml.controllers.adaptive_threshold_mechanism.dynamic.LogarithmicDecayAdaptiveThresholdFn +name: log_decay_at +starting_value: 4 +approach_value: 1 +decay_rate: 0.05 diff --git a/edml/config/controller/adaptive_threshold_fn/static.yaml b/edml/config/controller/adaptive_threshold_fn/static.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ec21ff8dffffc281964d798c13b455ea9e43b15 --- /dev/null +++ b/edml/config/controller/adaptive_threshold_fn/static.yaml @@ -0,0 +1,3 @@ +name: static_at +_target_: edml.controllers.adaptive_threshold_mechanism.static.StaticAdaptiveThresholdFn +threshold: 1.65 diff --git a/edml/config/controller/fed.yaml b/edml/config/controller/fed.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0a0886c6aff43a87b8ca5d0616f44393a31ced3 --- /dev/null +++ b/edml/config/controller/fed.yaml @@ -0,0 +1,3 @@ +name: fed +_target_: edml.controllers.fed_controller.FedController +_partial_: true diff --git a/edml/config/controller/parallel_swarm.yaml b/edml/config/controller/parallel_swarm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be9d2a6b43df282facea48954ff6fa6ca7907fdc --- /dev/null +++ b/edml/config/controller/parallel_swarm.yaml @@ -0,0 +1,6 @@ +name: psl +_target_: edml.controllers.parallel_split_controller.ParallelSplitController +_partial_: true +defaults: + - scheduler: sequential + - adaptive_threshold_fn: !!null diff --git a/edml/config/controller/scheduler/max_battery.yaml b/edml/config/controller/scheduler/max_battery.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0a258cb82414f2adec4685c1e02cf559dd120d3 --- /dev/null +++ b/edml/config/controller/scheduler/max_battery.yaml @@ -0,0 +1,2 @@ +name: max_battery +_target_: edml.controllers.scheduler.max_battery.MaxBatteryNextServerScheduler diff --git a/edml/config/controller/scheduler/rand.yaml b/edml/config/controller/scheduler/rand.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1db64320ca7c09694b7108c176d8569e7fa88c17 --- /dev/null +++ b/edml/config/controller/scheduler/rand.yaml @@ -0,0 +1,2 @@ +name: rand +_target_: edml.controllers.scheduler.random.RandomNextServerScheduler diff --git a/edml/config/controller/scheduler/sequential.yaml b/edml/config/controller/scheduler/sequential.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f920dc24f8766f5a35b4ba2e4962f59fd041307 --- /dev/null +++ b/edml/config/controller/scheduler/sequential.yaml @@ -0,0 +1,2 @@ +name: sequential +_target_: edml.controllers.scheduler.sequential.SequentialNextServerScheduler diff --git a/edml/config/controller/scheduler/smart.yaml b/edml/config/controller/scheduler/smart.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04a05edbc4c75baa121cfecdd1e9a9b400b4b12e --- /dev/null +++ b/edml/config/controller/scheduler/smart.yaml @@ -0,0 +1,4 @@ +name: smart +_target_: edml.controllers.scheduler.smart.SmartNextServerScheduler +fallback_scheduler: + _target_: edml.controllers.scheduler.max_battery.MaxBatteryNextServerScheduler diff --git a/edml/config/controller/split.yaml b/edml/config/controller/split.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9a5c80c2f72191504c0763fad4feab07b667cd0 --- /dev/null +++ b/edml/config/controller/split.yaml @@ -0,0 +1,3 @@ +name: split +_target_: edml.controllers.split_controller.SplitController +_partial_: true diff --git a/edml/config/controller/swarm.yaml b/edml/config/controller/swarm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f31a60a20c1e4cff7f8eec4bd9b1ff4958753ab5 --- /dev/null +++ b/edml/config/controller/swarm.yaml @@ -0,0 +1,5 @@ +name: swarm +_target_: edml.controllers.swarm_controller.SwarmController +_partial_: true +defaults: + - scheduler: sequential diff --git a/edml/config/dataset/cifar10.yaml b/edml/config/dataset/cifar10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b203a6b8036bd6446e8ad9f4f5ecbb4ea9d86bbc --- /dev/null +++ b/edml/config/dataset/cifar10.yaml @@ -0,0 +1,3 @@ +name: cifar10 +average_setting: micro +num_classes: 10 diff --git a/edml/config/dataset/cifar100.yaml b/edml/config/dataset/cifar100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0848f24d85ea89132df3b71ce79148226cbb7259 --- /dev/null +++ b/edml/config/dataset/cifar100.yaml @@ -0,0 +1,3 @@ +name: cifar100 +average_setting: micro +num_classes: 100 diff --git a/edml/config/dataset/mnist.yaml b/edml/config/dataset/mnist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bdd4b2604d70738a598b523656538e0c41925dea --- /dev/null +++ b/edml/config/dataset/mnist.yaml @@ -0,0 +1,3 @@ +name: mnist +average_setting: micro +num_classes: 10 diff --git a/edml/config/dataset/ptbxl.yaml b/edml/config/dataset/ptbxl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ecd70bbb2dd9e90319e2cd4ca825531734ed097b --- /dev/null +++ b/edml/config/dataset/ptbxl.yaml @@ -0,0 +1,5 @@ +name: ptbxl +n_inputs: 12 +average_setting: micro +num_classes: 5 +distribution: iid # set to non-iid for non-iid distribution diff --git a/edml/config/experiment/cifar.yaml b/edml/config/experiment/cifar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e43040363e21fafd69686d7432cfd92d443e7ba --- /dev/null +++ b/edml/config/experiment/cifar.yaml @@ -0,0 +1,28 @@ +project: cifar100_with_resnet +name: swarm_test +job: train +loss_fn: cross_entropy +optimizer: sgd +momentum: 0.9 +weight_decay: 0.0001 +batch_size: 64 +learning_rate: 0.1 +scheduler: multistep +scheduler_milestones: [100, 150] +scheduler_gamma: 0.1 +max_epochs: 1 +max_rounds: 200 +metrics: [ accuracy ] +save_weights: True +server_model_load_path: "edml/models/weights/initial/Resnet18_Server_random_weights.pth" +client_model_load_path: "edml/models/weights/initial/Resnet18_Client_random_weights.pth" +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +partition: True +fractions: !!null +random_seed: 42 +load_single_batch_for_debugging: False +early_stopping: True +early_stopping_patience: 200 +early_stopping_metric: accuracy +latency: !!null diff --git a/edml/config/experiment/default_experiment.yaml b/edml/config/experiment/default_experiment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33e9648a601111be9351e0f011e5873568ed7dde --- /dev/null +++ b/edml/config/experiment/default_experiment.yaml @@ -0,0 +1,24 @@ +project: baseline +name: swarm_test +job: train +loss_fn: nll +batch_size: 64 +optimizer: adamw +learning_rate: 0.001 +max_epochs: 1 +max_rounds: 20 +metrics: [ accuracy ] +load_weights: True +save_weights: True +server_model_load_path: "edml/models/weights/initial/MNIST_Server_random_weights.pth" +client_model_load_path: "edml/models/weights/initial/MNIST_Client_random_weights.pth" +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +partition: True +fractions: [ 0.1, 0.1, 0.1, 0.1, 0.1 ] # set to !!null if dataset should not be partitioned or partitioned equally +random_seed: 42 +load_single_batch_for_debugging: False +early_stopping: True +early_stopping_patience: 5 +early_stopping_metric: accuracy +latency: [ 0.0, 1.0, 0.0, 0.0, 0.0 ] # set to !!null for no latency diff --git a/edml/config/experiment/ecg.yaml b/edml/config/experiment/ecg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6c559e81afb6fab973ec9d07258ff5aa0fc0bd7 --- /dev/null +++ b/edml/config/experiment/ecg.yaml @@ -0,0 +1,24 @@ +project: ecg +name: swarm_smart +job: train +loss_fn: bce +batch_size: 64 +optimizer: adamw +learning_rate: 0.001 +max_epochs: 1 +max_rounds: 100 +metrics: [ accuracy ] +load_weights: False +save_weights: True +server_model_load_path: "edml/models/weights/initial/TCN_Server_random_weights.pth" +client_model_load_path: "edml/models/weights/initial/TCN_Client_random_weights.pth" +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +partition: True +fractions: [ 0.1, 0.1, 0.1, 0.1, 0.6 ] +random_seed: 42 +load_single_batch_for_debugging: False +early_stopping: True +early_stopping_patience: 100 +early_stopping_metric: accuracy +latency: [ 5, 5, 5, 5, 0 ] diff --git a/edml/config/experiment/parallel_split_vs_swarm.yaml b/edml/config/experiment/parallel_split_vs_swarm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..71606e984c9a413c9af7e6f5a674364938ec3485 --- /dev/null +++ b/edml/config/experiment/parallel_split_vs_swarm.yaml @@ -0,0 +1,29 @@ +project: testing_resnet_memory3 +name: tt3 +job: train +loss_fn: cross_entropy +optimizer: sgd +momentum: 0.9 +weight_decay: 0.0001 +batch_size: 64 +learning_rate: 0.1 +scheduler: multistep +scheduler_milestones: [ 100, 150 ] +scheduler_gamma: 0.1 +max_epochs: 1 +max_rounds: 50 +metrics: [ accuracy ] +load_weights: True +save_weights: True +server_model_load_path: "edml/models/weights/initial/Resnet18_Server_random_weights.pth" +client_model_load_path: "edml/models/weights/initial/Resnet18_Client_random_weights.pth" +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +partition: True +fractions: !!null #[ 0.125, 0.125, 0.125, 0.125, 0.5 ] +random_seed: 42 +load_single_batch_for_debugging: False +early_stopping: True # use only for saving best model +early_stopping_patience: 200 +early_stopping_metric: accuracy +latency: !!null #[ 5, 5, 5, 5, 0 ] diff --git a/edml/config/experiment/resnet20_cifar100.yaml b/edml/config/experiment/resnet20_cifar100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7935fbb0470c0dc8cc6ae70404c6a44f4c7d9787 --- /dev/null +++ b/edml/config/experiment/resnet20_cifar100.yaml @@ -0,0 +1,29 @@ +project: cifar100_with_resnet20 +name: swarm_test +job: train +loss_fn: cross_entropy +optimizer: sgd +momentum: 0.9 +weight_decay: 0.0001 +batch_size: 64 +learning_rate: 0.1 +scheduler: multistep +scheduler_milestones: [ 100, 150 ] +scheduler_gamma: 0.1 +max_epochs: 1 +max_rounds: 200 +metrics: [ accuracy ] +load_weights: False +save_weights: True +server_model_load_path: "edml/models/weights/Resnet18_Server_random_weights.pth" +client_model_load_path: "edml/models/weights/Resnet18_Client_random_weights.pth" +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +partition: True +fractions: [ 0.125, 0.125, 0.125, 0.125, 0.5 ] +random_seed: 42 +load_single_batch_for_debugging: False +early_stopping: True # use only for saving best model +early_stopping_patience: 200 +early_stopping_metric: accuracy +latency: [ 5, 5, 5, 5, 0 ] diff --git a/edml/config/experiment/resnet_fed_vs_split.yaml b/edml/config/experiment/resnet_fed_vs_split.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bfbee778c31c3d0d53daf973161bfb8d9d09c42a --- /dev/null +++ b/edml/config/experiment/resnet_fed_vs_split.yaml @@ -0,0 +1,29 @@ +project: resnet_with_many_devices +name: swarm_test +job: train +loss_fn: cross_entropy +optimizer: sgd +momentum: 0.9 +weight_decay: 0.0001 +batch_size: 64 +learning_rate: 0.1 +scheduler: multistep +scheduler_milestones: [ 100, 150 ] +scheduler_gamma: 0.1 +max_epochs: 1 +max_rounds: 200 +metrics: [ accuracy ] +load_weights: False +save_weights: True +server_model_load_path: "edml/models/weights/initial/Resnet_Server_random_weights.pth" +client_model_load_path: "edml/models/weights/initial/Resnet_Client_random_weights.pth" +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +partition: True +fractions: !!null +random_seed: 42 +load_single_batch_for_debugging: False +early_stopping: True # use only for saving best model +early_stopping_patience: 200 +early_stopping_metric: accuracy +latency: !!null diff --git a/edml/config/grpc/default.yaml b/edml/config/grpc/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de05fb8730616b2cfd25569e9e984a91e2798a33 --- /dev/null +++ b/edml/config/grpc/default.yaml @@ -0,0 +1,2 @@ +max_message_length: -1 # unlimited, (same for send and receive) +max_threads: 10 diff --git a/edml/config/loss_fn/bce.yaml b/edml/config/loss_fn/bce.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76a88f11c5e9d24aada781821c06b8ec900c1856 --- /dev/null +++ b/edml/config/loss_fn/bce.yaml @@ -0,0 +1 @@ +_target_: torch.nn.BCELoss diff --git a/edml/config/loss_fn/cross_entropy.yaml b/edml/config/loss_fn/cross_entropy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b44f215526e7d8f8596a59756dd1097196e3dfd7 --- /dev/null +++ b/edml/config/loss_fn/cross_entropy.yaml @@ -0,0 +1 @@ +_target_: torch.nn.CrossEntropyLoss diff --git a/edml/config/loss_fn/nll.yaml b/edml/config/loss_fn/nll.yaml new file mode 100644 index 0000000000000000000000000000000000000000..18a39db606b1ece99b252c60bfa594cc967551d3 --- /dev/null +++ b/edml/config/loss_fn/nll.yaml @@ -0,0 +1,2 @@ +_target_: hydra.utils.get_method +path: torch.nn.functional.nll_loss diff --git a/edml/config/model/resnet110.yaml b/edml/config/model/resnet110.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c4475f3016ba43542f1fffc59329ac4951b2e318 --- /dev/null +++ b/edml/config/model/resnet110.yaml @@ -0,0 +1,2 @@ +name: resnet110 +cut_layer: 4 diff --git a/edml/config/model/resnet1202.yaml b/edml/config/model/resnet1202.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a159e1dea9fa07d265890de48f0f1c03987c270 --- /dev/null +++ b/edml/config/model/resnet1202.yaml @@ -0,0 +1,2 @@ +name: resnet1202 +cut_layer: 4 diff --git a/edml/config/model/resnet20.yaml b/edml/config/model/resnet20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9e0e12fbc507ee0527f1fd8630c0dfca4e6344e --- /dev/null +++ b/edml/config/model/resnet20.yaml @@ -0,0 +1,2 @@ +name: resnet20 +cut_layer: 4 diff --git a/edml/config/model/resnet32.yaml b/edml/config/model/resnet32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d9ea97fb8aadd5d96e6a002c56ab3d9a5559353 --- /dev/null +++ b/edml/config/model/resnet32.yaml @@ -0,0 +1,2 @@ +name: resnet32 +cut_layer: 4 diff --git a/edml/config/model/resnet44.yaml b/edml/config/model/resnet44.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e9946919b8edadf9f6eaf9edd5ad4c0616ad042e --- /dev/null +++ b/edml/config/model/resnet44.yaml @@ -0,0 +1,2 @@ +name: resnet44 +cut_layer: 4 diff --git a/edml/config/model/resnet56.yaml b/edml/config/model/resnet56.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4441a607eea90a93b074e816ddaa0c23ce75e079 --- /dev/null +++ b/edml/config/model/resnet56.yaml @@ -0,0 +1,2 @@ +name: resnet56 +cut_layer: 4 diff --git a/edml/config/model/simple_conv.yaml b/edml/config/model/simple_conv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a743bebb8e8a95c0f70b102e968ba88b066ca0df --- /dev/null +++ b/edml/config/model/simple_conv.yaml @@ -0,0 +1 @@ +name: simple_conv diff --git a/edml/config/model/tcn.yaml b/edml/config/model/tcn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9d03e3bcc91e5b6070a9922e1517f84c33062e78 --- /dev/null +++ b/edml/config/model/tcn.yaml @@ -0,0 +1 @@ +name: tcn diff --git a/edml/config/model_provider/mnist.yaml b/edml/config/model_provider/mnist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67ca5345ff4e472787c15f5d98478898caabcc3c --- /dev/null +++ b/edml/config/model_provider/mnist.yaml @@ -0,0 +1,5 @@ +_target_: edml.models.provider.base.ModelProvider +client: + _target_: edml.models.mnist_models.ClientNet +server: + _target_: edml.models.mnist_models.ServerNet diff --git a/edml/config/model_provider/resnet110-with-autoencoder.yaml b/edml/config/model_provider/resnet110-with-autoencoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..318dc714baa29b050696b73cb4f94ca3c3079542 --- /dev/null +++ b/edml/config/model_provider/resnet110-with-autoencoder.yaml @@ -0,0 +1,22 @@ +_target_: edml.models.provider.autoencoder.AutoencoderModelProvider +model_provider: + # TODO: can this include other files next to it? + _target_: edml.models.provider.cut_layer.CutLayerModelProvider + model: + _target_: edml.models.resnet_models.ResNet + block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock + num_blocks: [ 18, 18, 18 ] + num_classes: 100 + cut_layer: 4 +decoder: + _target_: edml.models.provider.path.SerializedModel + model: + _target_: edml.models.partials.resnet.Decoder + path: resnet_decoder.pth +encoder: + _target_: edml.models.provider.path.SerializedModel + model: + _target_: edml.models.partials.resnet.Encoder + path: resnet_encoder.pth diff --git a/edml/config/model_provider/resnet110.yaml b/edml/config/model_provider/resnet110.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a436b97dcbe01ccf3dc9ae1b50b4ebea084ad338 --- /dev/null +++ b/edml/config/model_provider/resnet110.yaml @@ -0,0 +1,9 @@ +_target_: edml.models.provider.cut_layer.CutLayerModelProvider +model: + _target_: edml.models.resnet_models.ResNet + block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock + num_blocks: [ 18, 18, 18 ] + num_classes: 100 +cut_layer: 4 diff --git a/edml/config/model_provider/resnet20-with-autoencoder.yaml b/edml/config/model_provider/resnet20-with-autoencoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a56b77ed33320a14851f91edaa40286e8b9c32aa --- /dev/null +++ b/edml/config/model_provider/resnet20-with-autoencoder.yaml @@ -0,0 +1,22 @@ +_target_: edml.models.provider.autoencoder.AutoencoderModelProvider +model_provider: + # TODO: can this include other files next to it? + _target_: edml.models.provider.cut_layer.CutLayerModelProvider + model: + _target_: edml.models.resnet_models.ResNet + block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock + num_blocks: [ 3, 3, 3 ] + num_classes: 100 + cut_layer: 4 +decoder: + _target_: edml.models.provider.path.SerializedModel + model: + _target_: edml.models.partials.resnet.Decoder + path: resnet_decoder.pth +encoder: + _target_: edml.models.provider.path.SerializedModel + model: + _target_: edml.models.partials.resnet.Encoder + path: resnet_encoder.pth diff --git a/edml/config/model_provider/resnet20.yaml b/edml/config/model_provider/resnet20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..144ba60783f58c22e791f84c67725e779502a648 --- /dev/null +++ b/edml/config/model_provider/resnet20.yaml @@ -0,0 +1,9 @@ +_target_: edml.models.provider.cut_layer.CutLayerModelProvider +model: + _target_: edml.models.resnet_models.ResNet + block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock + num_blocks: [ 3, 3, 3 ] + num_classes: 100 +cut_layer: 4 diff --git a/edml/config/model_provider/tcn.yaml b/edml/config/model_provider/tcn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7dcafcd88ce7028fc97fe54b97f08c226162d726 --- /dev/null +++ b/edml/config/model_provider/tcn.yaml @@ -0,0 +1,7 @@ +_target_: edml.models.provider.base.ModelProvider +client: + _target_: edml.models.tcn_models.Small_TCN_5_Client + n_inputs: ${dataset.n_inputs} # TODO: think about these configurations and where to put them. +server: + _target_: edml.models.tcn_models.Small_TCN_5_Server + n_inputs: ${dataset.n_inputs} # TODO: think about these configurations and where to put them. diff --git a/edml/config/models/resnet110.yaml b/edml/config/models/resnet110.yaml new file mode 100644 index 0000000000000000000000000000000000000000..85f3b6b77aa95f0db3963838c6302854cc305b17 --- /dev/null +++ b/edml/config/models/resnet110.yaml @@ -0,0 +1,6 @@ +_target_: edml.models.resnet_models.ResNet +block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock +num_blocks: [18, 18, 18] +num_classes: 100 diff --git a/edml/config/models/resnet1202.yaml b/edml/config/models/resnet1202.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e31218e42c72dfe24b95f4bf45c0e142f14b4449 --- /dev/null +++ b/edml/config/models/resnet1202.yaml @@ -0,0 +1,6 @@ +_target_: edml.models.resnet_models.ResNet +block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock +num_blocks: [200, 200, 200] +num_classes: 100 diff --git a/edml/config/models/resnet20.yaml b/edml/config/models/resnet20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab5d28bb89aef794c6cea5fb57030f57f2e2d863 --- /dev/null +++ b/edml/config/models/resnet20.yaml @@ -0,0 +1,6 @@ +_target_: edml.models.resnet_models.ResNet +block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock +num_blocks: [3, 3, 3] +num_classes: 100 diff --git a/edml/config/models/resnet32.yaml b/edml/config/models/resnet32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1550e92a828a718c44f6a2fbc14ba60240baf2f --- /dev/null +++ b/edml/config/models/resnet32.yaml @@ -0,0 +1,6 @@ +_target_: edml.models.resnet_models.ResNet +block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock +num_blocks: [5, 5, 5] +num_classes: 100 diff --git a/edml/config/models/resnet44.yaml b/edml/config/models/resnet44.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aa880d5bbe2e34981b830affd0a0dc45f34c68ce --- /dev/null +++ b/edml/config/models/resnet44.yaml @@ -0,0 +1,6 @@ +_target_: edml.models.resnet_models.ResNet +block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock +num_blocks: [7, 7, 7] +num_classes: 100 diff --git a/edml/config/models/resnet56.yaml b/edml/config/models/resnet56.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9763b494ab116d1fb435c2a9f7d0192e0aa5905 --- /dev/null +++ b/edml/config/models/resnet56.yaml @@ -0,0 +1,6 @@ +_target_: edml.models.resnet_models.ResNet +block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock +num_blocks: [9, 9, 9] +num_classes: 100 diff --git a/edml/config/models/tcn.yaml b/edml/config/models/tcn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/config/seed/default.yaml b/edml/config/seed/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d96a47d6fc03263e11a92912f0fe8b7f0a4d01a9 --- /dev/null +++ b/edml/config/seed/default.yaml @@ -0,0 +1,2 @@ +value: 1 +torch_deterministic: True diff --git a/edml/config/topology/50_devices.yaml b/edml/config/topology/50_devices.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8cfbab79de6275389a0a9b35a56b67e8c8782c0 --- /dev/null +++ b/edml/config/topology/50_devices.yaml @@ -0,0 +1,302 @@ +devices: [ + { + device_id: "d0", + address: "localhost:50051", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d1", + address: "localhost:50052", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d2", + address: "localhost:50053", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d3", + address: "localhost:50054", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d4", + address: "localhost:50055", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d5", + address: "localhost:50056", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d6", + address: "localhost:50057", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d7", + address: "localhost:50058", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d8", + address: "localhost:50059", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d9", + address: "localhost:50060", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d10", + address: "localhost:50061", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d11", + address: "localhost:50062", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d12", + address: "localhost:50063", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d13", + address: "localhost:50064", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d14", + address: "localhost:50065", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d15", + address: "localhost:50066", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d16", + address: "localhost:50067", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d17", + address: "localhost:50068", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d18", + address: "localhost:50069", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d19", + address: "localhost:50070", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d20", + address: "localhost:50071", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d21", + address: "localhost:50072", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d22", + address: "localhost:50073", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d23", + address: "localhost:50074", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d24", + address: "localhost:50075", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d25", + address: "localhost:50076", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d26", + address: "localhost:50077", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d27", + address: "localhost:50078", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d28", + address: "localhost:50079", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d29", + address: "localhost:50080", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d30", + address: "localhost:50081", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d31", + address: "localhost:50082", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d32", + address: "localhost:50083", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d33", + address: "localhost:50084", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d34", + address: "localhost:50085", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d35", + address: "localhost:50086", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d36", + address: "localhost:50087", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d37", + address: "localhost:50088", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d38", + address: "localhost:50089", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d39", + address: "localhost:50090", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d40", + address: "localhost:50091", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d41", + address: "localhost:50092", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d42", + address: "localhost:50093", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d43", + address: "localhost:50094", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d44", + address: "localhost:50095", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d45", + address: "localhost:50096", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d46", + address: "localhost:50097", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d47", + address: "localhost:50098", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d48", + address: "localhost:50099", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d49", + address: "localhost:50100", + battery_capacity: 1000000, + torch_device: cuda:1 + }, +] diff --git a/edml/config/topology/equal_batteries.yaml b/edml/config/topology/equal_batteries.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a43e5d82cf8112e07ad834229aa8312cedfae46 --- /dev/null +++ b/edml/config/topology/equal_batteries.yaml @@ -0,0 +1,28 @@ +num_clients: 2 # Deprecated +devices: [ + { + device_id: "d0", + address: "localhost:50051", + battery_capacity: 45000, + }, + { + device_id: "d1", + address: "localhost:50052", + battery_capacity: 45000, + }, + { + device_id: "d2", + address: "localhost:50053", + battery_capacity: 45000, + }, + { + device_id: "d3", + address: "localhost:50054", + battery_capacity: 45000, + }, + { + device_id: "d4", + address: "localhost:50055", + battery_capacity: 45000, + } +] diff --git a/edml/config/topology/resnet20_cifar100_batteries.yaml b/edml/config/topology/resnet20_cifar100_batteries.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b1026da83e668e9bfb41e3c1b148e23866a532f --- /dev/null +++ b/edml/config/topology/resnet20_cifar100_batteries.yaml @@ -0,0 +1,27 @@ +devices: [ + { + device_id: "d0", + address: "localhost:50051", + battery_capacity: 5500 + }, + { + device_id: "d1", + address: "localhost:50052", + battery_capacity: 5500 + }, + { + device_id: "d2", + address: "localhost:50053", + battery_capacity: 5500 + }, + { + device_id: "d3", + address: "localhost:50054", + battery_capacity: 5500 + }, + { + device_id: "d4", + address: "localhost:50055", + battery_capacity: 5500 + } +] diff --git a/edml/config/topology/resnet_fed_vs_split_batteries.yaml b/edml/config/topology/resnet_fed_vs_split_batteries.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f83c4736305ffdc7e8bad0f54ab61d66b553592 --- /dev/null +++ b/edml/config/topology/resnet_fed_vs_split_batteries.yaml @@ -0,0 +1,252 @@ +devices: [ + { + device_id: "d0", + address: "localhost:50051", + battery_capacity: 1000000 + }, + { + device_id: "d1", + address: "localhost:50052", + battery_capacity: 1000000 + }, + { + device_id: "d2", + address: "localhost:50053", + battery_capacity: 1000000 + }, + { + device_id: "d3", + address: "localhost:50054", + battery_capacity: 1000000 + }, + { + device_id: "d4", + address: "localhost:50055", + battery_capacity: 1000000 + }, + { + device_id: "d5", + address: "localhost:50056", + battery_capacity: 1000000 + }, + { + device_id: "d6", + address: "localhost:50057", + battery_capacity: 1000000 + }, + { + device_id: "d7", + address: "localhost:50058", + battery_capacity: 1000000 + }, + { + device_id: "d8", + address: "localhost:50059", + battery_capacity: 1000000 + }, + { + device_id: "d9", + address: "localhost:50060", + battery_capacity: 1000000 + }, + { + device_id: "d10", + address: "localhost:50061", + battery_capacity: 1000000 + }, + { + device_id: "d11", + address: "localhost:50062", + battery_capacity: 1000000 + }, + { + device_id: "d12", + address: "localhost:50063", + battery_capacity: 1000000 + }, + { + device_id: "d13", + address: "localhost:50064", + battery_capacity: 1000000 + }, + { + device_id: "d14", + address: "localhost:50065", + battery_capacity: 1000000 + }, + { + device_id: "d15", + address: "localhost:50066", + battery_capacity: 1000000 + }, + { + device_id: "d16", + address: "localhost:50067", + battery_capacity: 1000000 + }, + { + device_id: "d17", + address: "localhost:50068", + battery_capacity: 1000000 + }, + { + device_id: "d18", + address: "localhost:50069", + battery_capacity: 1000000 + }, + { + device_id: "d19", + address: "localhost:50070", + battery_capacity: 1000000 + }, + { + device_id: "d20", + address: "localhost:50071", + battery_capacity: 1000000 + }, + { + device_id: "d21", + address: "localhost:50072", + battery_capacity: 1000000 + }, + { + device_id: "d22", + address: "localhost:50073", + battery_capacity: 1000000 + }, + { + device_id: "d23", + address: "localhost:50074", + battery_capacity: 1000000 + }, + { + device_id: "d24", + address: "localhost:50075", + battery_capacity: 1000000 + }, + { + device_id: "d25", + address: "localhost:50076", + battery_capacity: 1000000 + }, + { + device_id: "d26", + address: "localhost:50077", + battery_capacity: 1000000 + }, + { + device_id: "d27", + address: "localhost:50078", + battery_capacity: 1000000 + }, + { + device_id: "d28", + address: "localhost:50079", + battery_capacity: 1000000 + }, + { + device_id: "d29", + address: "localhost:50080", + battery_capacity: 1000000 + }, + { + device_id: "d30", + address: "localhost:50081", + battery_capacity: 1000000 + }, + { + device_id: "d31", + address: "localhost:50082", + battery_capacity: 1000000 + }, + { + device_id: "d32", + address: "localhost:50083", + battery_capacity: 1000000 + }, + { + device_id: "d33", + address: "localhost:50084", + battery_capacity: 1000000 + }, + { + device_id: "d34", + address: "localhost:50085", + battery_capacity: 1000000 + }, + { + device_id: "d35", + address: "localhost:50086", + battery_capacity: 1000000 + }, + { + device_id: "d36", + address: "localhost:50087", + battery_capacity: 1000000 + }, + { + device_id: "d37", + address: "localhost:50088", + battery_capacity: 1000000 + }, + { + device_id: "d38", + address: "localhost:50089", + battery_capacity: 1000000 + }, + { + device_id: "d39", + address: "localhost:50090", + battery_capacity: 1000000 + }, + { + device_id: "d40", + address: "localhost:50091", + battery_capacity: 1000000 + }, + { + device_id: "d41", + address: "localhost:50092", + battery_capacity: 1000000 + }, + { + device_id: "d42", + address: "localhost:50093", + battery_capacity: 1000000 + }, + { + device_id: "d43", + address: "localhost:50094", + battery_capacity: 1000000 + }, + { + device_id: "d44", + address: "localhost:50095", + battery_capacity: 1000000 + }, + { + device_id: "d45", + address: "localhost:50096", + battery_capacity: 1000000 + }, + { + device_id: "d46", + address: "localhost:50097", + battery_capacity: 1000000 + }, + { + device_id: "d47", + address: "localhost:50098", + battery_capacity: 1000000 + }, + { + device_id: "d48", + address: "localhost:50099", + battery_capacity: 1000000 + }, + { + device_id: "d49", + address: "localhost:50100", + battery_capacity: 1000000 + } +] diff --git a/edml/config/topology/unequal_batteries.yaml b/edml/config/topology/unequal_batteries.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e0802c096bf7f639c14c33e1903fbbce56b7a29 --- /dev/null +++ b/edml/config/topology/unequal_batteries.yaml @@ -0,0 +1,28 @@ +num_clients: 2 # Deprecated +devices: [ + { + device_id: "d0", + address: "localhost:50051", + battery_capacity: 5750 + }, + { + device_id: "d1", + address: "localhost:50052", + battery_capacity: 4750 + }, + { + device_id: "d2", + address: "localhost:50053", + battery_capacity: 3750 + }, + { + device_id: "d3", + address: "localhost:50054", + battery_capacity: 2750 + }, + { + device_id: "d4", + address: "localhost:50055", + battery_capacity: 1750 + } +] diff --git a/edml/config/wandb/default.yaml b/edml/config/wandb/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f38d0c9032f161a7f6b69381d1eaf5a83be5c51 --- /dev/null +++ b/edml/config/wandb/default.yaml @@ -0,0 +1,3 @@ +enabled: False +key_path: wandb_key.txt +entity: !!null diff --git a/edml/controllers/__init__.py b/edml/controllers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/controllers/adaptive_threshold_mechanism/__init__.py b/edml/controllers/adaptive_threshold_mechanism/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e85555cd7127d35fa7be4ffb2d0e1ce9c1517fd --- /dev/null +++ b/edml/controllers/adaptive_threshold_mechanism/__init__.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + + +class AdaptiveThresholdFn(ABC): + """A function that returns the adaptive threshold value based on the current round.""" + + @abstractmethod + def invoke(self, round_no: int) -> float: + """Return the adaptive threshold value for the given round number.""" diff --git a/edml/controllers/adaptive_threshold_mechanism/dynamic.py b/edml/controllers/adaptive_threshold_mechanism/dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..b73e3e1cfdb82a9183d7ad7a4afdecb6dfe2c91d --- /dev/null +++ b/edml/controllers/adaptive_threshold_mechanism/dynamic.py @@ -0,0 +1,19 @@ +import numpy as np + +from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn + + +class LogarithmicDecayAdaptiveThresholdFn(AdaptiveThresholdFn): + + def __init__( + self, starting_value: float, approach_value: float, decay_rate: float = 1.0 + ): + super().__init__() + self._start = starting_value + self._end = approach_value + self._decay_rate = decay_rate + + def invoke(self, round_no: int): + return self._end + (self._start - self._end) * np.exp( + -self._decay_rate * round_no + ) diff --git a/edml/controllers/adaptive_threshold_mechanism/static.py b/edml/controllers/adaptive_threshold_mechanism/static.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6e78e12d9fc7238b231f591faccf5b8c6b1562 --- /dev/null +++ b/edml/controllers/adaptive_threshold_mechanism/static.py @@ -0,0 +1,10 @@ +from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn + + +class StaticAdaptiveThresholdFn(AdaptiveThresholdFn): + def __init__(self, threshold: float): + super().__init__() + self._threshold = threshold + + def invoke(self, round_no: int) -> float: + return self._threshold diff --git a/edml/controllers/base_controller.py b/edml/controllers/base_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..a2adca85736844b23f40a7c8c5c1a5c4574ef280 --- /dev/null +++ b/edml/controllers/base_controller.py @@ -0,0 +1,203 @@ +import abc +import concurrent.futures +from typing import Optional + +import torch + +from edml.controllers.early_stopping import create_early_stopping_callback +from edml.core.device import DeviceRequestDispatcher +from edml.core.start_device import _get_models +from edml.helpers.logging import SimpleLogger, create_logger +from edml.helpers.metrics import ModelMetricResultContainer +from edml.helpers.types import DeviceBatteryStatus + + +class BaseController(abc.ABC): + """Base class for all controllers. Provides the basic functionality for training on multiple devices. + Specific training can be implemented by overriding the _train method.""" + + def __init__(self, cfg): + self.cfg = cfg + self.cfg.own_device_id = ( + "controller" # set own device id here since needed by the logger + ) + self.logger: SimpleLogger = create_logger(cfg) + self.devices = cfg.topology.devices[: cfg.num_devices] + self.request_dispatcher = DeviceRequestDispatcher(self.devices) + self.active_devices = self.request_dispatcher.active_devices() + self.device_batteries_status = {} + self.early_stopping = create_early_stopping_callback(cfg) + + def train(self): + """Runs the training algorithm implemented by the controller.""" + print("Starting training") + if ( + "load_weights" in self.cfg.experiment + and not self.cfg.experiment.load_weights + ): + self.logger.log( + "WARNING: loading models via experiment configurations is deprecated. Please use a " + "dedicated model provider instead." + ) + + # if no weights are loaded, initialize the models randomly and set them on all devices + client_model, server_model = _get_models(self.cfg) + self._set_weights_on_all_devices( + client_model.state_dict(), on_client=True, wait_for_ready=True + ) + self._set_weights_on_all_devices( + server_model.state_dict(), on_client=False, wait_for_ready=True + ) + self._start_experiment_on_active_devices() + self.logger.start_experiment() # start controller logging to wandb + self._refresh_active_devices() + print(f"Active devices: {self.active_devices}") + self._train() + self._end_experiment_on_active_devices() + self.logger.end_experiment() + + @abc.abstractmethod + def _train(self): + """Hook for the training algorithm of the specific controller.""" + pass + + def _start_experiment_on_active_devices(self): + """Starts the experiment on all given devices in parallel.""" + print(f"Starting experiment on {self.active_devices}") + with concurrent.futures.ThreadPoolExecutor( + max_workers=max(len(self.active_devices), 1) + ) as executor: # avoid exception when setting 0 workers + for device_id in self.active_devices: + executor.submit( + self.request_dispatcher.start_experiment_on, device_id, True + ) + + def _end_experiment_on_active_devices(self): + """Ends the experiment on all given devices.""" + for device_id in self.active_devices: + self.request_dispatcher.end_experiment_on(device_id) + + def _get_battery_status(self): + """Returns the battery status of all active devices. Inactive devices receive None as battery status.""" + battery_status: dict[str, Optional[DeviceBatteryStatus]] = {} + for device_id in self.active_devices: + status = self.request_dispatcher.get_battery_status_on(device_id) + if status is not False: + battery_status[device_id] = status + else: + battery_status[device_id] = None + return battery_status + + def _update_devices_battery_status(self): + """Updates the battery status of all active devices. + Inactive devices receive None as battery status. Also refreshes the list of active devices. + """ + self.device_batteries_status = { + device_id: None for device_id in self.device_batteries_status + } # reset to None to detect inactive devices + self.device_batteries_status = ( + self.device_batteries_status | self._get_battery_status() + ) # update with new values + self._refresh_active_devices() + + def _refresh_active_devices(self): + """Refreshes the active devices by checking if the request dispatcher is still connected to them. + Device failure is only detected after the next request to the device.""" + self.active_devices = self.request_dispatcher.active_devices() + + def _get_device_ids(self) -> list[str]: + return [device.device_id for device in self.devices] + + def _save_weights(self, client_weights, server_weights, round_no: int): + """ + Saves the weights of the given round if saving weights is configured. + + Args: + client_weights: The weights of the client model. + server_weights: The weights of the server model. + round_no: The number of the current round. + + Returns: + None + + Raises: + None + + Notes: + If early stopping is configured, the weights are only saved if the current round is the best round so far. + """ + if self.cfg.experiment.save_weights: + print("\n###SAVING WEIGHTS###") + if ( + self.cfg.experiment.early_stopping + and self.early_stopping.best_epoch != round_no + ): + print(" triggered early stopping") + return + if client_weights: + print(" saving client weights...") + torch.save( + client_weights, + f"{self.cfg.experiment.client_model_save_path}{self.__model_prefix__()}_client_{round_no}.pth", + ) + if server_weights: + print(" saving server weights...") + torch.save( + server_weights, + f"{self.cfg.experiment.server_model_save_path}{self.__model_prefix__()}_server_{round_no}.pth", + ) + print("\n") + + def _load_weights(self, round_no: int): + """Loads the weights from the configured directory if loading weights is configured.""" + return ( + torch.load( + f"{self.cfg.experiment.client_model_save_path}{self.__model_prefix__()}_client_{round_no}.pth" + ), + torch.load( + f"{self.cfg.experiment.server_model_save_path}{self.__model_prefix__()}_server_{round_no}.pth" + ), + ) + + def _set_weights_on_all_devices( + self, weights, on_client=True, wait_for_ready=False + ): + """Sets the weights on all devices.""" + for device_id in self.active_devices: + self.request_dispatcher.set_weights_on( + device_id, weights, on_client, wait_for_ready=wait_for_ready + ) + + def _devices_empty_or_only_server_left(self, server_device_id): + if len(self.active_devices) == 0: + return True + elif ( + len(self.active_devices) == 1 and self.active_devices[0] == server_device_id + ): + return True + return False + + def _aggregate_and_log_metrics( + self, metrics: Optional[ModelMetricResultContainer], round_no: int + ): + """ + Aggregates and logs the metrics of the current round. + + Args: + metrics (ModelMetricResultContainer): The metrics of the current epoch. + round_no (int): The number of the current round. + Returns: + None + Raises: + None + Notes: + None + """ + if metrics is not None: + aggregated_metrics = metrics.get_aggregated_metrics() + for metric_result in aggregated_metrics.get_as_list(): + self.logger.log(metric_result.as_loggable_dict(round_no)) + + def __model_prefix__(self): + """Returns the model prefix for the current experiment.""" + return f"{self.cfg.experiment.project}_{self.cfg.group}" diff --git a/edml/controllers/early_stopping.py b/edml/controllers/early_stopping.py new file mode 100644 index 0000000000000000000000000000000000000000..058c1d5b4b9f6aa3578c19bcc3dafab77eead215 --- /dev/null +++ b/edml/controllers/early_stopping.py @@ -0,0 +1,97 @@ +import warnings + +from edml.helpers.metrics import ModelMetricResultContainer + + +class EarlyStopping(object): + def __init__(self, metric: str, patience=10, phase="val"): + """ + Implementation of early stopping for the training loop. + + Args: + metric (string): The metric to be monitored. + patience (int): The number of epochs of no improvement to stop after. + phase (string): The phase to monitor the metric on. + + Returns: + None + + Raises: + ValueError: If the given metric is None. + + Notes: + Only works if the metric of interest is present in the given metrics when called later on. + Uses the metric values obtained during validation by default. + """ + if metric is None: + raise ValueError("Early stopping metric must not be None") + self.metric = metric + self.patience = patience + self.phase = phase + self.best_score = None + self.best_epoch = None + self.counter = 0 + self.early_stop = False + + def __call__(self, metrics: ModelMetricResultContainer, epoch): + """ + Implementation of early stopping for the training loop. + + Args: + metrics (ModelMetricResultContainer): The metrics of the current epoch. + epoch (int): The number the current epoch. + + Returns: + bool: Whether to stop training. + + Raises: + Warning if the metric of interest is not present in the given metrics. + Notes: + Requires the metric of interest (in the defined phase) to be present in the given metrics. + Otherwise, a warning is raised and the round is not counted. + """ + try: + score = ( + metrics.get_aggregated_metrics() + .results[(self.metric, self.phase)][0] + .value + ) + except KeyError: + warnings.warn( + f"Early stopping metric {self.metric} not found in given metrics. Skipping this round." + ) + return False + if self.best_score is None: + self.best_score = score + self.best_epoch = epoch + elif score <= self.best_score: + self.counter += 1 + if self.counter >= self.patience: + return True + else: + self.best_score = score + self.best_epoch = epoch + self.counter = 0 + return False + + +def create_early_stopping_callback(cfg): + """ + A factory function to create an early stopping callback. + + Args: + cfg (OmegaConf): The configuration object. + + Returns: + EarlyStopping: The early stopping callback. + Raises: + None + Notes: + If early stopping is deactivated, a callback that always returns false is returned. + """ + if cfg.experiment.early_stopping: + return EarlyStopping( + metric=cfg.experiment.early_stopping_metric, + patience=cfg.experiment.early_stopping_patience, + ) + return lambda *args: False diff --git a/edml/controllers/fed_controller.py b/edml/controllers/fed_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf6b6e86d2df4d59f5ea025288fb5e28d3d4efc --- /dev/null +++ b/edml/controllers/fed_controller.py @@ -0,0 +1,113 @@ +import concurrent.futures +import threading +from typing import Dict, List + +from overrides import override + +from edml.controllers.base_controller import BaseController +from edml.helpers.metrics import ModelMetricResultContainer + + +def _drop_empty_weights(weights: list[Dict]) -> list[Dict]: + """Drops all weights that are empty.""" + return [weight for weight in weights if weight is not None and len(weight) > 0] + + +def fed_average(model_weights: list[Dict], weighting_scheme: List[float] = None): + """Computes the federated average of the given weights. + If a weighting scheme (percentage or number of examples per model) is specified, the weighted average is computed. + """ + averaged_weights = {} + model_weights = _drop_empty_weights(model_weights) + if len(model_weights) > 0: + if weighting_scheme is None: + for key in model_weights[0].keys(): + averaged_weights[key] = sum( + weight[key] for weight in model_weights + ) / len(model_weights) + else: + for key in model_weights[0].keys(): + averaged_weights[key] = sum( + weight[key] * weighting_scheme[i] + for i, weight in enumerate(model_weights) + ) / sum(weighting_scheme) + return averaged_weights + return None + + +class FedController(BaseController): + """Controller for federated learning.""" + + def __init__(self, cfg): + super().__init__(cfg) + + def _fed_train_round(self, round_no: int = -1): + """Returns new client and server weights.""" + client_weights_lock = threading.Lock() + server_weights_lock = threading.Lock() + samples_count_lock = threading.Lock() + metrics_lock = threading.Lock() + client_weights = [] + server_weights = [] + samples_count = [] + metrics_container = ModelMetricResultContainer() + with concurrent.futures.ThreadPoolExecutor( + max_workers=max(len(self.active_devices), 1) + ) as executor: # avoid exception when setting 0 workers + futures = [ + executor.submit( + self.request_dispatcher.federated_train_on, device_id, round_no + ) + for device_id in self.active_devices + ] + for future in concurrent.futures.as_completed(futures): + response = future.result() + if response is not False: + new_client_weights, new_server_weights, num_samples, metrics, _ = ( + response # skip diagnostic metrics + ) + with client_weights_lock: + client_weights.append(new_client_weights) + with server_weights_lock: + server_weights.append(new_server_weights) + with samples_count_lock: + samples_count.append(num_samples) + with metrics_lock: + metrics_container.merge(metrics) + + print(f"samples count {samples_count}") + + return ( + fed_average(model_weights=client_weights, weighting_scheme=samples_count), + fed_average(model_weights=server_weights, weighting_scheme=samples_count), + metrics_container, + ) + + @override + def _train(self): + """Runs the federated training algorithm for the configured number of rounds.""" + print("Training with FedController") + for i in range(self.cfg.experiment.max_rounds): + print(f"Round {i}") + self._refresh_active_devices() + + avg_client_weights, avg_server_weights, metrics = self._fed_train_round( + round_no=i + ) + + self.logger.log( + {"remaining_devices": {"devices": len(self.active_devices), "round": i}} + ) + + self._aggregate_and_log_metrics(metrics, i) + + early_stop = self.early_stopping(metrics, i) + if early_stop: + break + + self._save_weights(avg_client_weights, avg_server_weights, i) + self._refresh_active_devices() + + # set new weights + self._set_weights_on_all_devices(avg_client_weights, on_client=True) + self._set_weights_on_all_devices(avg_server_weights, on_client=False) diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..82bf519a06dc456d79ba25368914c376236275c9 --- /dev/null +++ b/edml/controllers/parallel_split_controller.py @@ -0,0 +1,93 @@ +from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn +from edml.controllers.adaptive_threshold_mechanism.static import ( + StaticAdaptiveThresholdFn, +) +from edml.controllers.base_controller import BaseController +from edml.controllers.scheduler.base import NextServerScheduler +from edml.helpers.config_helpers import get_device_index_by_id + + +class ParallelSplitController(BaseController): + def __init__( + self, + cfg, + scheduler: NextServerScheduler, + adaptive_threshold_fn: AdaptiveThresholdFn = StaticAdaptiveThresholdFn(0.0), + ): + super().__init__(cfg) + scheduler.initialize(self) + self._next_server_scheduler = scheduler + self._adaptive_threshold_fn = adaptive_threshold_fn + + def _train(self): + client_weights = None + server_weights = None + server_device_id = self.cfg.topology.devices[0].device_id + optimizer_state = None + + for i in range(self.cfg.experiment.max_rounds): + print(f"=================================Round {i}") + + # We fetch the newest device information to check and see what active devices are still available. + # After that, we can also update the next server device if applicable. + self._update_devices_battery_status() + # break if no active devices or only server device left + if self._devices_empty_or_only_server_left(server_device_id): + print("No active client devices left.") + break + if self._next_server_scheduler: + server_device_id = self._next_server() + print(f"<> training on server: {server_device_id} <>") + + # set latest server weights once we did a single round of training. + if server_weights is not None: + print(f">>> Propagating newest server weights to {server_device_id}") + self.request_dispatcher.set_weights_on( + device_id=server_device_id, + state_dict=server_weights, + on_client=False, + ) + + # Start parallel training of all client devices. + adaptive_threshold = self._adaptive_threshold_fn.invoke(i) + self.logger.log({"adaptive-threshold": adaptive_threshold}) + training_response = self.request_dispatcher.train_parallel_on_server( + server_device_id=server_device_id, + epochs=1, + round_no=i, + adaptive_learning_threshold=adaptive_threshold, + optimizer_state=optimizer_state, + ) + + self._refresh_active_devices() + self.logger.log( + {"remaining_devices": {"devices": len(self.active_devices), "round": i}} + ) + self.logger.log( + { + "server_device": { + "device": get_device_index_by_id(self.cfg, server_device_id) + }, + "round": i, + } + ) # log the server device index for convenience + + if training_response is False: # server device unavailable + print(f"Training response was false.") + break + else: + cw, server_weights, metrics, optimizer_state, _ = training_response + + self._aggregate_and_log_metrics(metrics, i) + + early_stop = self.early_stopping(metrics, i) + if early_stop: + print(f"Early stopping triggered.") + break + + self._save_weights( + client_weights=cw, server_weights=server_weights, round_no=i + ) + + def _next_server(self) -> str: + return self._next_server_scheduler.next_server(self.active_devices) diff --git a/edml/controllers/scheduler/__init__.py b/edml/controllers/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/controllers/scheduler/base.py b/edml/controllers/scheduler/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0b95b1d7c6d8097b270ba14533e03bb524a27a --- /dev/null +++ b/edml/controllers/scheduler/base.py @@ -0,0 +1,67 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional, Sequence + +if TYPE_CHECKING: + from edml.controllers.base_controller import BaseController + + +class NextServerScheduler(ABC): + """ + This class can be used to implement a custom scheduler algorithm. Controller implementations can then use them to + determine the next server to be used for the round. + + Attributes: + devices: A list of device IDs that can be used for scheduler. + """ + + def __init__(self): + self.devices: Optional[list[str]] = None + + def initialize(self, controller: "BaseController"): + """ + Initializes the scheduler with the list of devices that can be used for scheduler. + + Args: + controller: The controller that is using this scheduler. The instance should not be used outside of this + method. + + Notes: + This method should not be overridden. Instead, the `_initialize` method should be overridden. + """ + self.devices = controller._get_device_ids() + self._initialize(controller) + + def _initialize(self, controller: "BaseController"): + """Custom hook for implementations to initialize themselves.""" + + def next_server(self, active_devices: Sequence[str], **kwargs) -> Optional[str]: + """ + Returns the next active server device. This method only calls the `_next_server` method after verifying the + list of active devices is not empty. + For using multiple schedulers interchangeably, pass all required args for all schedulers as kwargs. + + Args: + active_devices: The list of currently active device IDs. This list should not be empty. + last_server_device_id: Optional kwarg, the id of the last server device. + diagnostic_metric_container: Optional kwarg, a DiagnosticMetricResultContainer. + + Returns: + The next server device ID. In the case that no device is available, `None` is returned. + + Raises: + ValueError: If the list of active devices is empty. + """ + if len(active_devices) == 0: + raise ValueError("list cannot be empty") + return self._next_server(active_devices, **kwargs) + + @abstractmethod + def _next_server(self, active_devices: Sequence[str], **kwargs) -> Optional[str]: + """ + Custom hook to return the next active device. Concrete implementations that need more arguments, + should define these as additional keyword args, so the base signature does not have to change. + However, when relying on polymorphism, the next_server template should receive all keyword args for the + different schedulers. Also, the concrete implementations should define **kwargs in the signature, + so it does not have to change when new kwargs are added. + """ + raise NotImplementedError() diff --git a/edml/controllers/scheduler/dispatcher.py b/edml/controllers/scheduler/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9d70ad3f30614943916ad769f1c81f2c51ae2c --- /dev/null +++ b/edml/controllers/scheduler/dispatcher.py @@ -0,0 +1,16 @@ +from typing import Optional, Sequence + +from edml.controllers.scheduler.base import NextServerScheduler +from edml.helpers.types import KeyedNextServerScheduler + + +class DispatchingNextServerScheduler(NextServerScheduler): + def __init__( + self, schedulers: Sequence[KeyedNextServerScheduler], active_scheduler: str + ): + super().__init__() + self._schedulers = {scheduler.KEY: scheduler for scheduler in schedulers} + self.active_scheduler = active_scheduler + + def _next_server(self, active_devices: list[str]) -> Optional[str]: + return self._schedulers[self.active_scheduler].next_server(active_devices) diff --git a/edml/controllers/scheduler/max_battery.py b/edml/controllers/scheduler/max_battery.py new file mode 100644 index 0000000000000000000000000000000000000000..afe19ff3a1512812bb1fbcee4bc17c5281764bc6 --- /dev/null +++ b/edml/controllers/scheduler/max_battery.py @@ -0,0 +1,40 @@ +from typing import Optional, Callable, Sequence + +from edml.controllers.base_controller import BaseController +from edml.controllers.scheduler.base import NextServerScheduler +from edml.helpers.types import DeviceBatteryStatus + +UpdateBatteryStatusCallback = Callable[[], dict[str, Optional[DeviceBatteryStatus]]] + + +class MaxBatteryNextServerScheduler(NextServerScheduler): + """ + This scheduler selects the server with the highest battery level. + """ + + KEY: str = "max_battery" + + def __init__(self): + super().__init__() + + # We require some kind of callback that allows us to fetch the newest battery levels for all active devices. + # This is set during initialization. + self._update_batteries_cb: Optional[UpdateBatteryStatusCallback] = None + + def _initialize(self, controller: BaseController): + self._update_batteries_cb = controller._get_battery_status + + def _next_server(self, _: Sequence[str], **kwargs) -> Optional[str]: + battery_status = self._get_active_devices_battery_levels() + if len(battery_status) == 0: + return None + # return the device_id with the highest battery level + return max(battery_status, key=lambda device_id: battery_status[device_id]) + + def _get_active_devices_battery_levels(self) -> dict[str, float]: + """Returns the current battery levels of active devices only.""" + return { + key: battery_info.current_capacity + for key, battery_info in self._update_batteries_cb().items() + if battery_info is not None + } diff --git a/edml/controllers/scheduler/random.py b/edml/controllers/scheduler/random.py new file mode 100644 index 0000000000000000000000000000000000000000..2638827fca941b1a1b9ffbe596a6b0fee2522aef --- /dev/null +++ b/edml/controllers/scheduler/random.py @@ -0,0 +1,16 @@ +from random import choice +from typing import Sequence + +from edml.controllers.scheduler.base import NextServerScheduler + + +class RandomNextServerScheduler(NextServerScheduler): + """ + A scheduler that always returns a random server ID from the list of active devices. To seed the random number, use + `random.seed()` before calling this scheduler. + """ + + KEY: str = "random" + + def _next_server(self, active_devices: Sequence[str], **kwargs) -> str: + return choice(active_devices) diff --git a/edml/controllers/scheduler/sequential.py b/edml/controllers/scheduler/sequential.py new file mode 100644 index 0000000000000000000000000000000000000000..f6c4759327af83a704e5e0b38e893d0871932bfd --- /dev/null +++ b/edml/controllers/scheduler/sequential.py @@ -0,0 +1,48 @@ +from typing import Optional, Sequence + +from edml.controllers.scheduler.base import NextServerScheduler + + +class SequentialNextServerScheduler(NextServerScheduler): + """ + A scheduler that always returns the next server by incrementing a wrapping index into the list of active + devices. + + Args: + last_server_device_id (optional): The device ID of the last server device used. If not provided, the first + server device will be the first active device in the list. + """ + + KEY: str = "sequential" + + def __init__(self, last_server_device_id: Optional[str] = None): + super().__init__() + self._last_server_device_id = last_server_device_id + + def _next_server( + self, + active_devices: Sequence[str], + last_server_device_id: Optional[str] = None, + **kwargs + ) -> str: + # Special case if we do not have an initial first server. We simply return the first server ID in the list. + if self._last_server_device_id is None: + next_server = active_devices[0] + self._last_server_device_id = next_server + return next_server + + # Find the index of the last server in the list of all devices. Create a temporary list of devices with all the + # devices after the last server and all the devices before the last server. + # The next server is the first device in the temporary list that is also in the list of active devices. + last_server_index = self.devices.index(self._last_server_device_id) + remaining_devices = ( + self.devices[last_server_index + 1 :] + + self.devices[: last_server_index + 1] + ) + next_server = next( + device_id for device_id in remaining_devices if device_id in active_devices + ) + + # Update the last server ID for the next call. + self._last_server_device_id = next_server + return next_server diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py new file mode 100644 index 0000000000000000000000000000000000000000..48940331247f647a855d72bf3e8a54b44f13ba78 --- /dev/null +++ b/edml/controllers/scheduler/smart.py @@ -0,0 +1,119 @@ +from typing import Optional, Sequence + +from edml.controllers.scheduler.base import NextServerScheduler +from edml.controllers.strategy_optimization import ( + GlobalParams, + ServerChoiceOptimizer, + DeviceParams, +) +from edml.helpers.metrics import ( + DiagnosticMetricResultContainer, + compute_metrics_for_optimization, +) + + +class SmartNextServerScheduler(NextServerScheduler): + """ + This scheduler optimizes the server device selection so that the number of rounds with all devices participating is maximized. + Therefore, training metrics are needed, so in the first round, the device is chosen by the specified fallback scheduler. + Afterward, the optimal selection schedule is computed. If all devices have been picked according to the schedule, + i.e. one device will run out of battery in the next round, again the server device is chosen by the fallback. + """ + + KEY: str = "smart" + + def __init__(self, fallback_scheduler: NextServerScheduler): + super().__init__() + self.fallback_scheduler = fallback_scheduler + self.selection_schedule = None + + def _initialize(self, controller: "BaseController"): + self.cfg = controller.cfg + self._update_batteries_cb = controller._get_battery_status + self._data_model_cb = ( + controller._get_active_devices_dataset_sizes_and_model_flops + ) + self.fallback_scheduler.initialize(controller=controller) + + def _next_server( + self, + active_devices: Sequence[str], + last_server_device_id=None, + diagnostic_metric_container: Optional[DiagnosticMetricResultContainer] = None, + **kwargs, + ) -> str: + if diagnostic_metric_container is None: + return self.fallback_scheduler.next_server(active_devices) + else: + if self.selection_schedule is None or len(self.selection_schedule) == 0: + self.selection_schedule = self._get_selection_schedule( + diagnostic_metric_container, last_server_device_id + ) + print(f"server device schedule: {self.selection_schedule}") + try: + return self.selection_schedule.pop(0) + except IndexError: # no more devices left in schedule + return self.fallback_scheduler.next_server( + active_devices, + diagnostic_metric_container=diagnostic_metric_container, + kwargs=kwargs, + ) + + def _get_selection_schedule( + self, diagnostic_metric_container, last_server_device_id=None + ): + device_params_list = [] + device_battery_levels = self._update_batteries_cb() + # get num samples and flops per device + dataset_sizes, model_flops = self._data_model_cb() + try: + optimization_metrics = compute_metrics_for_optimization( + diagnostic_metric_container, + dataset_sizes, + self.cfg.experiment.batch_size, + ) + except ( + KeyError + ): # if some metrics are not available e.g. because a device ran out of battery + return [] # return empty schedule + for device_id, battery_level in device_battery_levels.items(): + device_params = DeviceParams( + device_id=device_id, + initial_battery=battery_level.initial_capacity, + current_battery=battery_level.current_capacity, + train_samples=dataset_sizes[device_id][0], + validation_samples=dataset_sizes[device_id][1], + comp_latency_factor=optimization_metrics["comp_latency_factor"][ + device_id + ], + ) + device_params_list.append(device_params) + global_params = GlobalParams() + global_params.fill_values_from_config(self.cfg) + global_params.client_model_flops = model_flops["client"] + global_params.server_model_flops = model_flops["server"] + global_params.client_norm_fw_time = optimization_metrics["client_norm_fw_time"] + global_params.client_norm_bw_time = optimization_metrics["client_norm_bw_time"] + global_params.server_norm_fw_time = optimization_metrics["server_norm_fw_time"] + global_params.server_norm_bw_time = optimization_metrics["server_norm_bw_time"] + global_params.gradient_size = optimization_metrics["gradient_size"] + global_params.label_size = optimization_metrics["label_size"] + global_params.smashed_data_size = optimization_metrics["smashed_data_size"] + global_params.client_weights_size = optimization_metrics["client_weight_size"] + global_params.server_weights_size = optimization_metrics["server_weight_size"] + global_params.optimizer_state_size = optimization_metrics[ + "optimizer_state_size" + ] + if "train_global_time" in optimization_metrics.keys(): + global_params.train_global_time = optimization_metrics["train_global_time"] + global_params.last_server_device_id = last_server_device_id + print(f"global params: {vars(global_params)}") + print(f"device params: {[vars(device) for device in device_params_list]}") + server_choice_optimizer = ServerChoiceOptimizer( + device_params_list, global_params + ) + solution, status = server_choice_optimizer.optimize() + schedule = [] + for device_id in solution.keys(): + schedule += [device_id] * int(solution[device_id]) + return schedule diff --git a/edml/controllers/split_controller.py b/edml/controllers/split_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..309f6d95f586417261c01b185ef0f033e4ac8a46 --- /dev/null +++ b/edml/controllers/split_controller.py @@ -0,0 +1,47 @@ +from edml.controllers.base_controller import BaseController + + +class SplitController(BaseController): + + def _train(self): + client_weights = None + server_device = self.cfg.topology.devices[0] + for i in range(self.cfg.experiment.max_rounds): + print(f"Round {i}") + + self._update_devices_battery_status() + # break if no active devices or only server device left + if self._devices_empty_or_only_server_left(server_device.device_id): + print("No active client devices left.") + break + + # set latest client weights on first device to train on + self.request_dispatcher.set_weights_on( + device_id=self.active_devices[0], + state_dict=client_weights, + on_client=True, + ) + + training_response = self.request_dispatcher.train_global_on( + server_device.device_id, epochs=1, round_no=i + ) + + self._refresh_active_devices() + self.logger.log( + {"remaining_devices": {"devices": len(self.active_devices), "round": i}} + ) + + if training_response is False: # server device unavailable + break + else: + client_weights, server_weights, metrics, _, _ = ( + training_response # no need for optimizer state and diagnostic metrics + ) + + self._aggregate_and_log_metrics(metrics, i) + + early_stop = self.early_stopping(metrics, i) + if early_stop: + break + + self._save_weights(client_weights, server_weights, i) diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..eddbf5742a526285cf6a3793b208900beda9f3ec --- /dev/null +++ b/edml/controllers/strategy_optimization.py @@ -0,0 +1,721 @@ +import math + +import numpy as np +from ortools.linear_solver import pywraplp + + +class DeviceParams: + def __init__( + self, + device_id, + initial_battery, + current_battery, + train_samples, + validation_samples, + comp_latency_factor, + ): + self.device_id = device_id + self.initial_battery = initial_battery + self.current_battery = current_battery + self.train_samples = train_samples + self.validation_samples = validation_samples + self.comp_latency_factor = comp_latency_factor + + +class GlobalParams: + def __init__( + self, + cost_per_sec=None, + cost_per_byte_sent=None, + cost_per_byte_received=None, + cost_per_flop=None, + client_model_flops=None, + server_model_flops=None, + smashed_data_size=None, + label_size=None, + gradient_size=None, + batch_size=None, + client_weights_size=None, + server_weights_size=None, + optimizer_state_size=None, + client_norm_fw_time=None, + client_norm_bw_time=None, + server_norm_fw_time=None, + server_norm_bw_time=None, + train_global_time=None, + last_server_device_id=None, + ): + self.cost_per_sec = cost_per_sec + self.cost_per_byte_sent = cost_per_byte_sent + self.cost_per_byte_received = cost_per_byte_received + self.cost_per_flop = cost_per_flop + self.client_model_flops = client_model_flops + self.server_model_flops = server_model_flops + self.smashed_data_size = smashed_data_size + self.batch_size = batch_size + self.client_weights_size = client_weights_size + self.server_weights_size = server_weights_size + self.optimizer_state_size = optimizer_state_size + self.train_global_time = train_global_time + self.last_server_device_id = last_server_device_id + # metrics per sample + self.label_size = label_size + self.gradient_size = gradient_size + self.client_norm_fw_time = client_norm_fw_time + self.client_norm_bw_time = client_norm_bw_time + self.server_norm_fw_time = server_norm_fw_time + self.server_norm_bw_time = server_norm_bw_time + + def fill_values_from_config(self, cfg): + self.cost_per_sec = cfg.battery.deduction_per_second + self.cost_per_byte_sent = cfg.battery.deduction_per_mbyte_sent / 1000000 + self.cost_per_byte_received = cfg.battery.deduction_per_mbyte_received / 1000000 + self.cost_per_flop = cfg.battery.deduction_per_mflop / 1000000 + self.batch_size = cfg.experiment.batch_size + + def all_values_set(self): + return ( + self.cost_per_sec is not None + and self.cost_per_byte_sent is not None + and self.cost_per_byte_received is not None + and self.cost_per_flop is not None + and self.client_model_flops is not None + and self.server_model_flops is not None + and self.optimizer_state_size is not None + and self.smashed_data_size is not None + and self.label_size is not None + and self.gradient_size is not None + and self.batch_size is not None + and self.client_norm_fw_time is not None + and self.client_norm_bw_time is not None + and self.server_norm_fw_time is not None + and self.server_norm_bw_time is not None + ) + + +class ServerChoiceOptimizer: + + def __init__(self, device_params_list, global_params): + self.device_params_list = device_params_list + self.global_params = global_params + + # derived properties + def _num_devices(self): + return len(self.device_params_list) + + def _total_battery(self): + return sum(device.current_battery for device in self.device_params_list) + + def _total_train_dataset_size(self): + return sum(device.train_samples for device in self.device_params_list) + + def _total_validation_dataset_size(self): + return sum(device.validation_samples for device in self.device_params_list) + + def _get_device_params(self, device_id): + for device in self.device_params_list: + if device.device_id == device_id: + return device + return None + + def _transmission_latency(self): + if ( + self.global_params.train_global_time is not None + and self.global_params.last_server_device_id is not None + ): + return ( + self.global_params.train_global_time + - self._round_runtime_with_server_no_latency( + self.global_params.last_server_device_id + ) + ) + return 0 # latency not known + + def _round_runtime_with_server_no_latency(self, server_device_id): + """ + Computes the runtime of a round with the given server device. + Params: + server_device_id: the id of the server device + Returns: + the runtime of a round with the given server device + Notes: + Does not consider any transmission latency and may underestimate the runtime + """ + total_time = 0 + for device in self.device_params_list: + if device.device_id == server_device_id: + # server processes all data with its own speed on server model (fw + bw) + total_time += device.comp_latency_factor * ( + ( + self._total_train_dataset_size() + + self._total_validation_dataset_size() + ) + * self.global_params.server_norm_fw_time + + self._total_train_dataset_size() + * self.global_params.server_norm_bw_time + ) # backprop only on train data + # client time (for every device including server device) + total_time += ( + (device.train_samples + device.validation_samples) + * self.global_params.client_norm_fw_time + + device.train_samples * self.global_params.client_norm_bw_time + ) * device.comp_latency_factor + return total_time + + def round_runtime_with_server(self, server_device_id): + """ + Computes the runtime of a round with the given server device. + Params: + server_device_id: the id of the server device + Returns: + the runtime of a round with the given server device + Notes: + """ + return ( + self._round_runtime_with_server_no_latency(server_device_id) + + self._transmission_latency() + ) + + def num_flops_per_round_on_device(self, device_id, server_device_id): + device = self._get_device_params(device_id) + total_flops = 0 + if device_id == server_device_id: + total_flops += ( + self.global_params.server_model_flops + * self._total_train_dataset_size() + * 3 + ) # fw + 2bw + total_flops += ( + self.global_params.server_model_flops + * self._total_validation_dataset_size() + ) # fw + total_flops += ( + self.global_params.client_model_flops * device.train_samples * 3 + ) # fw + 2bw + total_flops += ( + self.global_params.client_model_flops * device.validation_samples + ) # fw + return total_flops + + def num_bytes_sent_per_round_on_device(self, device_id, server_device_id): + device = self._get_device_params(device_id) + total_bytes = 0 + if device_id == server_device_id: + total_bytes += ( + self.global_params.gradient_size * self._total_train_dataset_size() + ) + total_bytes += self.global_params.client_weights_size * ( + self._num_devices() - 1 + ) # setting client weights before training + # return server weights and optimizer state in the end + total_bytes += self.global_params.server_weights_size + total_bytes += self.global_params.optimizer_state_size + # exclude server device's own gradients + total_bytes -= self.global_params.gradient_size * device.train_samples + else: + total_bytes += self.global_params.label_size * ( + device.train_samples + device.validation_samples + ) + total_bytes += self.global_params.smashed_data_size * ( + device.train_samples + device.validation_samples + ) + total_bytes += ( + self.global_params.client_weights_size + ) # clients return weights to server + return total_bytes + + def num_bytes_received_per_round_on_device(self, device_id, server_device_id): + device = self._get_device_params(device_id) + total_bytes = 0 + if device_id == server_device_id: + total_bytes += self.global_params.label_size * ( + self._total_train_dataset_size() + self._total_validation_dataset_size() + ) + total_bytes += self.global_params.smashed_data_size * ( + self._total_train_dataset_size() + self._total_validation_dataset_size() + ) + total_bytes += self.global_params.client_weights_size * ( + self._num_devices() - 1 + ) # clients return their weights to server + # server weights and optimizer state set once in the beginning + total_bytes += self.global_params.server_weights_size + total_bytes += self.global_params.optimizer_state_size + # exclude server device's own data and labels + total_bytes -= self.global_params.label_size * ( + device.train_samples + device.validation_samples + ) + total_bytes -= self.global_params.smashed_data_size * ( + device.train_samples + device.validation_samples + ) + else: + total_bytes += self.global_params.gradient_size * device.train_samples + total_bytes += ( + self.global_params.client_weights_size + ) # client weights set in the beginning of the training + return total_bytes + + def energy_per_round_on_device(self, device_id, server_device_id): + total_energy = 0 + total_energy += ( + self.num_flops_per_round_on_device(device_id, server_device_id) + * self.global_params.cost_per_flop + ) + total_energy += ( + self.num_bytes_sent_per_round_on_device(device_id, server_device_id) + * self.global_params.cost_per_byte_sent + ) + total_energy += ( + self.num_bytes_received_per_round_on_device(device_id, server_device_id) + * self.global_params.cost_per_byte_received + ) + total_energy += ( + self.round_runtime_with_server(server_device_id) + * self.global_params.cost_per_sec + ) + return total_energy + + def max_rounds_upper_bound(self): + max_rounds = [] + for server_device in self.device_params_list: + max_rounds_with_server = [] + for device in self.device_params_list: + max_rounds_with_server += [ + math.floor( + device.current_battery + / self.energy_per_round_on_device( + device.device_id, server_device.device_id + ) + ) + ] + max_rounds.append(max(max_rounds_with_server)) + return max(max_rounds) + + def optimize(self): + solver = pywraplp.Solver.CreateSolver("SAT") + # variables + # rounds_as_server[i] = number of rounds device i is server + rounds_as_server = {} + for device in self.device_params_list: + rounds_as_server[device.device_id] = solver.IntVar( + 0, self.max_rounds_upper_bound(), f"rounds_{device.device_id}" + ) + # constraints + energy = np.array( + [ + self.energy_per_round_on_device( + device.device_id, server_device.device_id + ) + for device in self.device_params_list + for server_device in self.device_params_list + ] + ) + # energy[i][j] = energy per round for device i with server j + energy = energy.reshape( + (len(self.device_params_list), len(self.device_params_list)) + ) + # for each device, sum up the energy over all rounds with each device's frequency being the server device + for i in range(len(self.device_params_list)): + solver.Add( + solver.Sum( + [ + rounds_as_server[self.device_params_list[j].device_id] + * energy[i][j] + for j in range(len(self.device_params_list)) + ] + ) + <= self.device_params_list[i].current_battery + ) + # objective: maximize the total number of rounds, i.e. the sum of times each device is server + solver.Maximize(solver.Sum(rounds_as_server.values())) + + print( + solver.ExportModelAsLpFormat(False).replace("\\", "").replace(",_", ","), + sep="\n", + ) + status = solver.Solve() + + return { + device.device_id: rounds_as_server[device.device_id].solution_value() + for device in self.device_params_list + }, status + + +class EnergySimulator: + + def __init__(self, device_params_list, global_params): + """ + Simulator class to simulate the different server choice algorithms. Allows to simulate the algorithms for different + scenarios instead of running it on the actual devices. This is useful for testing and debugging but requires actual values though. + Args: + device_params_list: list of DeviceParams objects + global_params: GlobalParams object + """ + self.device_params_list = device_params_list + self.global_params = global_params + self.server_choice_optimizer = ServerChoiceOptimizer( + device_params_list, global_params + ) + + def simulate_greedy_selection(self): + """ + Simulates the greedy server choice algorithm. + Returns: + num_rounds: number of rounds until the first device runs out of battery + server_selection_schedule: list of server device ids for each round + device_batteries: list of battery levels for each device after the last successful round + """ + + def __get_device_with_max_battery__(device_battery_list): + return max( + range(len(device_battery_list)), key=device_battery_list.__getitem__ + ) + + def __all_devices_alive__(device_battery_list): + return all(battery > 0 for battery in device_battery_list) + + energy = np.array( + [ + self.server_choice_optimizer.energy_per_round_on_device( + device.device_id, server_device.device_id + ) + for device in self.device_params_list + for server_device in self.device_params_list + ] + ) + # energy[i][j] = energy per round for device i with server j + energy = energy.reshape( + (len(self.device_params_list), len(self.device_params_list)) + ) + + device_batteries = [ + device.current_battery for device in self.device_params_list + ] + all_devices_alive = True + server_selection_schedule = [] + num_rounds = 0 + while all_devices_alive: + server_device_idx = __get_device_with_max_battery__(device_batteries) + new_batteries = device_batteries.copy() + for idx, device in enumerate(self.device_params_list): + new_batteries[idx] = new_batteries[idx] - energy[idx][server_device_idx] + all_devices_alive = __all_devices_alive__(new_batteries) + if all_devices_alive: + device_batteries = new_batteries + num_rounds += 1 + server_selection_schedule.append( + self.device_params_list[server_device_idx].device_id + ) + else: + break + return num_rounds, server_selection_schedule, device_batteries + + def simulate_smart_selection(self): + """ + Simulates the smart server choice algorithm. + Returns: + num_rounds: number of rounds until the first device runs out of battery + solution: the solution computed by the optimizer + device_batteries: list of battery levels for each device after in the end + """ + solution, status = self.server_choice_optimizer.optimize() + return sum(solution.values()), solution, self._remaining_batteries(solution) + + def _remaining_batteries(self, solution): + device_batteries = [ + device.current_battery for device in self.device_params_list + ] + for server_device_id, rounds in solution.items(): + for idx, device in enumerate(self.device_params_list): + device_batteries[idx] -= ( + self.server_choice_optimizer.energy_per_round_on_device( + device.device_id, server_device_id + ) + * rounds + ) + return device_batteries + + def _fl_round_time(self): + train_times = [] + for device in self.device_params_list: + model_bw_time = ( + self.global_params.client_norm_bw_time + + self.global_params.server_norm_bw_time + ) + model_fw_time = ( + self.global_params.client_norm_fw_time + + self.global_params.server_norm_fw_time + ) + total_time = ( + (device.train_samples + device.validation_samples) * model_fw_time + + device.train_samples * model_bw_time + ) * device.comp_latency_factor + train_times.append(total_time) + return max(train_times) + + def _fl_flops_on_device(self, device_id): + device = self._get_device_params(device_id) + total_flops = 0 + total_flops += self.global_params.client_model_flops * device.train_samples * 3 + total_flops += self.global_params.client_model_flops * device.validation_samples + total_flops += self.global_params.server_model_flops * device.train_samples * 3 + total_flops += self.global_params.server_model_flops * device.validation_samples + return total_flops + + def _fl_data_sent_per_device(self): + return ( + self.global_params.server_weights_size + + self.global_params.client_weights_size + ) + + def _fl_data_received_per_device(self): + return ( + self.global_params.server_weights_size + + self.global_params.client_weights_size + ) + + def _fl_energy_per_round_on_device(self, device_id): + total_energy = 0 + total_energy += ( + self._fl_flops_on_device(device_id) * self.global_params.cost_per_flop + ) + total_energy += ( + self._fl_data_sent_per_device() * self.global_params.cost_per_byte_sent + ) + total_energy += ( + self._fl_data_received_per_device() + * self.global_params.cost_per_byte_received + ) + total_energy += self._fl_round_time() * self.global_params.cost_per_sec + return total_energy + + def simulate_federated_learning(self): + """ + Simulates the federated learning algorithm. + Returns: + num_rounds: number of rounds until the first device runs out of battery + device_batteries: list of battery levels for each device after the last successful round + """ + + def __all_devices_alive__(device_battery_list): + return all(battery > 0 for battery in device_battery_list) + + device_batteries = [ + device.current_battery for device in self.device_params_list + ] + all_devices_alive = True + energy = [ + self._fl_energy_per_round_on_device(device.device_id) + for device in self.device_params_list + ] + num_rounds = 0 + while all_devices_alive: + new_batteries = device_batteries.copy() + for idx, device in enumerate(self.device_params_list): + new_batteries[idx] = new_batteries[idx] - energy[idx] + all_devices_alive = __all_devices_alive__(new_batteries) + if all_devices_alive: + num_rounds += 1 + device_batteries = new_batteries + return num_rounds, device_batteries + + def _get_device_params(self, device_id): + for device in self.device_params_list: + if device.device_id == device_id: + return device + return None + + +def run_grid_search( + device_params_list, + global_params, + batteries=None, + latencies=None, + partitions=None, + cost_per_sec=None, + cost_per_byte_sent=None, + cost_per_byte_received=None, + cost_per_flop=None, +): + """ + Runs a grid search for the given device parameters and global parameters. + Params: + device_params_list: list of DeviceParams objects + global_params: GlobalParams object + **kwargs: optional parameters for the grid search: should be a list of lists for device params and a list for costs + If e.g. cost_per_second is provided, the grid search will be run for all values in the list overriding existing values in the global params object. + If no cost_per_second is provided, the grid search will use the value from the global params object. + Returns: + list of dicts containing the results for each combination of parameters + """ + if batteries is None: + batteries = [[device.current_battery for device in device_params_list]] + if latencies is None: + latencies = [[device.comp_latency_factor for device in device_params_list]] + total_train_samples = sum(device.train_samples for device in device_params_list) + total_val_samples = sum(device.validation_samples for device in device_params_list) + if partitions is None: + partitions = [ + [ + device.train_samples / total_train_samples + for device in device_params_list + ] + ] + if cost_per_sec is None: + cost_per_sec = [global_params.cost_per_sec] + if cost_per_byte_sent is None: + cost_per_byte_sent = [global_params.cost_per_byte_sent] + if cost_per_byte_received is None: + cost_per_byte_received = [global_params.cost_per_byte_received] + if cost_per_flop is None: + cost_per_flop = [global_params.cost_per_flop] + results = [] + for battery in batteries: + for latency in latencies: + for partition in partitions: + for cost_sec in cost_per_sec: + for cost_sent in cost_per_byte_sent: + for cost_received in cost_per_byte_received: + for cost_flop in cost_per_flop: + global_params.cost_per_sec = cost_sec + global_params.cost_per_byte_sent = cost_sent + global_params.cost_per_byte_received = cost_received + global_params.cost_per_flop = cost_flop + for idx, device in enumerate(device_params_list): + device.current_battery = battery[idx] + device.comp_latency_factor = latency[idx] + device.train_samples = ( + partition[idx] * total_train_samples + ) + device.validation_samples = ( + partition[idx] * total_val_samples + ) + energy_simulator = EnergySimulator( + device_params_list, global_params + ) + num_rounds_smart, _, _ = ( + energy_simulator.simulate_smart_selection() + ) + num_rounds_greedy, _, _ = ( + energy_simulator.simulate_greedy_selection() + ) + num_rounds_fl, _ = ( + energy_simulator.simulate_federated_learning() + ) + results.append( + { + "battery": battery, + "latency": latency, + "partition": partition, + "cost_per_sec": cost_sec, + "cost_per_byte_sent": cost_sent, + "cost_per_byte_received": cost_received, + "cost_per_flop": cost_flop, + "num_rounds_smart": num_rounds_smart, + "num_rounds_greedy": num_rounds_greedy, + "num_rounds_fl": num_rounds_fl, + } + ) + return results + + +def run_grid_search_with_variable_devices( + num_devices_list, + global_params, + battery_per_device, + total_train_samples, + total_val_samples, + max_latencies=None, + max_split=None, + cost_per_sec=None, + cost_per_byte_sent=None, + cost_per_byte_received=None, + cost_per_flop=None, +): + """ + Runs a grid search for the given device parameters and global parameters. + Params: + num_devices: list(number of devices) + global_params: GlobalParams object + battery_per_device: list(battery per device) + total_train_samples: total number of train samples + total_val_samples: total number of validation samples + max_latencies: optional list(max latency per device) Sets all devices except the last one to the max latency if provided + max_split: optional list(max split per device) Sets the largest split for last device if provided and distributes the rest equally among the other devices + **kwargs: optional parameters for the grid search: should be a list of lists for device params and a list for costs + If e.g. cost_per_second is provided, the grid search will be run for all values in the list overriding existing values in the global params object. + If no cost_per_second is provided, the grid search will use the value from the global params object. + Returns: + list of dicts containing the results for each combination of parameters + """ + if max_latencies is None: + max_latencies = [1.0] + if cost_per_sec is None: + cost_per_sec = [global_params.cost_per_sec] + if cost_per_byte_sent is None: + cost_per_byte_sent = [global_params.cost_per_byte_sent] + if cost_per_byte_received is None: + cost_per_byte_received = [global_params.cost_per_byte_received] + if cost_per_flop is None: + cost_per_flop = [global_params.cost_per_flop] + results = [] + for num_devices in num_devices_list: + if max_split is None: + max_split = [1 / num_devices] + for battery in battery_per_device: + for latency in max_latencies: + for partition in max_split: + device_params_list = [ + DeviceParams( + device_id=f"d{i}", + initial_battery=0, + current_battery=battery, + train_samples=( + partition * total_train_samples + if i == num_devices - 1 + else total_train_samples // num_devices + ), + validation_samples=( + partition * total_val_samples + if i == num_devices - 1 + else total_val_samples // num_devices + ), + comp_latency_factor=latency if i < num_devices - 1 else 1.0, + ) + for i in range(num_devices) + ] + for cost_sec in cost_per_sec: + for cost_sent in cost_per_byte_sent: + for cost_received in cost_per_byte_received: + for cost_flop in cost_per_flop: + global_params.cost_per_sec = cost_sec + global_params.cost_per_byte_sent = cost_sent + global_params.cost_per_byte_received = cost_received + global_params.cost_per_flop = cost_flop + energy_simulator = EnergySimulator( + device_params_list, global_params + ) + num_rounds_smart, _, _ = ( + energy_simulator.simulate_smart_selection() + ) + num_rounds_greedy, _, _ = ( + energy_simulator.simulate_greedy_selection() + ) + num_rounds_fl, _ = ( + energy_simulator.simulate_federated_learning() + ) + results.append( + { + "num_devices": num_devices, + "battery": battery, + "latency": latency, + "partition": partition, + "cost_per_sec": cost_sec, + "cost_per_byte_sent": cost_sent, + "cost_per_byte_received": cost_received, + "cost_per_flop": cost_flop, + "num_rounds_smart": num_rounds_smart, + "num_rounds_greedy": num_rounds_greedy, + "num_rounds_fl": num_rounds_fl, + } + ) + return results diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..4463e2cc8ffd3a89471e8728ba0a117e163e8fea --- /dev/null +++ b/edml/controllers/swarm_controller.py @@ -0,0 +1,137 @@ +from typing import Any + +from edml.controllers.base_controller import BaseController +from edml.controllers.scheduler.base import NextServerScheduler +from edml.helpers.config_helpers import get_device_index_by_id + + +class SwarmController(BaseController): + + def __init__(self, cfg, scheduler: NextServerScheduler): + super().__init__(cfg) + scheduler.initialize(self) + self._next_server_scheduler = scheduler + + def _train(self): + client_weights = None + server_weights = None + server_device_id = None + diagnostic_metric_container = None + optimizer_state = None + for i in range(self.cfg.experiment.max_rounds): + print(f"Round {i}") + self._update_devices_battery_status() + server_device_id = self._select_server_device( + last_server_device_id=server_device_id, + diagnostic_metric_container=diagnostic_metric_container, + ) + + # break if no active devices or only server device left + if server_device_id is None or self._devices_empty_or_only_server_left( + server_device_id + ): + print("No active client devices left.") + break + + ( + client_weights, + server_weights, + metrics, + optimizer_state, + diagnostic_metric_container, + ) = self._swarm_train_round( + client_weights, + server_weights, + server_device_id, + round_no=i, + optimizer_state=optimizer_state, + ) + + self._refresh_active_devices() + self.logger.log( + {"remaining_devices": {"devices": len(self.active_devices), "round": i}} + ) + self.logger.log( + { + "server_device": { + "device": get_device_index_by_id(self.cfg, server_device_id) + }, + "round": i, + } + ) # log the server device index for convenience + + if metrics is not None: + self._aggregate_and_log_metrics(metrics, i) + + early_stop = self.early_stopping(metrics, i) + if early_stop: + break + + self._save_weights(client_weights, server_weights, i) + + def _swarm_train_round( + self, + client_weights, + server_weights, + server_device_id, + round_no: int = -1, + optimizer_state: dict[str, Any] = None, + ): + self._refresh_active_devices() + # set latest client weights on first device to train on + self.request_dispatcher.set_weights_on( + device_id=self.active_devices[0], state_dict=client_weights, on_client=True + ) + # set latest server weights on server device + self.request_dispatcher.set_weights_on( + device_id=server_device_id, state_dict=server_weights, on_client=False + ) + + training_response = self.request_dispatcher.train_global_on( + server_device_id, + epochs=1, + round_no=round_no, + optimizer_state=optimizer_state, + ) + + if training_response is not False: # server device unavailable + return training_response + else: + return ( + client_weights, + server_weights, + None, + optimizer_state, + None, + ) # return most recent weights and no metrics + + def _select_server_device( + self, last_server_device_id=None, diagnostic_metric_container=None + ): + """Returns the id of the server device for the given round.""" + if len(self.active_devices) == 0: + return None + return self._next_server_scheduler.next_server( + self.active_devices, + last_server_device_id=last_server_device_id, + diagnostic_metric_container=diagnostic_metric_container, + ) + + def _get_active_devices_dataset_sizes_and_model_flops(self): + """Returns the dataset sizes and model flops of active devices only.""" + dataset_sizes = {} + model_flops = {} + client_flop_list = [] + server_flop_list = [] + for device_id in self.active_devices: + train_samples, val_samples, client_flops, server_flops = ( + self.request_dispatcher.get_dataset_model_info_on(device_id) + ) + dataset_sizes[device_id] = (train_samples, val_samples) + client_flop_list.append(client_flops) + server_flop_list.append(server_flops) + # avoid that flops are taken from a device that wasn't used for training and thus has no flops + # apart from that, FLOPs should be the same everywhere + model_flops["client"] = max(client_flop_list) + model_flops["server"] = max(server_flop_list) + return dataset_sizes, model_flops diff --git a/edml/controllers/test_controller.py b/edml/controllers/test_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5e7c187a745d13f8dae9673d3518fc9015b9d2 --- /dev/null +++ b/edml/controllers/test_controller.py @@ -0,0 +1,53 @@ +import concurrent.futures +import os + +from omegaconf.errors import ConfigAttributeError + +from edml.controllers.base_controller import BaseController + + +class TestController(BaseController): + """Controller for evaluating the models with the test data on all devices.""" + + def _train(self): + """ + Raises: + ValueError: If no models are found with the path is defined in the config and the prefix saved. + Notes: + Evaluates the test data on all devices locally using the full network. + Assumes sufficient (or infinite) battery on all devices. + If config defines a 'best_round' attribute, the weights numbered with the corresponding round are used. + Otherwise, the weights with the highest postfix number are used. + """ + try: + best_round = self.cfg.best_round + except ConfigAttributeError: + # assume client and server model have same highest postfix number + best_round = self._get_model_with_highest_postfix_number( + self.cfg.experiment.client_model_save_path + ) + client_weights, server_weights = self._load_weights(best_round) + self._set_weights_on_all_devices(client_weights, on_client=True) + self._set_weights_on_all_devices(server_weights, on_client=False) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max(len(self.active_devices), 1) + ) as executor: # avoid exception when setting 0 workers + futures = [ + executor.submit( + self.request_dispatcher.evaluate_global_on, device_id, False, True + ) + for device_id in self.active_devices + ] + concurrent.futures.wait(futures) + + def _get_model_with_highest_postfix_number(self, model_save_path): + """Returns the highest postfix number in the given directory for the configured model_prefix.""" + model_prefix = self.__model_prefix__() + highest_postfix_number = 0 + for file in os.listdir(model_save_path): + if file.startswith(model_prefix): + postfix_number = int(file.split("_")[-1].split(".")[0]) + if postfix_number > highest_postfix_number: + highest_postfix_number = postfix_number + return highest_postfix_number diff --git a/edml/core/__init__.py b/edml/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/core/battery.py b/edml/core/battery.py new file mode 100644 index 0000000000000000000000000000000000000000..783f62a5dccfa2436898c499def9976d645cb4dd --- /dev/null +++ b/edml/core/battery.py @@ -0,0 +1,186 @@ +import threading +import time +from typing import Final + +from edml.helpers.units import MEGA_FACTOR + + +class Battery: + """ + A battery used to simulate energy consumption for participating devices. + + Batteries are capable of tracking energy consumption for network data usage + and training energy consumption based on flop count. + """ + + def __init__( + self, + capacity: int, + deduction_per_second: float = 1.0, + deduction_per_mflop: float = 0.1, + deduction_per_mbyte_received: float = 0, + deduction_per_mbyte_sent: float = 0, + ): + """ + Initializes a new battery with the given deduction rates. + + Args: + capacity (int): the total battery capacity. + deduction_per_second (float): the deduction rate per second. Defaults to 1.0. + deduction_per_mflop (float): the deduction rate per mega-flop. Defaults to 0.1. + deduction_per_mbyte_received (float): The deduction rate per megabyte of received network data. + Defaults to 0. + deduction_per_mbyte_sent (float): the deduction rate per megabyte of sent network data. Defaults to 0. + """ + + # Batteries can be in a state where they have not been "started". Meaning + # that they do not deduce any energy consumption yet. + # To put batteries into an active state, you have to call `start_experiment`. + self.__running__: bool = False + + self.__capacity__: int = capacity + self.__initial_capacity__: Final[int] = capacity + self.__last_update__: float = time.time() + self.__deduction_per_second__: float = deduction_per_second + self.__deduction_per_mflop__: float = deduction_per_mflop + self.__deduction_per_mbyte_received__: float = deduction_per_mbyte_received + self.__deduction_per_mbyte_sent__: float = deduction_per_mbyte_sent + self.__lock__ = threading.Lock() + + def update_flops(self, flops: int): + """ + Decreases the battery capacity by the given amount of flops. + + Args: + flops (int): amount of flops to decrease battery capacity. + + Raises: + BatteryEmptyException: if the energy deduction is less or equal to + the remaining battery capacity. + """ + self.__decrease_by__(self.__deduction_per_mflop__ * flops / MEGA_FACTOR) + + def update_time(self): + """ + Decreases the battery capacity based on the amount of time passed since the + last method call of `update_time`. + + Raises: + BatteryEmptyException: if the energy deduction is less or equal to + the remaining battery capacity. + + Notes: + Acquires a lock to prevent race conditions when accessing the battery state. + """ + now = time.time() + difference = now - self.__last_update__ + if difference > 0: + self.__decrease_by__(self.__deduction_per_second__ * difference) + self.__last_update__ = now + + def update_communication_received(self, size_in_bytes: int): + """ + Reduces the remaining battery capacity based on the amount of network data received. + + The consumed energy amount is calculated by multiplying the number if received bytes with the + `deduction_per_mbyte_received` factor. + + Args: + size_in_bytes (int): the number of received bytes. + + Raises: + BatteryEmptyException: if the energy deduction is less or equal to + the remaining battery capacity. + + Notes: + Acquires a lock to prevent race conditions when accessing the battery state. + """ + self.__decrease_by__( + size_in_bytes * self.__deduction_per_mbyte_received__ / MEGA_FACTOR + ) + + def update_communication_sent(self, size_in_bytes: int): + """ + Reduces the remaining battery capacity based on the amount of network data sent. + + The consumed energy amount is calculated by multiplying the number if received bytes with the + `deduction_per_mbyte_sent` factor. + + Args: + size_in_bytes (int): the number of sent bytes. + + Raises: + BatteryEmptyException: if the energy deduction is less or equal to + the remaining battery capacity. + + Notes: + Acquires a lock to prevent race conditions when accessing the battery state. + """ + self.__decrease_by__( + size_in_bytes * self.__deduction_per_mbyte_sent__ / MEGA_FACTOR + ) + + def remaining_capacity(self) -> int: + """Returns the remaining battery capacity.""" + return self.__capacity__ + + def initial_capacity(self) -> int: + """Returns the initial battery capacity.""" + return self.__initial_capacity__ + + def is_empty(self) -> bool: + """ + Returns whether the battery is empty or not. + + Returns: + bool: `True` if the battery is empty, `False` otherwise. + + Notes: + Acquires a lock to prevent race conditions when accessing the battery state. + """ + with self.__lock__: + return self.__capacity__ <= 0 + + def __decrease_by__(self, amount): + """ + Decreases the battery capacity by the given amount. + + Raises: + BatteryEmptyException: if the battery is empty. + + Notes: + Decreases the capacity only if :py:meth:`~edml.core.battery.Battery.start_experiment` has been called. + """ + with self.__lock__: + if self.__running__: + self.__capacity__ -= amount + if self.is_empty(): + raise BatteryEmptyException("Battery empty") + + def start_experiment(self): + """ + Initializes the battery and enables time-based capacity depletion. + + Notes: + This method should be called during the `start_experiment` state of the controller. The method has to be + called before any method that reduces the battery's capacity. If not, methods that reduce the battery's + capacity will be a no-op. + + See: + - :py:meth:`~edml.core.battery.Battery.update_flops` + - :py:meth:`~edml.core.battery.Battery.update_time` + - :py:meth:`~edml.core.battery.Battery.update_communication_received` + - :py:meth:`~edml.core.battery.Battery.update_communication_sent` + """ + self.__running__ = True + + +class BatteryEmptyException(Exception): + """ + This exception is raised when the battery runs empty. + + The exception is caught inside gRPC network dispatchers, allowing us to gracefully + close the TCP connections. + """ + + pass diff --git a/edml/core/client.py b/edml/core/client.py new file mode 100644 index 0000000000000000000000000000000000000000..fe41e2cdb3c0be6847b37df8ae3c7cae8c4eba7d --- /dev/null +++ b/edml/core/client.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +import itertools +import time +from typing import Optional, Tuple, TYPE_CHECKING, Any + +import torch +from omegaconf import DictConfig +from torch import nn +from torch.utils.data import DataLoader + +from edml.helpers.config_helpers import get_torch_device_id +from edml.helpers.decorators import ( + check_device_set, + simulate_latency_decorator, + LatencySimulator, +) +from edml.helpers.flops import estimate_model_flops +from edml.helpers.load_optimizer import get_optimizer_and_scheduler +from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult +from edml.helpers.types import StateDict, SLTrainBatchResult + +if TYPE_CHECKING: + from edml.core.device import Device + + +class DeviceClient: + """ + A client in the context of split learning. I.e., a device that trains the first `n`-th layers on the client. The + remaining layers will then be trained on the server device. + + Attributes: + - TODO: latency_factor and node_device? + + See: + - py:meth:`~edml.core.server.DeviceServer` + - py:meth:`~edml.controllers.split_learning.SplitController` + - py:meth:`~edml.controllers.swarm_learning.SwarmController` + + Split learning client that runs on a device and communicates with servers on (potentially) other devices + through the provided interface by its device.""" + + def __init__( + self, + model: nn.Module, + cfg: DictConfig, + train_dl: DataLoader, + val_dl: DataLoader, + test_dl: DataLoader, + latency_factor: float = 0.0, + ): + """ + Initializes the split learning client with its (partial) model and training, validation and test data. + + Args: + model (nn.Module): The pytorch neural network trained by the client. + cfg (DictConfig): The experiment's configuration. + train_dl (DataLoader): The data loader responsible for loading training data. + val_dl (DataLoader): The data loader responsible for loading validation data. + test_dl (DataLoader): The data loader responsible for loading testing data. + latency_factor (float): + + Notes: + This class moves the model to the GPU if CUDA is available. If not, the model will be moved to the CPU. + """ + + self._train_data, self._val_data, self._test_data = train_dl, val_dl, test_dl + self._batchable_data_loader = None + self._device = torch.device(get_torch_device_id(cfg)) + self._model = model.to(self._device) + self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( + cfg, self._model.parameters() + ) + # get first sample from train data to estimate model flops + sample = self._train_data.dataset.__getitem__(0)[0] + if not isinstance(sample, torch.Tensor): + sample = torch.tensor(data=sample) + self._model_flops = estimate_model_flops( + self._model, sample.to(self._device).unsqueeze(0) + ) + self._cfg = cfg + self.node_device: Optional[Device] = None + self.latency_factor = latency_factor + + self._psl_cache = None + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def set_device(self, node_device: Device): + """ + Sets the device that this client-side is part of. + + Notes: + If a latency factor is specified, this function sleeps for said amount before returning. + """ + self.node_device = node_device + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def set_weights(self, state_dict: StateDict): + """ + Updates the model's weights with the one specified. + + Args: + state_dict (StateDict): The model's parameters and buffers. + + Notes: + If a latency factor is specified, this function sleeps for said amount before returning. + """ + if state_dict is not None: + self._model.load_state_dict(state_dict=state_dict) + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def get_weights(self) -> StateDict: + """ + Returns the model's parameters and buffers, including its weights. + + Returns: + StateDict + + Notes: + If a latency factor is specified, this function sleeps for said amount before returning. + """ + return self._model.state_dict() + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def get_num_samples(self) -> int: + """ + Returns the number of samples in the client's training data. + + Returns: + int: The length of the training data set. + + Notes: + If a latency factor is specified, this function sleeps for said amount before returning. + """ + return len(self._train_data.dataset) + + @check_device_set() + def train_single_batch( + self, batch_index: int + ) -> Optional[torch.Tensor, torch.Tensor]: + torch.cuda.set_device(self._device) + # We have to re-initialize the data loader in the case that we do another epoch. + if batch_index == 0: + self._batchable_data_loader = iter(self._train_data) + + # Used to measure training time. The problem we have with parallel split learning is that forward- and backward- + # passes are orchestrated by the current server. + # Thus, we need to cache the time required for the forward pass to ensure that we collect the right execution + # time. + start_time = time.time() + + self._model.train() + + # We need to get the number of batches that the DataLoader can provide us with to properly index and retrieve + # the correct batch. + # + # TODO: is there another way to do this? gRPC streaming does not work here, since we have to keep the streams + # alive while doing other RPC calls like settings weights, sending/averaging gradient data, ... + num_batches = self.get_approximated_num_batches() + assert 0 <= batch_index < num_batches + + # Safety check to ensure that we train same-sized batches only. + batch_data, batch_labels = next(self._batchable_data_loader) + + # Updates the battery capacity by simulating the required energy consumption for conducting the training step. + self.node_device.battery.update_flops(self._model_flops * len(batch_data)) + + # We train the model using the single batch and return the activations and labels. These get send over to the + # server to be then further processed + + with LatencySimulator(latency_factor=self.latency_factor): + batch_data_to = batch_data.to(self._device) + + self._optimizer.zero_grad() + smashed_data = self._model(batch_data_to) + + end_time = time.time() + + self._psl_cache = { + "batch_data": batch_data, + "smashed_data": smashed_data, + "start_time": start_time, + "end_time": end_time, + } + return smashed_data, batch_labels + + @check_device_set() + def backward_single_batch( + self, gradients + ) -> Tuple[DiagnosticMetricResultContainer, torch.Tensor]: + torch.cuda.set_device(self._device) + batch_data, smashed_data, start_time, end_time = ( + self._psl_cache["batch_data"], + self._psl_cache["smashed_data"], + self._psl_cache["start_time"], + self._psl_cache["end_time"], + ) + + start_time_2 = time.time() + + self.node_device.battery.update_flops( + self._model_flops * len(batch_data) * 2 + ) # 2x for backward pass + gradients = gradients.to(self._device) + smashed_data.backward(gradients) + # self._optimizer.step() + + # We need to store a reference to the smashed_data to make it possible to finalize the training step. + self._psl_cache["smashed_data"] = smashed_data + + end_time_2 = time.time() + + metric = DiagnosticMetricResult( + device_id=self.node_device.device_id, + name="comp_time", + value=end_time - start_time + (end_time_2 - start_time_2), + method="client_train_batch_time", + ) + metrics_container = DiagnosticMetricResultContainer([metric]) + + gradients = [] + for param in self._model.parameters(): + if param is not None: + gradients.append(param.grad) + else: + gradients.append(torch.zeros_like(param)) + + return metrics_container, gradients + + def get_approximated_num_batches(self) -> int: + return len(self._train_data) + + @check_device_set() + def train_epoch( + self, server_device_id: str, round_no: int = -1 + ) -> Tuple[StateDict, DiagnosticMetricResultContainer]: + """ + Trains the model on the client's data for one epoch, returning the new weights and training metrics. + The server model is run on the device with the given id. + + Args: + server_device_id (str): The id of the device on which the server model is run. + round_no (int, optional): The current epoch number. Required when using a learning rate scheduler. + + Returns: + StateDict: The updated weights of the client's model. + DiagnosticMetricResultContainer: The diagnostic metrics collected when training on the server and the actual + client model execution time. + + Notes: + If configured, runtime latency is simulated on neural network operations. + + For optimizing the server device selection, the training time for the client model is needed. Therefore, the + execution time (without the time for the server to process the batches) is measured and added as a + diagnostic metrics. + + Usual designs measure the execution time at the device level (including batch processing time). Contrary to + that, this approach does not require to deduce server batch processing time after a "traditional" + measurement. + """ + client_train_start_time = time.time() + server_train_batch_times = ( + [] + ) # collects the time for the server to process the batches + self._model.train() + diagnostic_metric_container = DiagnosticMetricResultContainer() + for idx, (batch_data, batch_labels) in enumerate(self._train_data): + self.node_device.battery.update_flops(self._model_flops * len(batch_data)) + + with LatencySimulator(latency_factor=self.latency_factor): + batch_data = batch_data.to(self._device) + batch_labels = batch_labels.to(self._device) + + self._optimizer.zero_grad() + smashed_data = self._model(batch_data) + + # measure the time for the server to process the batch + start_time = time.time() + train_batch_response = self.node_device.train_batch_on( + server_device_id, smashed_data, batch_labels + ) + server_train_batch_times.append(time.time() - start_time) + + with LatencySimulator(latency_factor=self.latency_factor): + if ( + train_batch_response is False or train_batch_response is None + ): # server device unavailable + break + server_grad, _server_loss, diagnostic_metrics = train_batch_response + diagnostic_metric_container.merge(diagnostic_metrics) + self.node_device.battery.update_flops( + self._model_flops * len(batch_data) * 2 + ) # 2x for backward pass + server_grad = server_grad.to(self._device) + smashed_data.backward(server_grad) + self._optimizer.step() + + if self._lr_scheduler is not None: + if round_no != -1: + self._lr_scheduler.step(round_no) + else: + self._lr_scheduler.step() + + client_train_time = ( + time.time() - client_train_start_time - sum(server_train_batch_times) + ) + diagnostic_metric_container.add_result( + DiagnosticMetricResult( + device_id=self.node_device.device_id, + name="comp_time", + value=client_train_time, + method="client_train_epoch_time", + ) + ) + return self._model.state_dict(), diagnostic_metric_container + + @check_device_set() + def evaluate(self, server_device: str, val=True) -> DiagnosticMetricResultContainer: + """ + Evaluates the model on the client's data. + + Args: + server_device (str): The server device to run the server model on. + val (bool, optional): If `True`, uses the validation data, otherwise the test data. + Set to `True` by default. + + Returns: + DiagnosticMetricResultContainer: The diagnostic metrics collected when training on the server and the actual + client model execution time. + + Notes: + If configured, runtime latency is simulated on neural network operations. + + For optimizing the server device selection, the training time for the client model is needed. Therefore, the + execution time (without the time for the server to process the batches) is measured and added as a + diagnostic metrics. + + Usual designs measure the execution time at the device level (including batch processing time). Contrary to + that, this approach does not require to deduce server batch processing time after a "traditional" + measurement. + """ + client_eval_start_time = time.time() + server_eval_batch_times = ( + [] + ) # collects the time for the server to process the batches + self._model.eval() + diagnostic_metric_results = DiagnosticMetricResultContainer() + with torch.no_grad(): + dataloader = self._val_data if val else self._test_data + for b, (batch_data, batch_labels) in enumerate(dataloader): + with LatencySimulator(latency_factor=self.latency_factor): + self.node_device.battery.update_flops( + self._model_flops * len(batch_data) + ) + batch_data = batch_data.to(self._device) + batch_labels = batch_labels.to(self._device) + + # measure the time for the server to process the batch + start_time = time.time() + diagnostic_metrics = self.node_device.evaluate_batch_on( + server_device, self._model(batch_data), batch_labels + ) + server_eval_batch_times.append(time.time() - start_time) + + diagnostic_metric_results.merge(diagnostic_metrics) + client_eval_time = ( + time.time() - client_eval_start_time - sum(server_eval_batch_times) + ) + diagnostic_metric_results.add_result( + DiagnosticMetricResult( + device_id=self.node_device.device_id, + name="comp_time", + value=client_eval_time, + method="client_eval_epoch_time", + ) + ) + return diagnostic_metric_results + + def set_gradient_and_finalize_training(self, gradients: Any): + for param, grad in zip(self._model.parameters(), gradients): + param.grad = grad.to(self._device) + + self._optimizer.step() + self._psl_cache = None diff --git a/edml/core/device.py b/edml/core/device.py new file mode 100644 index 0000000000000000000000000000000000000000..51278532bfb7f1888f13671a24a289fc10e868b2 --- /dev/null +++ b/edml/core/device.py @@ -0,0 +1,1036 @@ +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, List, Union, Tuple, cast + +import grpc +from google.protobuf.message import Message +from omegaconf import DictConfig +from torch import Tensor +from torch.autograd import Variable + +from edml.core.battery import Battery +from edml.core.client import DeviceClient +from edml.core.server import DeviceServer +from edml.generated import connection_pb2 +from edml.generated.connection_pb2 import ( + SetGradientsRequest, + SetWeightsRequest, + TrainBatchRequest, + TrainGlobalResponse, + TrainEpochResponse, + TrainBatchResponse, + EvalGlobalResponse, + EvalResponse, + EvalBatchResponse, + FullModelTrainResponse, + BatteryStatusResponse, + DatasetModelInfoResponse, + EndExperimentResponse, + StartExperimentResponse, + SingleBatchTrainingResponse, + SingleBatchBackwardRequest, + TrainGlobalParallelSplitLearningResponse, + SingleBatchBackwardResponse, +) +from edml.generated.connection_pb2_grpc import DeviceServicer, DeviceStub +from edml.generated.datastructures_pb2 import ( + Gradients, + Weights, + DeviceInfo, + Activations, + Labels, + BatteryStatus, + Empty, + Metrics, +) +from edml.helpers.decorators import ( + log_execution_time, + update_battery, + add_time_to_diagnostic_metrics, +) +from edml.helpers.interceptors import DeviceClientInterceptor +from edml.helpers.logging import SimpleLogger +from edml.helpers.metrics import ( + ModelMetricResultContainer, + DiagnosticMetricResultContainer, +) +from edml.helpers.proto_helpers import ( + proto_to_tensor, + tensor_to_proto, + state_dict_to_proto, + proto_to_state_dict, + proto_to_weights, + metrics_to_proto, + proto_to_metrics, + _proto_size_per_field, +) +from edml.helpers.types import ( + HasMetrics, + StateDict, + DeviceBatteryStatus, + DeviceBatteryStatusReport, +) + + +class Device(ABC): + """ + Base class that represents a (physical or virtual) device. Every device is split into two parts: a client part and + a server part. + + Attributes: + device_id (str): This device's id. + logger (SimpleLogger): The logger instance the device can use. + battery (Battery): The device's battery. Certain function consume energy and drain the battery. + client (DeviceClient): The client part of this device. Initialized later by explicitly calling + py:meth:`set_client`. + server (DeviceServer): The server part of this device. Initialized later by explicitly calling + py:meth:`set_server`. + """ + + def __init__(self, device_id: str, logger: SimpleLogger, battery: Battery): + self.client: DeviceClient = cast(DeviceClient, None) + self.server: DeviceServer = cast(DeviceServer, None) + self.device_id = device_id + self.logger = logger + self.battery = battery + + def set_client(self, client: DeviceClient): + """ + Sets the client part of this device. + + Notes: + Also sets the `device` instance on the client part. + """ + self.client = client + self.client.set_device(self) + + def set_server(self, server: DeviceServer): + """ + Sets the server part of this device. + + Notes: + Also sets the `device` instance on the server part. + """ + self.server = server + self.server.set_device(self) + + def log(self, message: Any): + """Logging wrapper to be accessed by server and client""" + self.logger.log(message) + + def start_experiment(self): + """ + Lifecycle hook that is called at the start of an experiment. + + Notes: + Ensure that the super method is called if you override this method. + """ + self.logger.start_experiment() + self.battery.start_experiment() + + def end_experiment(self): + """ + Lifecycle hook that is called at the end of an experiment. + + Notes: + Ensure that the super method is called if you override this method. + """ + self.logger.end_experiment() + + def get_battery_status(self) -> Tuple[int, int]: + """ + Returns the initial and remaining battery of this device. + + Returns: + Tuple[int, int]: The first component holds the initial battery capacity, the second the current capacity. + """ + return self.battery.initial_capacity(), self.battery.remaining_capacity() + + def shutdown(self): + """ + Shuts the device down and cleans up resources. + """ + self.end_experiment() + + @abstractmethod + def train_global(self, epochs: int): + """Trains globally for a given number of epochs using the device's server""" + + @abstractmethod + def set_devices(self, devices): + """Sets references to all devices in the network""" + + @abstractmethod + def set_weights(self, state_dict, on_client: bool): + """Sets the weights for one of the device's networks""" + + @abstractmethod + def set_weights_on(self, device_id: str, state_dict, on_client: bool): + """Sets the weights for on of the networks on the device with the given id""" + + @abstractmethod + def train_epoch(self, server_device: str): + """Trains an epoch on the device's client""" + + @abstractmethod + def train_epoch_on(self, device_id: str, server_device: str, round_no: int): + """Trains an epoch on the device's client with the given id""" + + @abstractmethod + def train_batch(self, smashed_data, labels): + """Trains a batch on the device's server""" + + @abstractmethod + def train_batch_on(self, device_id: str, smashed_data, labels): + """Trains a batch on the server of the device with the given id""" + + @abstractmethod + def evaluate_global(self, val: bool = True, fed: bool = False): + """Starts evaluation on all devices' clients using the device's server. val determines whether the validation (True) or test (False) set is used""" + + @abstractmethod + def evaluate(self, server_device: str, val: bool = True): + """Starts evaluation on the device's client using the specified server. val determines whether the validation (True) or test (False) set is used""" + + @abstractmethod + def evaluate_on(self, device_id: str, server_device: str, val: bool): + """Starts evaluation on the client of the device with the given id using the specified server. val determines whether the validation (True) or test (False) set is used""" + + @abstractmethod + def evaluate_batch(self, smashed_data, labels): + """Evaluates a batch on the device's server""" + + @abstractmethod + def evaluate_batch_on(self, device_id: str, smashed_data, labels): + """Evaluates a batch on the server of the device with the given id""" + + @abstractmethod + def train_batch_on_client_only_on(self, device_id: str, batch_index: int): + """""" + + @abstractmethod + def backpropagation_on_client_only_on(self, client_id: str, gradients: Any): + """""" + + @abstractmethod + def set_gradient_and_finalize_training_on_client_only_on( + self, client_id: str, gradients: Any + ): + """""" + + +class NetworkDevice(Device): + @update_battery + def set_gradient_and_finalize_training_on_client_only_on( + self, client_id: str, gradients: Any + ): + if client_id == self.device_id: + self.client.set_gradient_and_finalize_training(gradients) + else: + return self.request_dispatcher.set_gradient_and_finalize_training_on_client_only( + client_id, gradients + ) + + @update_battery + @log_execution_time("logger", "train_parallel_split_learning") + def train_parallel_split_learning( + self, + clients: list[str], + round_no: int, + adaptive_learning_threshold: Optional[float] = None, + optimizer_state: dict[str, Any] = None, + ): + return self.server.train_parallel_split_learning( + clients=clients, + round_no=round_no, + adaptive_learning_threshold=adaptive_learning_threshold, + optimizer_state=optimizer_state, + ) + + @update_battery + @log_execution_time("logger", "client_only_backpropagation_train") + def backpropagation_on_client_only(self, gradients: Any): + return self.client.backward_single_batch(gradients) + + @update_battery + def backpropagation_on_client_only_on(self, client_id: str, gradients: Any): + if client_id == self.device_id: + return self.backpropagation_on_client_only(gradients=gradients) + else: + return self.request_dispatcher.backpropagation_on_client_only( + device_id=client_id, gradients=gradients + ) + + @update_battery + @log_execution_time("logger", "client_only_batch_train") + def train_batch_on_client_only(self, batch_index: int): + smashed_data, labels = self.client.train_single_batch(batch_index=batch_index) + return smashed_data, labels + + @update_battery + def train_batch_on_client_only_on(self, device_id: str, batch_index: int): + if self.device_id == device_id: + return self.train_batch_on_client_only(batch_index=batch_index) + else: + return self.request_dispatcher.train_batch_on_client_only( + device_id=device_id, batch_index=batch_index + ) + + def __init__( + self, + device_id: str, + logger: SimpleLogger, + battery: Battery, + stop_event: Optional[threading.Event] = None, + ): + self.devices: List[DictConfig[str, Any]] = [] + self.request_dispatcher = DeviceRequestDispatcher([], device_id=device_id) + self.stop_event = stop_event + super().__init__(device_id, logger, battery) + + @add_time_to_diagnostic_metrics("train_global") + @update_battery + @log_execution_time("logger", "train_global_time") + def train_global( + self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None + ) -> Tuple[ + Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer + ]: + return self.server.train( + devices=self.__get_device_ids__(), + epochs=epochs, + round_no=round_no, + optimizer_state=optimizer_state, + ) + + def __get_device_ids__(self) -> List[str]: + return [d.device_id for d in self.devices] + + def set_devices(self, devices: List[DictConfig[str, Any]]): + """ + Sets a Dictionary with references to all devices in the network. + + Expects a list of dictionaries containing with keys for the device_id and address of each device. + """ + self.devices = devices + self.request_dispatcher = DeviceRequestDispatcher( + devices, + self.logger, + self.battery, + self.stop_event, + device_id=self.device_id, + ) + + @update_battery + def set_weights(self, state_dict, on_client: bool = True): + if on_client: + self.client.set_weights(state_dict) + else: + self.server.set_weights(state_dict) + + @update_battery + def set_weights_on(self, device_id: str, state_dict, on_client: bool = True): + if device_id == self.device_id: + self.set_weights(state_dict, on_client) + else: + self.request_dispatcher.set_weights_on(device_id, state_dict, on_client) + + @update_battery + @log_execution_time("logger", "client_train_epoch_time") + def train_epoch(self, server_device: str, round_no: int = -1): + # the execution time is measured in the client in order to deduct the time for the server + return self.client.train_epoch(server_device, round_no=round_no) + + @update_battery + def train_epoch_on(self, device_id: str, server_device: str, round_no: int = -1): + if device_id == self.device_id: + return self.train_epoch(server_device, round_no) + return self.request_dispatcher.train_epoch_on( + device_id, server_device, round_no + ) + + @add_time_to_diagnostic_metrics("train_batch") + @update_battery + def train_batch(self, smashed_data, labels) -> Variable: + result = self.server.train_batch(smashed_data, labels) + self._log_current_battery_capacity() + return result + + @update_battery + def train_batch_on(self, device_id: str, smashed_data, labels): + if device_id == self.device_id: + return self.train_batch(smashed_data, labels) + result = self.request_dispatcher.train_batch_on(device_id, smashed_data, labels) + self._log_current_battery_capacity() + return result + + @add_time_to_diagnostic_metrics("evaluate_global") + @update_battery + @log_execution_time("logger", "evaluate_global_time") + def evaluate_global( + self, val: bool = True, fed: bool = False + ) -> ModelMetricResultContainer: + if fed: + return self.server.evaluate_global(devices=[self.device_id], val=val) + else: + return self.server.evaluate_global( + devices=self.__get_device_ids__(), val=val + ) + + @update_battery + @log_execution_time("logger", "client_evaluate_time") + def evaluate(self, server_device: str, val=True) -> DiagnosticMetricResultContainer: + # the execution time is measured in the client in order to deduct the time for the server + return self.client.evaluate(server_device, val) + + @update_battery + def evaluate_on( + self, device_id, server_device, val + ) -> DiagnosticMetricResultContainer: + if device_id == self.device_id: + return self.evaluate(server_device, val) + else: + return self.request_dispatcher.evaluate_on(device_id, server_device, val) + + @add_time_to_diagnostic_metrics("evaluate_batch") + @update_battery + def evaluate_batch(self, smashed_data, labels): + result = self.server.evaluate_batch(smashed_data, labels) + self._log_current_battery_capacity() + return result + + @update_battery + def evaluate_batch_on(self, device_id, smashed_data, labels): + if device_id == self.device_id: + return self.evaluate_batch(smashed_data, labels) + else: + self._log_current_battery_capacity() + return self.request_dispatcher.evaluate_batch_on( + device_id, smashed_data, labels + ) + + @add_time_to_diagnostic_metrics("federated_train") + @update_battery + @log_execution_time("logger", "fed_train_time") + def federated_train( + self, round_no: int = -1 + ) -> Tuple[ + Any, Any, int, ModelMetricResultContainer, DiagnosticMetricResultContainer + ]: + """Returns client and server weights, the number of samples used for training and metrics""" + client_weights, server_weights, metrics, _, diagnostic_metrics = ( + self.server.train(devices=[self.device_id], epochs=1, round_no=round_no) + ) + num_samples = self.client.get_num_samples() + return client_weights, server_weights, num_samples, metrics, diagnostic_metrics + + def _log_current_battery_capacity(self): + """Wrapper for logging the current battery capacity""" + self.logger.log({"battery": self.battery.remaining_capacity()}) + + +class RPCDeviceServicer(DeviceServicer): + def __init__(self, device: NetworkDevice): + self.device = device + + def TrainGlobal(self, request, context): + print(f"Called TrainGlobal on device {self.device.device_id}") + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( + self.device.train_global(request.epochs, request.round_no) + ) + response = connection_pb2.TrainGlobalResponse( + client_weights=Weights(weights=state_dict_to_proto(client_weights)), + server_weights=Weights(weights=state_dict_to_proto(server_weights)), + metrics=metrics_to_proto(metrics), + optimizer_state=state_dict_to_proto(optimizer_state), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + return response + + def SetWeights(self, request, context): + print(f"Called SetWeights on device {self.device.device_id}") + weights = proto_to_state_dict(request.weights.weights) + self.device.set_weights(weights, request.on_client) + return connection_pb2.SetWeightsResponse() + + def TrainEpoch(self, request, context): + print(f"Called TrainEpoch on device {self.device.device_id}") + device_info = request.server + device_id = device_info.device_id + round_no = request.round_no + weights, diagnostic_metrics = self.device.train_epoch(device_id, round_no) + proto_weights = state_dict_to_proto(weights) + return connection_pb2.TrainEpochResponse( + weights=Weights(weights=proto_weights), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + + def TrainBatch(self, request, context): + activations = proto_to_tensor(request.smashed_data.activations) + labels = proto_to_tensor(request.labels.labels) + gradients, loss, diagnostic_metrics = self.device.train_batch( + activations, labels + ) + proto_gradients = Gradients(gradients=tensor_to_proto(gradients)) + return connection_pb2.TrainBatchResponse( + gradients=proto_gradients, + loss=loss, + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + + def EvaluateGlobal(self, request, context): + print(f"Called EvaluateGlobal on device {self.device.device_id}") + metrics, diagnostic_metrics = self.device.evaluate_global( + request.validation, request.federated + ) + return connection_pb2.EvalGlobalResponse( + metrics=metrics_to_proto(metrics), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + + def Evaluate(self, request, context): + print(f"Called Evaluate on device {self.device.device_id}") + diagnostic_metrics = self.device.evaluate( + request.server.device_id, request.validation + ) + return connection_pb2.EvalResponse( + diagnostic_metrics=metrics_to_proto(diagnostic_metrics) + ) + + def EvaluateBatch(self, request, context): + activations = proto_to_tensor(request.smashed_data.activations) + labels = proto_to_tensor(request.labels.labels) + diagnostic_metrics = self.device.evaluate_batch(activations, labels) + return connection_pb2.EvalBatchResponse( + diagnostic_metrics=metrics_to_proto(diagnostic_metrics) + ) + + def FullModelTraining(self, request, context): + print(f"Called Full Training on device {self.device.device_id}") + client_weights, server_weights, num_samples, metrics, diagnostic_metrics = ( + self.device.federated_train(request.round_no) + ) + return connection_pb2.FullModelTrainResponse( + client_weights=Weights(weights=state_dict_to_proto(client_weights)), + server_weights=Weights(weights=state_dict_to_proto(server_weights)), + num_samples=num_samples, + metrics=metrics_to_proto(metrics), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + + def StartExperiment(self, request, context) -> StartExperimentResponse: + print(f"Start Experiment on device {self.device.device_id}") + self.device.start_experiment() + return connection_pb2.StartExperimentResponse() + + def EndExperiment(self, request, context) -> EndExperimentResponse: + print(f"End Experiment on device {self.device.device_id}") + print(f"Remaining battery capacity {self.device.battery.remaining_capacity()}") + self.device.end_experiment() + return connection_pb2.EndExperimentResponse() + + def GetBatteryStatus(self, request, context): + print(f"Get Battery Status on device {self.device.device_id}") + initial_capacity, remaining_capacity = self.device.get_battery_status() + + return connection_pb2.BatteryStatusResponse( + status=BatteryStatus( + initial_battery_level=initial_capacity, + current_battery_level=remaining_capacity, + ) + ) + + def GetDatasetModelInfo(self, request, context): + print(f"Get Dataset and Model Info on device {self.device.device_id}") + return connection_pb2.DatasetModelInfoResponse( + train_samples=len(self.device.client._train_data.dataset), + validation_samples=len(self.device.client._val_data.dataset), + client_model_flops=int(self.device.client._model_flops), + server_model_flops=int(self.device.server._model_flops), + ) + + def TrainGlobalParallelSplitLearning(self, request, context): + print(f"Starting parallel split learning") + clients = self.device.__get_device_ids__() + round_no = request.round_no + adaptive_learning_threshold = request.adaptive_learning_threshold + + cw, sw, model_metrics, optimizer_state, diagnostic_metrics = ( + self.device.train_parallel_split_learning( + clients=clients, + round_no=round_no, + adaptive_learning_threshold=adaptive_learning_threshold, + ) + ) + response = connection_pb2.TrainGlobalParallelSplitLearningResponse( + client_weights=Weights(weights=state_dict_to_proto(cw)), + server_weights=Weights(weights=state_dict_to_proto(sw)), + metrics=metrics_to_proto(model_metrics), + optimizer_state=state_dict_to_proto(optimizer_state), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + return response + + def TrainSingleBatchOnClient(self, request, context): + batch_index = request.batch_index + print(f"Starting single batch@{batch_index}") + + smashed_data, labels = self.device.client.train_single_batch(batch_index) + + smashed_data = Activations(activations=tensor_to_proto(smashed_data)) + labels = Labels(labels=tensor_to_proto(labels)) + return connection_pb2.SingleBatchTrainingResponse( + smashed_data=smashed_data, + labels=labels, + ) + + def BackwardPropagationSingleBatchOnClient( + self, request: SingleBatchBackwardRequest, context + ): + gradients = proto_to_tensor(request.gradients.gradients) + + metrics, gradients = self.device.client.backward_single_batch( + gradients=gradients + ) + return connection_pb2.SingleBatchBackwardResponse( + metrics=metrics_to_proto(metrics), + gradients=Gradients(gradients=tensor_to_proto(gradients)), + ) + + def SetGradientsAndFinalizeTrainingStep( + self, request: SetGradientsRequest, context + ): + gradients = proto_to_tensor(request.gradients.gradients) + self.device.client.set_gradient_and_finalize_training(gradients=gradients) + return Empty() + + +class DeviceRequestDispatcher: + + def __init__( + self, + devices: List[DictConfig[str, Any]], + logger: Optional[SimpleLogger] = None, + battery: Optional[Battery] = None, + stop_event: Optional[threading.Event] = None, + device_id: Optional[str] = None, + ): + self.devices = devices + self.connections: Dict[str, DeviceStub] = {} + # optional, interceptor only works if all three are set + self.logger = logger + self.battery = battery + self.stop_event = stop_event + + self._establish_connections() + self.connections_lock = threading.Lock() + self.device_id = device_id # used for diagnostic metrics to assign the source device correctly + + def __get_device_address__(self, device_id: str) -> Optional[str]: + for device in self.devices: + if device.device_id == device_id: + return device.address + return None + + def train_parallel_on_server( + self, + server_device_id: str, + epochs: int, + round_no: int, + adaptive_learning_threshold: Optional[float] = None, + optimizer_state: dict[str, Any] = None, + ): + print(f"><><><> {adaptive_learning_threshold}") + + try: + response: TrainGlobalParallelSplitLearningResponse = self._get_connection( + server_device_id + ).TrainGlobalParallelSplitLearning( + connection_pb2.TrainGlobalParallelSplitLearningRequest( + round_no=round_no, + adaptive_learning_threshold=adaptive_learning_threshold, + optimizer_state=state_dict_to_proto(optimizer_state), + ) + ) + return ( + proto_to_weights(response.client_weights), + proto_to_weights(response.server_weights), + proto_to_metrics(response.metrics), + proto_to_state_dict(response.optimizer_state), + self._add_byte_size_to_diagnostic_metrics(response, self.device_id), + ) + except grpc.RpcError: + self._handle_rpc_error(server_device_id) + except KeyError: + self._handle_unknown_device_id(server_device_id) + return False + + def _establish_connections(self): + for device in self.devices: + channel = grpc.insecure_channel( + device.address, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + if ( + self.logger is not None + and self.battery is not None + and self.stop_event is not None + ): + channel = grpc.intercept_channel( + channel, + DeviceClientInterceptor(self.logger, self.battery, self.stop_event), + ) + stub = DeviceStub(channel) + self.connections[device.device_id] = stub + print(f"active devices {self.connections.keys()}") + + def _get_connection(self, device_id: str) -> DeviceStub: + """Returns the connection to the device with the given id. + If the device_id is not a valid dictionary key, the caller has to handle the KeyError. + """ + with self.connections_lock: + return self.connections[device_id] + + def active_devices(self) -> List[str]: + with self.connections_lock: # lock connections to prevent remove while iterating + return list(self.connections.keys()) + + def _handle_rpc_error(self, device_id: str): + """Hook for handling an RpcError. Happens when the device encounters an empty battery during handling + the request or is not reachable (i.e. most probably shut down). + Here, the connection to the device is removed using a lock to avoid race conditions. + """ + with self.connections_lock: + del self.connections[device_id] + + def _handle_unknown_device_id(self, device_id: str): + """Handler in case the requested device id is not in the connection dictionary. + Could happen if the active devices weren't refreshed properly.""" + if device_id in [device.device_id for device in self.devices]: + print(f"Device {device_id} not active") + else: + print(f"Unknown Device ID {device_id}") + + def _add_byte_size_to_diagnostic_metrics( + self, response: Message, device_id: str, request=None + ): + """ + Adds the byte size of the response to the diagnostic metrics. + + Args: + response: A protobuf message. + device_id: The id of the device the request was sent from. + request: A protobuf message, if the request is of interest. + Returns: + A DiagnosticMetricResultContainer including the added byte size and previous metrics from the response (and possibly the request). + Raises: + None + Notes: + """ + if response.HasField("diagnostic_metrics"): + response: HasMetrics + diagnostic_metrics = proto_to_metrics(response.diagnostic_metrics) + else: + diagnostic_metrics = DiagnosticMetricResultContainer() + diagnostic_metrics.merge(_proto_size_per_field(response, device_id)) + if request is not None: + diagnostic_metrics.merge(_proto_size_per_field(request, device_id)) + return diagnostic_metrics + + def train_global_on( + self, + device_id: str, + epochs: int, + round_no: int = -1, + optimizer_state: dict[str, Any] = None, + ) -> Union[ + Tuple[ + Dict[str, Any], + Dict[str, Any], + ModelMetricResultContainer, + Dict[str, Any], + DiagnosticMetricResultContainer, + ], + bool, + ]: + try: + response: TrainGlobalResponse = self._get_connection(device_id).TrainGlobal( + connection_pb2.TrainGlobalRequest( + epochs=epochs, + round_no=round_no, + optimizer_state=state_dict_to_proto(optimizer_state), + ) + ) + return ( + proto_to_weights(response.client_weights), + proto_to_weights(response.server_weights), + proto_to_metrics(response.metrics), + proto_to_state_dict(response.optimizer_state), + self._add_byte_size_to_diagnostic_metrics(response, self.device_id), + ) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def set_weights_on( + self, device_id: str, state_dict, on_client: bool, wait_for_ready: bool = False + ): + try: + self._get_connection(device_id).SetWeights( + SetWeightsRequest( + weights=Weights(weights=state_dict_to_proto(state_dict)), + on_client=on_client, + ), + wait_for_ready=wait_for_ready, + ) + return True + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def train_epoch_on(self, device_id: str, server_device: str, round_no: int = -1): + try: + response: TrainEpochResponse = self._get_connection(device_id).TrainEpoch( + connection_pb2.TrainEpochRequest( + server=DeviceInfo( + device_id=server_device, + address=self.__get_device_address__(server_device), + ), + round_no=round_no, + ) + ) + return proto_to_state_dict( + response.weights.weights + ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def train_batch_on(self, device_id: str, smashed_data, labels): + try: + request = TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(smashed_data)), + labels=Labels(labels=tensor_to_proto(labels)), + ) + response: TrainBatchResponse = self._get_connection(device_id).TrainBatch( + request + ) + return ( + proto_to_tensor(response.gradients.gradients), + response.loss, + self._add_byte_size_to_diagnostic_metrics( + response, self.device_id, request + ), + ) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def evaluate_global_on(self, device_id: str, val: bool = True, fed: bool = False): + try: + response: EvalGlobalResponse = self._get_connection( + device_id + ).EvaluateGlobal( + connection_pb2.EvalGlobalRequest(validation=val, federated=fed) + ) + return proto_to_metrics( + response.metrics + ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def evaluate_on(self, device_id: str, server_device: str, val: bool): + try: + response: EvalResponse = self._get_connection(device_id).Evaluate( + connection_pb2.EvalRequest( + server=DeviceInfo( + device_id=server_device, + address=self.__get_device_address__(server_device), + ), + validation=val, + ) + ) + return self._add_byte_size_to_diagnostic_metrics(response, self.device_id) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def evaluate_batch_on(self, device_id: str, smashed_data, labels): + try: + request = connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(smashed_data)), + labels=Labels(labels=tensor_to_proto(labels)), + ) + response: EvalBatchResponse = self._get_connection(device_id).EvaluateBatch( + request + ) + return self._add_byte_size_to_diagnostic_metrics( + response, self.device_id, request + ) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def federated_train_on(self, device_id: str, round_no: int = -1) -> Union[ + Tuple[ + StateDict, + StateDict, + int, + ], + bool, + ]: + try: + response: FullModelTrainResponse = self._get_connection( + device_id + ).FullModelTraining(connection_pb2.FullModelTrainRequest(round_no=round_no)) + return ( + proto_to_weights(response.client_weights), + proto_to_weights(response.server_weights), + response.num_samples, + proto_to_metrics(response.metrics), + self._add_byte_size_to_diagnostic_metrics(response, self.device_id), + ) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def start_experiment_on(self, device_id: str, wait_for_ready: bool = False) -> bool: + try: + self._get_connection(device_id).StartExperiment( + connection_pb2.StartExperimentRequest(), wait_for_ready=wait_for_ready + ) + return True + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def end_experiment_on(self, device_id: str) -> bool: + try: + self._get_connection(device_id).EndExperiment( + connection_pb2.EndExperimentRequest() + ) + return True + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def get_battery_status_on(self, device_id: str) -> DeviceBatteryStatusReport: + try: + response: BatteryStatusResponse = self._get_connection( + device_id + ).GetBatteryStatus(connection_pb2.BatteryStatusRequest()) + return DeviceBatteryStatus( + current_capacity=response.status.current_battery_level, + initial_capacity=response.status.initial_battery_level, + ) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def train_batch_on_client_only( + self, device_id: str, batch_index: int + ) -> Tuple[Tensor, Tensor] | None: + try: + response: SingleBatchTrainingResponse = self._get_connection( + device_id + ).TrainSingleBatchOnClient( + connection_pb2.SingleBatchTrainingRequest(batch_index=batch_index) + ) + + # The response can only be None if the last batch was smaller than the configured batch size. + if response.HasField("smashed_data"): + return ( + proto_to_tensor(response.smashed_data.activations), + proto_to_tensor(response.labels.labels), + ) + + return None + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def get_dataset_model_info_on( + self, device_id: str + ) -> Union[Tuple[int, int, float, float], bool]: + try: + response: DatasetModelInfoResponse = self._get_connection( + device_id + ).GetDatasetModelInfo(connection_pb2.DatasetModelInfoRequest()) + return ( + response.train_samples, + response.validation_samples, + response.client_model_flops, + response.server_model_flops, + ) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def backpropagation_on_client_only(self, device_id, gradients): + try: + response: SingleBatchBackwardResponse = self._get_connection( + device_id + ).BackwardPropagationSingleBatchOnClient( + connection_pb2.SingleBatchBackwardRequest( + gradients=Gradients(gradients=tensor_to_proto(gradients)) + ) + ) + return ( + None, + proto_to_tensor(response.gradients.gradients), + ) + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + return False + + def set_gradient_and_finalize_training_on_client_only( + self, client_id: str, gradients: Any + ): + try: + response: Empty = self._get_connection( + client_id + ).SetGradientsAndFinalizeTrainingStep( + connection_pb2.SetGradientsRequest( + gradients=Gradients(gradients=tensor_to_proto(gradients)) + ) + ) + return response + except grpc.RpcError: + self._handle_rpc_error(client_id) + except KeyError: + self._handle_unknown_device_id(client_id) + return False diff --git a/edml/core/server.py b/edml/core/server.py new file mode 100644 index 0000000000000000000000000000000000000000..8853bd92bc3d5dc438e7ada06cb4dce5e62967c5 --- /dev/null +++ b/edml/core/server.py @@ -0,0 +1,399 @@ +from __future__ import annotations + +import concurrent.futures +import time +from typing import List, Optional, Tuple, Any, TYPE_CHECKING + +import torch +from omegaconf import DictConfig +from colorama import init, Fore +from torch import nn +from torch.autograd import Variable + +from edml.helpers.config_helpers import get_torch_device_id +from edml.helpers.decorators import check_device_set, simulate_latency_decorator +from edml.helpers.executor import create_executor_with_threads +from edml.helpers.flops import estimate_model_flops +from edml.helpers.load_optimizer import get_optimizer_and_scheduler +from edml.helpers.metrics import ( + create_metrics, + ModelMetricResultContainer, + DiagnosticMetricResultContainer, +) +from edml.helpers.types import StateDict, SLTrainBatchResult, LossFn + +if TYPE_CHECKING: + from edml.core.device import Device + + +class DeviceServer: + """Split learning server that runs on a device and communicates with clients on (potentially) other devices + through the provided interface by its device.""" + + def __init__( + self, + model: nn.Module, + loss_fn: LossFn, + cfg: DictConfig, + latency_factor: float = 0.0, + ): + """Initializes the server with the given model, loss function, configuration and reference to its device.""" + self._device = torch.device(get_torch_device_id(cfg)) + self._model = model.to(self._device) + self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( + cfg, self._model.parameters() + ) + self._model_flops = 0 # determine later + self._metrics = create_metrics( + cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting + ) + self._loss_fn = loss_fn + self._cfg = cfg + self.node_device: Optional[Device] = None + self.latency_factor = latency_factor + + def set_device(self, node_device: Device): + """Sets the device reference for the server.""" + self.node_device = node_device + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def set_weights(self, state_dict: StateDict): + """Sets the weights of the server's model""" + if state_dict is not None: + self._model.load_state_dict(state_dict=state_dict) + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def get_weights(self): + """Returns the weights of the server's model""" + return self._model.state_dict() + + @check_device_set() + def train( + self, + devices: List[str], + epochs: int = 1, + round_no: int = -1, + optimizer_state: dict[str, Any] = None, + ) -> Tuple[ + Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer + ]: + """Train the model on the given devices for the given number of epochs. + Shares the weights among clients and saves the final weights to the configured paths. + Args: + devices: The devices to train on + epochs: Optionally, the number of epochs to train. + round_no: Optionally, the current global epoch number if a learning rate scheduler is used. + optimizer_state: Optionally, the optimizer_state to proceed from + """ + client_weights = None + metrics = ModelMetricResultContainer() + diagnostic_metric_container = DiagnosticMetricResultContainer() + if optimizer_state is not None: + self._optimizer.load_state_dict(optimizer_state) + for epoch in range(epochs): + for device_id in devices: + print( + f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}" + ) + if client_weights is not None: + self.node_device.set_weights_on( + device_id=device_id, + state_dict=client_weights, + on_client=True, # we want to set client weights + ) + train_epoch_response = self.node_device.train_epoch_on( + device_id, self.node_device.device_id, round_no=round_no + epoch + ) + if ( + train_epoch_response is not False + and train_epoch_response is not None + ): + client_weights, diagnostic_metrics = train_epoch_response + train_metrics = self.finalize_metrics(str(device_id), "train") + diagnostic_metric_container.merge(diagnostic_metrics) + + diagnostic_metrics = self.node_device.evaluate_on( + device_id, server_device=self.node_device.device_id, val=True + ) + if diagnostic_metrics is not None: + diagnostic_metric_container.merge(diagnostic_metrics) + val_metrics = self.finalize_metrics(str(device_id), "val") + + metrics.add_results(train_metrics) + metrics.add_results(val_metrics) + if self._lr_scheduler is not None: + if round_no != -1: + self._lr_scheduler.step(round_no + epoch) + else: + self._lr_scheduler.step() + return ( + client_weights, + self.get_weights(), + metrics, + self._optimizer.state_dict(), + diagnostic_metric_container, + ) + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]: + """Train the model on the given batch of data and labels. + Returns the gradients of the model's parameters.""" + smashed_data, labels = smashed_data.to(self._device), labels.to(self._device) + + self._set_model_flops(smashed_data) + + self._optimizer.zero_grad() + + self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) + smashed_data = Variable(smashed_data, requires_grad=True) + output_train = self._model(smashed_data) + + loss_train = self._loss_fn(output_train, labels) + + self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2) + loss_train.backward() + self._optimizer.step() + + # Capturing training metrics for the current batch. + self.node_device.log({"loss": loss_train.item()}) + self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int()) + + return smashed_data.grad, loss_train.item() + + def _set_model_flops(self, smashed_data): + """Helper to determine the model flops when smashed data are available for the first time.""" + if self._model_flops == 0: + self._model_flops = estimate_model_flops(self._model, smashed_data) / len( + smashed_data + ) + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def finalize_metrics(self, device_id: str, phase: str): + """Computes the total results of the metrics. Logs the results clears the cached predictions. + Returns a list of results.""" + metric_result_list = self._metrics.compute_metrics( + phase=phase, device_id=device_id + ) + for metric_result in metric_result_list: + self.node_device.log(metric_result.as_loggable_dict()) + self._metrics.reset_metrics() + return metric_result_list + + @check_device_set() + def evaluate_global( + self, devices: List[str], val: bool + ) -> Tuple[ModelMetricResultContainer, DiagnosticMetricResultContainer]: + """Evaluates on the given devices using the own server model. Returns the gathered metrics.""" + result_metrics = ModelMetricResultContainer() + diagnostic_metric_results = DiagnosticMetricResultContainer() + for device_id in devices: + phase = "val" if val else "test" + print( + f"Evaluate with {phase} data on client {device_id} with server {self.node_device.device_id}" + ) + + diagnostic_metrics = self.node_device.evaluate_on( + device_id, server_device=self.node_device.device_id, val=val + ) + + metrics = self.finalize_metrics(str(device_id), f"{phase}") + result_metrics.add_results(metrics) + diagnostic_metric_results.merge(diagnostic_metrics) + return result_metrics, diagnostic_metric_results + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def evaluate_batch(self, smashed_data, labels): + """Evaluates the model on the given batch of data and labels""" + with torch.no_grad(): + smashed_data = smashed_data.to(self._device) + self._set_model_flops(smashed_data) + self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) + pred = self._model(smashed_data) + self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int()) + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def train_parallel_split_learning( + self, + clients: List[str], + round_no: int, + adaptive_learning_threshold: Optional[float] = None, + optimizer_state: dict[str, Any] = None, + ): + def client_training_job(client_id: str, batch_index: int): + result = self.node_device.train_batch_on_client_only_on( + device_id=client_id, batch_index=batch_index + ) + return (client_id, result) + + def client_backpropagation_job(client_id: str, gradients: Any): + return self.node_device.backpropagation_on_client_only_on( + client_id=client_id, gradients=gradients + ) + + def client_set_gradient_and_finalize_training_job( + client_id: str, gradients: Any + ): + return ( + self.node_device.set_gradient_and_finalize_training_on_client_only_on( + client_id=client_id, gradients=gradients + ) + ) + + if optimizer_state is not None: + self._optimizer.load_state_dict(optimizer_state) + + num_threads = len(clients) + executor = create_executor_with_threads(num_threads) + + # batches = [] + model_metrics = ModelMetricResultContainer() + diagnostic_metrics = DiagnosticMetricResultContainer() + + # We iterate over each batch, initializing all client training at once and processing the results afterward. + num_batches = self.node_device.client.get_approximated_num_batches() + print(f"\n\n:: BATCHES :: {num_batches}\n\n") + for batch_index in range(num_batches): + client_forward_pass_responses = [] + + futures = [ + executor.submit(client_training_job, client_id, batch_index) + for client_id in clients + ] + for future in concurrent.futures.as_completed(futures): + (client_id, result) = future.result() + if result is not None and result is not False: + client_forward_pass_responses.append((client_id, result)) + + # We want to split up the responses into a list of client IDs and batches again. + client_ids = [b[0] for b in client_forward_pass_responses] + client_batches = [b[1] for b in client_forward_pass_responses] + + print(f"\n\n\nBATCHES: {len(client_batches)}\n\n\n") + server_batch = _concat_smashed_data( + [b[0].to(self._device) for b in client_batches] + ) + server_labels = _concat_smashed_data( + [b[1].to(self._device) for b in client_batches] + ) + + # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need + # to split the gradients back into batch-sized tensors to average them before sending them to the client. + server_gradients, server_loss, server_metrics = ( + self.node_device.train_batch(server_batch, server_labels) + ) # DiagnosticMetricResultContainer + + # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only + # do the client propagation once the current loss value is larger than the threshold. + print( + f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}" + ) + if ( + adaptive_learning_threshold + and server_loss < adaptive_learning_threshold + ): + print( + f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}" + ) + self.node_device.log({"adaptive_learning_threshold_applied": True}) + continue + + num_client_gradients = len(client_forward_pass_responses) + print( + f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}" + ) + + client_gradients = torch.chunk(server_gradients, num_client_gradients) + + futures = [ + executor.submit( + client_backpropagation_job, client_id, client_gradients[idx] + ) + for (idx, client_id) in enumerate(client_ids) + ] + client_backpropagation_results = [] + for future in concurrent.futures.as_completed(futures): + client_backpropagation_results.append(future.result()) + + client_backpropagation_gradients = [ + result[1] + for result in client_backpropagation_results + if result is not None and result is not False + ] + + # We want to average the client's backpropagation gradients and send them over again to finalize the + # current training step. + averaged_gradient = _calculate_gradient_mean( + client_backpropagation_gradients, self._device + ) + futures = [ + executor.submit( + client_set_gradient_and_finalize_training_job, + client_id, + averaged_gradient, + ) + for client_id in clients + ] + for future in concurrent.futures.as_completed(futures): + future.result() + + # Now we have to determine the model metrics for each client. + for client_id in clients: + train_metrics = self.finalize_metrics(str(client_id), "train") + + print(f"::: evaluating on {client_id}") + evaluation_diagnostics_metrics = self.node_device.evaluate_on( + device_id=client_id, + server_device=self.node_device.device_id, + val=True, + ) + # if evaluation_diagnostics_metrics: + # diagnostic_metrics.merge(evaluation_diagnostics_metrics) + val_metrics = self.finalize_metrics(str(client_id), "val") + + model_metrics.add_results(train_metrics) + model_metrics.add_results(val_metrics) + + optimizer_state = self._optimizer.state_dict() + if self._lr_scheduler is not None: + if round_no != -1: + self._lr_scheduler.step(round_no + 1) # epoch=1 + else: + self._lr_scheduler.step() + # delete references and free GPU memory manually + server_batch = None + server_labels = None + server_gradients = None + client_gradients = None + concatenated_client_gradients = None + mean_tensor = None + torch.cuda.empty_cache() + torch.cuda.set_device(self._device) + return ( + self.node_device.client.get_weights(), + self.get_weights(), + model_metrics, + optimizer_state, + diagnostic_metrics, + ) + + +def _calculate_gradient_mean( + gradients: List[Variable], device: str = "cpu" +) -> Variable: + num_devices = len(gradients) + weights = [1 / num_devices] * num_devices + + # We need to move all tensors to the same device to do calculations. + for i, client_gradients in enumerate(gradients): + for j, grad in enumerate(client_gradients): + gradients[i][j] = grad.to(device) + + return [ + sum(gradients[i][j] * weights[i] for i in range(num_devices)) + for j in range(len(gradients[0])) + ] + + +def _concat_smashed_data(data: List[Any]) -> Any: + """Creates a single batch tensor from a list of tensors.""" + return torch.cat(data, dim=0) diff --git a/edml/core/start_device.py b/edml/core/start_device.py new file mode 100644 index 0000000000000000000000000000000000000000..49df053b5073a334d185ce21bc30c2188a545716 --- /dev/null +++ b/edml/core/start_device.py @@ -0,0 +1,140 @@ +import threading +from concurrent import futures + +import grpc +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch import nn + +from edml.core.battery import Battery +from edml.core.client import DeviceClient +from edml.core.device import NetworkDevice, RPCDeviceServicer +from edml.core.server import DeviceServer +from edml.dataset_utils.mnist.mnist import single_batch_dataloaders +from edml.generated import connection_pb2_grpc +from edml.helpers.config_helpers import get_device_address_by_id, get_device_index_by_id +from edml.helpers.data_partitioning import DataPartitioner +from edml.helpers.interceptors import DeviceServerInterceptor +from edml.helpers.load_dataset import get_dataloaders +from edml.helpers.load_model import get_models +from edml.helpers.logging import SimpleLogger, create_logger +from edml.models.provider.base import ModelProvider + + +def _start_device_server( + device_id, + address, + client: "DeviceClient", + server: "DeviceServer", + logger: "SimpleLogger", + battery: "Battery", + devices_info, + max_threads=10, + max_message_length=-1, +): + stop_event = threading.Event() + + device = NetworkDevice( + device_id=device_id, logger=logger, battery=battery, stop_event=stop_event + ) + device.set_client(client) + device.set_server(server) + device.set_devices(devices_info) + + grpc_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=max_threads), + options=[ + ("grpc.max_send_message_length", max_message_length), + ("grpc.max_receive_message_length", max_message_length), + ], + interceptors=[ + DeviceServerInterceptor( + logger=logger, stop_event=stop_event, battery=battery + ) + ], + ) + connection_pb2_grpc.add_DeviceServicer_to_server( + RPCDeviceServicer(device=device), grpc_server + ) + grpc_server.add_insecure_port(address) + grpc_server.start() + print(f"started server {device_id}") + stop_event.wait() + grpc_server.stop(grace=None) # grace=None to stop the server immediately + print(f"stopped server {device_id}") + + +def launch_device(cfg): + device_idx = get_device_index_by_id(cfg, cfg.own_device_id) + if cfg.experiment.load_single_batch_for_debugging: + train, val, test = single_batch_dataloaders(cfg.experiment.batch_size) + else: + data_partitioner = None + if cfg.experiment.partition: + fractions = None + if cfg.experiment.fractions: + fractions = cfg.experiment.fractions + distribution = None + if "distribution" in cfg.dataset.keys(): + distribution = cfg.dataset.distribution + data_partitioner = DataPartitioner( + device_index=device_idx, + num_devices=cfg.num_devices, + seed=cfg.seed.value, + fractions=fractions, + distribution=distribution, + ) + + train, val, test = get_dataloaders( + cfg.dataset.name, + cfg.experiment.batch_size, + data_partitioner=data_partitioner, + ) + + client_model, server_model = _get_models(cfg) + latency_factor = 0.0 + if cfg.experiment.latency is not None: + latency_factor = cfg.experiment.latency[device_idx] + client = DeviceClient( + client_model, + cfg, + train_dl=train, + val_dl=val, + test_dl=test, + latency_factor=latency_factor, + ) + loss_fn = instantiate(cfg.loss_fn) + server = DeviceServer( + server_model, + loss_fn, + cfg, + latency_factor=latency_factor, + ) + battery = Battery( + cfg.topology.devices[device_idx].battery_capacity, + cfg.battery.deduction_per_second, + cfg.battery.deduction_per_mflop, + cfg.battery.deduction_per_mbyte_received, + cfg.battery.deduction_per_mbyte_sent, + ) + logger = create_logger(cfg) + _start_device_server( + device_id=cfg.own_device_id, + address=get_device_address_by_id(cfg.own_device_id, cfg), + client=client, + server=server, + logger=logger, + battery=battery, + devices_info=cfg.topology.devices[: cfg.num_devices], + max_threads=cfg.grpc.max_threads, + max_message_length=cfg.grpc.max_message_length, + ) + + +def _get_models(cfg: DictConfig) -> tuple[nn.Module, nn.Module]: + if "model_provider" in cfg: + print(f">>> Using model provider: {cfg.model_provider}") + model_provider: ModelProvider = instantiate(cfg.model_provider) + return model_provider.models + + return get_models(cfg) diff --git a/edml/dataset_utils/__init__.py b/edml/dataset_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/dataset_utils/cifar/cifar.py b/edml/dataset_utils/cifar/cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..49bfdf8b18604f227bf338fcdb78c29c9c74bcb4 --- /dev/null +++ b/edml/dataset_utils/cifar/cifar.py @@ -0,0 +1,95 @@ +import os +from typing import Optional, Tuple + +from torch.utils.data import random_split, DataLoader +from torchvision import transforms, datasets + +from edml.helpers.data_partitioning import DataPartitioner + + +def _load_cifar100(train_transform, test_transform): + train_data = datasets.CIFAR100( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=True, + download=True, + transform=train_transform, + ) + test_data = datasets.CIFAR100( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=False, + download=True, + transform=test_transform, + ) + return train_data, test_data + + +def _load_cifar10(train_transform, test_transform): + train_data = datasets.CIFAR10( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=True, + download=True, + transform=train_transform, + ) + test_data = datasets.CIFAR10( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=False, + download=True, + transform=test_transform, + ) + return train_data, test_data + + +def _get_transforms(): + # transformations from https://github.com/akamaster/pytorch_resnet_cifar10 + # However, in this repository the test data is used for validation + # Here, we use the test data for testing only and split the training data into train and validation data (90%/10%) as in the original resnet paper + train_transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32, 4), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + test_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + return train_transform, test_transform + + +def _cifar_dataloaders(batch_size, data_partitioner, split, dataset=100): + """Returns the train, validation and test dataloaders for the CIFAR100 dataset""" + train_transform, test_transform = _get_transforms() + if dataset == 100: + train_data, test_data = _load_cifar100(train_transform, test_transform) + else: + train_data, test_data = _load_cifar10(train_transform, test_transform) + # partition data for device + if data_partitioner is not None: + train_data = data_partitioner.partition(train_data) + test_data = data_partitioner.partition(test_data) + train, val = random_split(train_data, split) + return ( + DataLoader(train, batch_size=batch_size), + DataLoader(val, batch_size=batch_size), + DataLoader(test_data, batch_size=batch_size), + ) + + +def cifar100_dataloaders( + batch_size: int, + split: Tuple[float, float] = (0.9, 0.1), + data_partitioner: Optional[DataPartitioner] = None, +): + return _cifar_dataloaders(batch_size, data_partitioner, split, dataset=100) + + +def cifar10_dataloaders( + batch_size: int, + split: Tuple[float, float] = (0.9, 0.1), + data_partitioner: Optional[DataPartitioner] = None, +): + return _cifar_dataloaders(batch_size, data_partitioner, split, dataset=10) diff --git a/edml/dataset_utils/mnist/mnist.py b/edml/dataset_utils/mnist/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..f4854919bcfd1fa853f96ff914695c200c450b56 --- /dev/null +++ b/edml/dataset_utils/mnist/mnist.py @@ -0,0 +1,63 @@ +import os +from typing import Tuple, Optional + +from torchvision import transforms, datasets +from torch.utils.data import random_split, DataLoader, Subset + +from edml.helpers.data_partitioning import DataPartitioner + + +def _load_transformed_data(): + transform = transforms.Compose( + [ # preprocessing from https://github.com/pytorch/examples/blob/main/mnist/main.py + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ) + train_data = datasets.MNIST( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=True, + download=True, + transform=transform, + ) + test_data = datasets.MNIST( + os.path.join(os.path.dirname(__file__), "../../../data"), + train=False, + download=True, + transform=transform, + ) + return train_data, test_data + + +def mnist_dataloaders( + batch_size: int, + split: Tuple[float, float] = (0.8, 0.2), + data_partitioner: Optional[DataPartitioner] = None, +): + """Returns the train, validation and test dataloaders for the MNIST dataset""" + train_data, test_data = _load_transformed_data() + # partition data for device + if data_partitioner is not None: + train_data = data_partitioner.partition(train_data) + test_data = data_partitioner.partition(test_data) + train, val = random_split(train_data, split) + return ( + DataLoader(train, batch_size=batch_size), + DataLoader(val, batch_size=batch_size), + DataLoader(test_data, batch_size=batch_size), + ) + + +def single_batch_dataloaders(batch_size: int): + """Returns the train, validation and test dataloaders of the MNIST dataset with only a single batch each for testing and debugging.""" + train_data, _ = _load_transformed_data() + return ( + DataLoader(Subset(train_data, range(batch_size)), batch_size=batch_size), + DataLoader( + Subset(train_data, range(batch_size, 2 * batch_size)), batch_size=batch_size + ), + DataLoader( + Subset(train_data, range(2 * batch_size, 3 * batch_size)), + batch_size=batch_size, + ), + ) diff --git a/edml/dataset_utils/ptb_xl/mlb.pklmlb.pkl b/edml/dataset_utils/ptb_xl/mlb.pklmlb.pkl new file mode 100644 index 0000000000000000000000000000000000000000..ea9bd43c4a5974d19e8d417b0095f5eb12465848 Binary files /dev/null and b/edml/dataset_utils/ptb_xl/mlb.pklmlb.pkl differ diff --git a/edml/dataset_utils/ptb_xl/ptb_xl.py b/edml/dataset_utils/ptb_xl/ptb_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..9fff4db2fd3f11f14454647cf6e1ec9f003ccb44 --- /dev/null +++ b/edml/dataset_utils/ptb_xl/ptb_xl.py @@ -0,0 +1,102 @@ +import os +import pickle +from typing import Optional + +from torch.utils.data import Dataset, DataLoader + +from edml.dataset_utils.ptb_xl import utils +from edml.helpers.data_partitioning import DataPartitioner +from edml.helpers.types import DatasetDataLoaders + +""" +Taken from https://github.com/a-ayad/Split_ECG_Classification/ +""" + + +def load_dataset(): + sampling_frequency = 100 + cwd = os.path.dirname(os.path.abspath(__file__)) + mlb_path = os.path.join(cwd, "mlb.pkl") + dataset_path = ( + os.path.normpath( + os.path.join( + cwd, + "../../../data/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1/", + ) + ) + + "/" + ) + + task = "superdiagnostic" + + # Load PTB-XL data + data, raw_labels = utils.load_dataset(dataset_path, sampling_frequency) + # Preprocess label data + labels = utils.compute_label_aggregations(raw_labels, dataset_path, task) + # Select relevant data and convert to one-hot + data, labels, Y, _ = utils.select_data( + data, labels, task, min_samples=0, outputfolder=mlb_path + ) + + # 1-8 for training + X_train = data[labels.strat_fold < 9] + y_train = Y[labels.strat_fold < 9].astype("float64") + # 9 for validation + X_val = data[labels.strat_fold == 9] + y_val = Y[labels.strat_fold == 9].astype("float64") + # 10 for testing + X_test = data[labels.strat_fold == 10] + y_test = Y[labels.strat_fold == 10].astype("float64") + + standard_scaler = pickle.load(open(cwd + "/standard_scaler.pkl", "rb")) + + X_train = utils.apply_standardizer(X_train, standard_scaler) + X_val = utils.apply_standardizer(X_val, standard_scaler) + X_test = utils.apply_standardizer(X_test, standard_scaler) + + return X_train, y_train, X_val, y_val, X_test, y_test + + +class PTB_XL(Dataset): + """ + Class to load sample-sets (train, val, test) + """ + + def __init__(self, data, labels, stage=None): + self.stage = stage + self.data = data + self.labels = labels + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + return self.data[idx].transpose((1, 0)), self.labels[idx] + + +def ptb_xl_train_val_test( + batch_size: int, data_partitioner: Optional[DataPartitioner] = None +) -> DatasetDataLoaders: + """Returns train, val, test dataloaders for PTB-XL dataset""" + X_train, y_train, X_val, y_val, X_test, y_test = load_dataset() + train_data = PTB_XL(X_train, y_train, "train") + val_data = PTB_XL(X_val, y_val, "train") + test_data = PTB_XL(X_test, y_test, "test") + + if data_partitioner is not None: + if data_partitioner.distribution == "non-iid": + train_data = data_partitioner.partition(train_data) + # stage is lost during non-iid partitioning, set it again + train_data.stage = "train" + else: + train_data = data_partitioner.partition(train_data) + # set distribution to iid for validation and test data + data_partitioner.distribution = "iid" + val_data = data_partitioner.partition(val_data) + test_data = data_partitioner.partition(test_data) + + return ( + DataLoader(train_data, batch_size=batch_size), + DataLoader(val_data, batch_size=batch_size), + DataLoader(test_data, batch_size=batch_size), + ) diff --git a/edml/dataset_utils/ptb_xl/standard_scaler.pkl b/edml/dataset_utils/ptb_xl/standard_scaler.pkl new file mode 100644 index 0000000000000000000000000000000000000000..ed51c42579126e7dd8b64d6d0ffdaf4f8c286dc1 Binary files /dev/null and b/edml/dataset_utils/ptb_xl/standard_scaler.pkl differ diff --git a/edml/dataset_utils/ptb_xl/utils.py b/edml/dataset_utils/ptb_xl/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a8fed1aa77e2500abb7324ad5bd7e29d89dfcb17 --- /dev/null +++ b/edml/dataset_utils/ptb_xl/utils.py @@ -0,0 +1,302 @@ +import ast +import os +import pickle + +import numpy as np +import pandas as pd +import wfdb +from sklearn.metrics import roc_curve +from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer +from tqdm import tqdm + +""" +Taken from https://github.com/a-ayad/Split_ECG_Classification/ +""" + + +def get_appropriate_bootstrap_samples(y_true, n_bootstraping_samples): + samples = [] + while True: + ridxs = np.random.randint(0, len(y_true), len(y_true)) + if y_true[ridxs].sum(axis=0).min() != 0: + samples.append(ridxs) + if len(samples) == n_bootstraping_samples: + break + return samples + + +def find_optimal_cutoff_threshold(target, predicted): + """ + Find the optimal probability cutoff point for a classification model related to event rate + """ + fpr, tpr, threshold = roc_curve(target, predicted) + optimal_idx = np.argmax(tpr - fpr) + optimal_threshold = threshold[optimal_idx] + return optimal_threshold + + +def find_optimal_cutoff_thresholds(y_true, y_pred): + return [ + find_optimal_cutoff_threshold(y_true[:, i], y_pred[:, i]) + for i in range(y_true.shape[1]) + ] + + +def apply_thresholds(preds, thresholds): + """ + apply class-wise thresholds to prediction score in order to get binary format. + BUT: if no score is above threshold, pick maximum. This is needed due to metric issues. + """ + tmp = [] + for p in preds: + tmp_p = (p > thresholds).astype(int) + if np.sum(tmp_p) == 0: + tmp_p[np.argmax(p)] = 1 + tmp.append(tmp_p) + tmp = np.array(tmp) + return tmp + + +# DATA PROCESSING STUFF + + +def load_dataset(path: str, sampling_rate, release=False): + Y = pd.read_csv(path + "ptbxl_database.csv", index_col="ecg_id") + Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) + X = load_raw_data_ptbxl(Y, sampling_rate, path) + if path.split("/")[-2] == "ptbxl": + # load and convert annotation data + Y = pd.read_csv(path + "ptbxl_database.csv", index_col="ecg_id") + Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) + + # Load raw signal data + X = load_raw_data_ptbxl(Y, sampling_rate, path) + + elif path.split("/")[-2] == "ICBEB": + # load and convert annotation data + Y = pd.read_csv(path + "icbeb_database.csv", index_col="ecg_id") + Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) + + # Load raw signal data + X = load_raw_data_icbeb(Y, sampling_rate, path) + return X, Y + + +def load_raw_data_icbeb(df, sampling_rate, path): + data = np.load(path + "raw100.npy", allow_pickle=True) + if sampling_rate == 100: + if os.path.exists(path + "raw100.npy"): + data = np.load(path + "raw100.npy", allow_pickle=True) + else: + data = [wfdb.rdsamp(path + "records100/" + str(f)) for f in tqdm(df.index)] + data = np.array([signal for signal, meta in data]) + pickle.dump(data, open(path + "raw100.npy", "wb"), protocol=4) + elif sampling_rate == 500: + if os.path.exists(path + "raw500.npy"): + data = np.load(path + "raw500.npy", allow_pickle=True) + else: + data = [wfdb.rdsamp(path + "records500/" + str(f)) for f in tqdm(df.index)] + data = np.array([signal for signal, meta in data]) + pickle.dump(data, open(path + "raw500.npy", "wb"), protocol=4) + return data + + +def load_raw_data_ptbxl(df, sampling_rate, path): + # data = np.load(path + 'raw100.npy', allow_pickle=True) + if sampling_rate == 100: + if os.path.exists(path + "raw100.npy"): + data = np.load(path + "raw100.npy", allow_pickle=True) + else: + data = [wfdb.rdsamp(path + f) for f in tqdm(df.filename_lr)] + data = np.array([signal for signal, meta in data]) + pickle.dump(data, open(path + "raw100.npy", "wb"), protocol=4) + elif sampling_rate == 500: + if os.path.exists(path + "raw500.npy"): + data = np.load(path + "raw500.npy", allow_pickle=True) + else: + data = [wfdb.rdsamp(path + f) for f in tqdm(df.filename_hr)] + data = np.array([signal for signal, meta in data]) + pickle.dump(data, open(path + "raw500.npy", "wb"), protocol=4) + return data + + +def compute_label_aggregations(df, folder, ctype): + + df["scp_codes_len"] = df.scp_codes.apply(lambda x: len(x)) + + aggregation_df = pd.read_csv(folder + "scp_statements.csv", index_col=0) + + if ctype in ["diagnostic", "subdiagnostic", "superdiagnostic"]: + + def aggregate_all_diagnostic(y_dic): + tmp = [] + for key in y_dic.keys(): + if key in diag_agg_df.index: + tmp.append(key) + return list(set(tmp)) + + def aggregate_subdiagnostic(y_dic): + tmp = [] + for key in y_dic.keys(): + if key in diag_agg_df.index: + c = diag_agg_df.loc[key].diagnostic_subclass + if str(c) != "nan": + tmp.append(c) + return list(set(tmp)) + + def aggregate_diagnostic(y_dic): + tmp = [] + for key in y_dic.keys(): + if key in diag_agg_df.index: + c = diag_agg_df.loc[key].diagnostic_class + if str(c) != "nan": + tmp.append(c) + return list(set(tmp)) + + diag_agg_df = aggregation_df[aggregation_df.diagnostic == 1.0] + if ctype == "diagnostic": + df["diagnostic"] = df.scp_codes.apply(aggregate_all_diagnostic) + df["diagnostic_len"] = df.diagnostic.apply(lambda x: len(x)) + elif ctype == "subdiagnostic": + df["subdiagnostic"] = df.scp_codes.apply(aggregate_subdiagnostic) + df["subdiagnostic_len"] = df.subdiagnostic.apply(lambda x: len(x)) + elif ctype == "superdiagnostic": + df["superdiagnostic"] = df.scp_codes.apply(aggregate_diagnostic) + df["superdiagnostic_len"] = df.superdiagnostic.apply(lambda x: len(x)) + elif ctype == "form": + form_agg_df = aggregation_df[aggregation_df.form == 1.0] + + def aggregate_form(y_dic): + tmp = [] + for key in y_dic.keys(): + if key in form_agg_df.index: + c = key + if str(c) != "nan": + tmp.append(c) + return list(set(tmp)) + + df["form"] = df.scp_codes.apply(aggregate_form) + df["form_len"] = df.form.apply(lambda x: len(x)) + elif ctype == "rhythm": + rhythm_agg_df = aggregation_df[aggregation_df.rhythm == 1.0] + + def aggregate_rhythm(y_dic): + tmp = [] + for key in y_dic.keys(): + if key in rhythm_agg_df.index: + c = key + if str(c) != "nan": + tmp.append(c) + return list(set(tmp)) + + df["rhythm"] = df.scp_codes.apply(aggregate_rhythm) + df["rhythm_len"] = df.rhythm.apply(lambda x: len(x)) + elif ctype == "all": + df["all_scp"] = df.scp_codes.apply(lambda x: list(set(x.keys()))) + + return df + + +def select_data(XX, YY, ctype, min_samples, outputfolder): + # convert multilabel to multi-hot + mlb = MultiLabelBinarizer() + + if ctype == "diagnostic": + X = XX[YY.diagnostic_len > 0] + Y = YY[YY.diagnostic_len > 0] + mlb.fit(Y.diagnostic.values) + y = mlb.transform(Y.diagnostic.values) + elif ctype == "subdiagnostic": + counts = pd.Series(np.concatenate(YY.subdiagnostic.values)).value_counts() + counts = counts[counts > min_samples] + YY.subdiagnostic = YY.subdiagnostic.apply( + lambda x: list(set(x).intersection(set(counts.index.values))) + ) + YY["subdiagnostic_len"] = YY.subdiagnostic.apply(lambda x: len(x)) + X = XX[YY.subdiagnostic_len > 0] + Y = YY[YY.subdiagnostic_len > 0] + mlb.fit(Y.subdiagnostic.values) + y = mlb.transform(Y.subdiagnostic.values) + elif ctype == "superdiagnostic": + counts = pd.Series(np.concatenate(YY.superdiagnostic.values)).value_counts() + counts = counts[counts > min_samples] + YY.superdiagnostic = YY.superdiagnostic.apply( + lambda x: list(set(x).intersection(set(counts.index.values))) + ) + YY["superdiagnostic_len"] = YY.superdiagnostic.apply(lambda x: len(x)) + X = XX[YY.superdiagnostic_len > 0] + Y = YY[YY.superdiagnostic_len > 0] + mlb.fit(Y.superdiagnostic.values) + y = mlb.transform(Y.superdiagnostic.values) + elif ctype == "form": + # filter + counts = pd.Series(np.concatenate(YY.form.values)).value_counts() + counts = counts[counts > min_samples] + YY.form = YY.form.apply( + lambda x: list(set(x).intersection(set(counts.index.values))) + ) + YY["form_len"] = YY.form.apply(lambda x: len(x)) + # select + X = XX[YY.form_len > 0] + Y = YY[YY.form_len > 0] + mlb.fit(Y.form.values) + y = mlb.transform(Y.form.values) + elif ctype == "rhythm": + # filter + counts = pd.Series(np.concatenate(YY.rhythm.values)).value_counts() + counts = counts[counts > min_samples] + YY.rhythm = YY.rhythm.apply( + lambda x: list(set(x).intersection(set(counts.index.values))) + ) + YY["rhythm_len"] = YY.rhythm.apply(lambda x: len(x)) + # select + X = XX[YY.rhythm_len > 0] + Y = YY[YY.rhythm_len > 0] + mlb.fit(Y.rhythm.values) + y = mlb.transform(Y.rhythm.values) + elif ctype == "all": + # filter + counts = pd.Series(np.concatenate(YY.all_scp.values)).value_counts() + counts = counts[counts > min_samples] + YY.all_scp = YY.all_scp.apply( + lambda x: list(set(x).intersection(set(counts.index.values))) + ) + YY["all_scp_len"] = YY.all_scp.apply(lambda x: len(x)) + # select + X = XX[YY.all_scp_len > 0] + Y = YY[YY.all_scp_len > 0] + mlb.fit(Y.all_scp.values) + y = mlb.transform(Y.all_scp.values) + else: + pass + + # save LabelBinarizer + with open(outputfolder + "mlb.pkl", "wb") as tokenizer: + pickle.dump(mlb, tokenizer) + + return X, Y, y, mlb + + +def preprocess_signals(X_train, X_validation, X_test, outputfolder): + # Standardize data such that mean 0 and variance 1 + ss = StandardScaler() + ss.fit(np.vstack(X_train).flatten()[:, np.newaxis].astype(float)) + + # Save Standardizer data + with open(outputfolder + "standard_scaler.pkl", "wb") as ss_file: + pickle.dump(ss, ss_file) + + return ( + apply_standardizer(X_train, ss), + apply_standardizer(X_validation, ss), + apply_standardizer(X_test, ss), + ) + + +def apply_standardizer(X, ss): + X_tmp = [] + for x in X: + x_shape = x.shape + X_tmp.append(ss.transform(x.flatten()[:, np.newaxis]).reshape(x_shape)) + X_tmp = np.array(X_tmp) + return X_tmp diff --git a/edml/generated/__init__.py b/edml/generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..073748618f2d2ff06d1eb5f7f0c0c0b36775730a --- /dev/null +++ b/edml/generated/__init__.py @@ -0,0 +1,5 @@ +import sys +import os + +# Add the parent directory to the path to make the import statements inside the generated grpc code work +sys.path.append(os.path.join(os.path.dirname(__file__))) diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..ce271bbab4fa5d30e0e8c5a0a8b309d96eaccf07 --- /dev/null +++ b/edml/generated/connection_pb2.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: connection.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import datastructures_pb2 as datastructures__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'connection_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _globals['_SETGRADIENTSREQUEST']._serialized_start=42 + _globals['_SETGRADIENTSREQUEST']._serialized_end=94 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=96 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=149 + _globals['_SINGLEBATCHBACKWARDREQUEST']._serialized_start=151 + _globals['_SINGLEBATCHBACKWARDREQUEST']._serialized_end=210 + _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_start=212 + _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_end=318 + _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_start=320 + _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=369 + _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=372 + _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=500 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=503 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=716 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=719 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=984 + _globals['_TRAINGLOBALREQUEST']._serialized_start=987 + _globals['_TRAINGLOBALREQUEST']._serialized_end=1121 + _globals['_TRAINGLOBALRESPONSE']._serialized_start=1124 + _globals['_TRAINGLOBALRESPONSE']._serialized_end=1368 + _globals['_SETWEIGHTSREQUEST']._serialized_start=1370 + _globals['_SETWEIGHTSREQUEST']._serialized_end=1435 + _globals['_SETWEIGHTSRESPONSE']._serialized_start=1437 + _globals['_SETWEIGHTSRESPONSE']._serialized_end=1523 + _globals['_TRAINEPOCHREQUEST']._serialized_start=1525 + _globals['_TRAINEPOCHREQUEST']._serialized_end=1609 + _globals['_TRAINEPOCHRESPONSE']._serialized_start=1611 + _globals['_TRAINEPOCHRESPONSE']._serialized_end=1724 + _globals['_TRAINBATCHREQUEST']._serialized_start=1726 + _globals['_TRAINBATCHREQUEST']._serialized_end=1806 + _globals['_TRAINBATCHRESPONSE']._serialized_start=1809 + _globals['_TRAINBATCHRESPONSE']._serialized_end=1954 + _globals['_EVALGLOBALREQUEST']._serialized_start=1956 + _globals['_EVALGLOBALREQUEST']._serialized_end=2014 + _globals['_EVALGLOBALRESPONSE']._serialized_start=2016 + _globals['_EVALGLOBALRESPONSE']._serialized_end=2129 + _globals['_EVALREQUEST']._serialized_start=2131 + _globals['_EVALREQUEST']._serialized_end=2193 + _globals['_EVALRESPONSE']._serialized_start=2195 + _globals['_EVALRESPONSE']._serialized_end=2275 + _globals['_EVALBATCHREQUEST']._serialized_start=2277 + _globals['_EVALBATCHREQUEST']._serialized_end=2356 + _globals['_EVALBATCHRESPONSE']._serialized_start=2358 + _globals['_EVALBATCHRESPONSE']._serialized_end=2470 + _globals['_FULLMODELTRAINREQUEST']._serialized_start=2472 + _globals['_FULLMODELTRAINREQUEST']._serialized_end=2531 + _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2534 + _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2740 + _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2742 + _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2766 + _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2768 + _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2859 + _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2861 + _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2883 + _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2885 + _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2974 + _globals['_BATTERYSTATUSREQUEST']._serialized_start=2976 + _globals['_BATTERYSTATUSREQUEST']._serialized_end=2998 + _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3000 + _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3121 + _globals['_DATASETMODELINFOREQUEST']._serialized_start=3123 + _globals['_DATASETMODELINFOREQUEST']._serialized_end=3148 + _globals['_DATASETMODELINFORESPONSE']._serialized_start=3151 + _globals['_DATASETMODELINFORESPONSE']._serialized_end=3350 + _globals['_DEVICE']._serialized_start=3353 + _globals['_DEVICE']._serialized_end=4497 +# @@protoc_insertion_point(module_scope) diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a9735505e808ee35b156158e972ab6b206e90b0a --- /dev/null +++ b/edml/generated/connection_pb2.pyi @@ -0,0 +1,258 @@ +import datastructures_pb2 as _datastructures_pb2 +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class SetGradientsRequest(_message.Message): + __slots__ = ["gradients"] + GRADIENTS_FIELD_NUMBER: _ClassVar[int] + gradients: _datastructures_pb2.Gradients + def __init__(self, gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ... + +class UpdateWeightsRequest(_message.Message): + __slots__ = ["gradients"] + GRADIENTS_FIELD_NUMBER: _ClassVar[int] + gradients: _datastructures_pb2.Gradients + def __init__(self, gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ... + +class SingleBatchBackwardRequest(_message.Message): + __slots__ = ["gradients"] + GRADIENTS_FIELD_NUMBER: _ClassVar[int] + gradients: _datastructures_pb2.Gradients + def __init__(self, gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ... + +class SingleBatchBackwardResponse(_message.Message): + __slots__ = ["metrics", "gradients"] + METRICS_FIELD_NUMBER: _ClassVar[int] + GRADIENTS_FIELD_NUMBER: _ClassVar[int] + metrics: _datastructures_pb2.Metrics + gradients: _datastructures_pb2.Gradients + def __init__(self, metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ... + +class SingleBatchTrainingRequest(_message.Message): + __slots__ = ["batch_index"] + BATCH_INDEX_FIELD_NUMBER: _ClassVar[int] + batch_index: int + def __init__(self, batch_index: _Optional[int] = ...) -> None: ... + +class SingleBatchTrainingResponse(_message.Message): + __slots__ = ["smashed_data", "labels"] + SMASHED_DATA_FIELD_NUMBER: _ClassVar[int] + LABELS_FIELD_NUMBER: _ClassVar[int] + smashed_data: _datastructures_pb2.Activations + labels: _datastructures_pb2.Labels + def __init__(self, smashed_data: _Optional[_Union[_datastructures_pb2.Activations, _Mapping]] = ..., labels: _Optional[_Union[_datastructures_pb2.Labels, _Mapping]] = ...) -> None: ... + +class TrainGlobalParallelSplitLearningRequest(_message.Message): + __slots__ = ["round_no", "adaptive_learning_threshold", "optimizer_state"] + ROUND_NO_FIELD_NUMBER: _ClassVar[int] + ADAPTIVE_LEARNING_THRESHOLD_FIELD_NUMBER: _ClassVar[int] + OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] + round_no: int + adaptive_learning_threshold: float + optimizer_state: _datastructures_pb2.StateDict + def __init__(self, round_no: _Optional[int] = ..., adaptive_learning_threshold: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... + +class TrainGlobalParallelSplitLearningResponse(_message.Message): + __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"] + CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int] + SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int] + METRICS_FIELD_NUMBER: _ClassVar[int] + OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + client_weights: _datastructures_pb2.Weights + server_weights: _datastructures_pb2.Weights + metrics: _datastructures_pb2.Metrics + optimizer_state: _datastructures_pb2.StateDict + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class TrainGlobalRequest(_message.Message): + __slots__ = ["epochs", "round_no", "optimizer_state"] + EPOCHS_FIELD_NUMBER: _ClassVar[int] + ROUND_NO_FIELD_NUMBER: _ClassVar[int] + OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] + epochs: int + round_no: int + optimizer_state: _datastructures_pb2.StateDict + def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... + +class TrainGlobalResponse(_message.Message): + __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"] + CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int] + SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int] + METRICS_FIELD_NUMBER: _ClassVar[int] + OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + client_weights: _datastructures_pb2.Weights + server_weights: _datastructures_pb2.Weights + metrics: _datastructures_pb2.Metrics + optimizer_state: _datastructures_pb2.StateDict + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class SetWeightsRequest(_message.Message): + __slots__ = ["weights", "on_client"] + WEIGHTS_FIELD_NUMBER: _ClassVar[int] + ON_CLIENT_FIELD_NUMBER: _ClassVar[int] + weights: _datastructures_pb2.Weights + on_client: bool + def __init__(self, weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., on_client: bool = ...) -> None: ... + +class SetWeightsResponse(_message.Message): + __slots__ = ["diagnostic_metrics"] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class TrainEpochRequest(_message.Message): + __slots__ = ["server", "round_no"] + SERVER_FIELD_NUMBER: _ClassVar[int] + ROUND_NO_FIELD_NUMBER: _ClassVar[int] + server: _datastructures_pb2.DeviceInfo + round_no: int + def __init__(self, server: _Optional[_Union[_datastructures_pb2.DeviceInfo, _Mapping]] = ..., round_no: _Optional[int] = ...) -> None: ... + +class TrainEpochResponse(_message.Message): + __slots__ = ["weights", "diagnostic_metrics"] + WEIGHTS_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + weights: _datastructures_pb2.Weights + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class TrainBatchRequest(_message.Message): + __slots__ = ["smashed_data", "labels"] + SMASHED_DATA_FIELD_NUMBER: _ClassVar[int] + LABELS_FIELD_NUMBER: _ClassVar[int] + smashed_data: _datastructures_pb2.Activations + labels: _datastructures_pb2.Labels + def __init__(self, smashed_data: _Optional[_Union[_datastructures_pb2.Activations, _Mapping]] = ..., labels: _Optional[_Union[_datastructures_pb2.Labels, _Mapping]] = ...) -> None: ... + +class TrainBatchResponse(_message.Message): + __slots__ = ["gradients", "diagnostic_metrics", "loss"] + GRADIENTS_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + LOSS_FIELD_NUMBER: _ClassVar[int] + gradients: _datastructures_pb2.Gradients + diagnostic_metrics: _datastructures_pb2.Metrics + loss: float + def __init__(self, gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., loss: _Optional[float] = ...) -> None: ... + +class EvalGlobalRequest(_message.Message): + __slots__ = ["validation", "federated"] + VALIDATION_FIELD_NUMBER: _ClassVar[int] + FEDERATED_FIELD_NUMBER: _ClassVar[int] + validation: bool + federated: bool + def __init__(self, validation: bool = ..., federated: bool = ...) -> None: ... + +class EvalGlobalResponse(_message.Message): + __slots__ = ["metrics", "diagnostic_metrics"] + METRICS_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + metrics: _datastructures_pb2.Metrics + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class EvalRequest(_message.Message): + __slots__ = ["server", "validation"] + SERVER_FIELD_NUMBER: _ClassVar[int] + VALIDATION_FIELD_NUMBER: _ClassVar[int] + server: _datastructures_pb2.DeviceInfo + validation: bool + def __init__(self, server: _Optional[_Union[_datastructures_pb2.DeviceInfo, _Mapping]] = ..., validation: bool = ...) -> None: ... + +class EvalResponse(_message.Message): + __slots__ = ["diagnostic_metrics"] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class EvalBatchRequest(_message.Message): + __slots__ = ["smashed_data", "labels"] + SMASHED_DATA_FIELD_NUMBER: _ClassVar[int] + LABELS_FIELD_NUMBER: _ClassVar[int] + smashed_data: _datastructures_pb2.Activations + labels: _datastructures_pb2.Labels + def __init__(self, smashed_data: _Optional[_Union[_datastructures_pb2.Activations, _Mapping]] = ..., labels: _Optional[_Union[_datastructures_pb2.Labels, _Mapping]] = ...) -> None: ... + +class EvalBatchResponse(_message.Message): + __slots__ = ["metrics", "diagnostic_metrics"] + METRICS_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + metrics: _datastructures_pb2.Metrics + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class FullModelTrainRequest(_message.Message): + __slots__ = ["round_no"] + ROUND_NO_FIELD_NUMBER: _ClassVar[int] + round_no: int + def __init__(self, round_no: _Optional[int] = ...) -> None: ... + +class FullModelTrainResponse(_message.Message): + __slots__ = ["client_weights", "server_weights", "num_samples", "metrics", "diagnostic_metrics"] + CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int] + SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int] + NUM_SAMPLES_FIELD_NUMBER: _ClassVar[int] + METRICS_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + client_weights: _datastructures_pb2.Weights + server_weights: _datastructures_pb2.Weights + num_samples: int + metrics: _datastructures_pb2.Metrics + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., num_samples: _Optional[int] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class StartExperimentRequest(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + +class StartExperimentResponse(_message.Message): + __slots__ = ["diagnostic_metrics"] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class EndExperimentRequest(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + +class EndExperimentResponse(_message.Message): + __slots__ = ["diagnostic_metrics"] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class BatteryStatusRequest(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + +class BatteryStatusResponse(_message.Message): + __slots__ = ["status", "diagnostic_metrics"] + STATUS_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + status: _datastructures_pb2.BatteryStatus + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, status: _Optional[_Union[_datastructures_pb2.BatteryStatus, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + +class DatasetModelInfoRequest(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + +class DatasetModelInfoResponse(_message.Message): + __slots__ = ["train_samples", "validation_samples", "client_model_flops", "server_model_flops", "diagnostic_metrics"] + TRAIN_SAMPLES_FIELD_NUMBER: _ClassVar[int] + VALIDATION_SAMPLES_FIELD_NUMBER: _ClassVar[int] + CLIENT_MODEL_FLOPS_FIELD_NUMBER: _ClassVar[int] + SERVER_MODEL_FLOPS_FIELD_NUMBER: _ClassVar[int] + DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + train_samples: int + validation_samples: int + client_model_flops: int + server_model_flops: int + diagnostic_metrics: _datastructures_pb2.Metrics + def __init__(self, train_samples: _Optional[int] = ..., validation_samples: _Optional[int] = ..., client_model_flops: _Optional[int] = ..., server_model_flops: _Optional[int] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... diff --git a/edml/generated/connection_pb2_grpc.py b/edml/generated/connection_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..c5b692413f8f3dcb6b06ba9e57d7aa31118ba740 --- /dev/null +++ b/edml/generated/connection_pb2_grpc.py @@ -0,0 +1,563 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import connection_pb2 as connection__pb2 +import datastructures_pb2 as datastructures__pb2 + + +class DeviceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.TrainGlobal = channel.unary_unary( + '/Device/TrainGlobal', + request_serializer=connection__pb2.TrainGlobalRequest.SerializeToString, + response_deserializer=connection__pb2.TrainGlobalResponse.FromString, + ) + self.SetWeights = channel.unary_unary( + '/Device/SetWeights', + request_serializer=connection__pb2.SetWeightsRequest.SerializeToString, + response_deserializer=connection__pb2.SetWeightsResponse.FromString, + ) + self.TrainEpoch = channel.unary_unary( + '/Device/TrainEpoch', + request_serializer=connection__pb2.TrainEpochRequest.SerializeToString, + response_deserializer=connection__pb2.TrainEpochResponse.FromString, + ) + self.TrainBatch = channel.unary_unary( + '/Device/TrainBatch', + request_serializer=connection__pb2.TrainBatchRequest.SerializeToString, + response_deserializer=connection__pb2.TrainBatchResponse.FromString, + ) + self.EvaluateGlobal = channel.unary_unary( + '/Device/EvaluateGlobal', + request_serializer=connection__pb2.EvalGlobalRequest.SerializeToString, + response_deserializer=connection__pb2.EvalGlobalResponse.FromString, + ) + self.Evaluate = channel.unary_unary( + '/Device/Evaluate', + request_serializer=connection__pb2.EvalRequest.SerializeToString, + response_deserializer=connection__pb2.EvalResponse.FromString, + ) + self.EvaluateBatch = channel.unary_unary( + '/Device/EvaluateBatch', + request_serializer=connection__pb2.EvalBatchRequest.SerializeToString, + response_deserializer=connection__pb2.EvalBatchResponse.FromString, + ) + self.FullModelTraining = channel.unary_unary( + '/Device/FullModelTraining', + request_serializer=connection__pb2.FullModelTrainRequest.SerializeToString, + response_deserializer=connection__pb2.FullModelTrainResponse.FromString, + ) + self.StartExperiment = channel.unary_unary( + '/Device/StartExperiment', + request_serializer=connection__pb2.StartExperimentRequest.SerializeToString, + response_deserializer=connection__pb2.StartExperimentResponse.FromString, + ) + self.EndExperiment = channel.unary_unary( + '/Device/EndExperiment', + request_serializer=connection__pb2.EndExperimentRequest.SerializeToString, + response_deserializer=connection__pb2.EndExperimentResponse.FromString, + ) + self.GetBatteryStatus = channel.unary_unary( + '/Device/GetBatteryStatus', + request_serializer=connection__pb2.BatteryStatusRequest.SerializeToString, + response_deserializer=connection__pb2.BatteryStatusResponse.FromString, + ) + self.GetDatasetModelInfo = channel.unary_unary( + '/Device/GetDatasetModelInfo', + request_serializer=connection__pb2.DatasetModelInfoRequest.SerializeToString, + response_deserializer=connection__pb2.DatasetModelInfoResponse.FromString, + ) + self.TrainGlobalParallelSplitLearning = channel.unary_unary( + '/Device/TrainGlobalParallelSplitLearning', + request_serializer=connection__pb2.TrainGlobalParallelSplitLearningRequest.SerializeToString, + response_deserializer=connection__pb2.TrainGlobalParallelSplitLearningResponse.FromString, + ) + self.TrainSingleBatchOnClient = channel.unary_unary( + '/Device/TrainSingleBatchOnClient', + request_serializer=connection__pb2.SingleBatchTrainingRequest.SerializeToString, + response_deserializer=connection__pb2.SingleBatchTrainingResponse.FromString, + ) + self.BackwardPropagationSingleBatchOnClient = channel.unary_unary( + '/Device/BackwardPropagationSingleBatchOnClient', + request_serializer=connection__pb2.SingleBatchBackwardRequest.SerializeToString, + response_deserializer=connection__pb2.SingleBatchBackwardResponse.FromString, + ) + self.SetGradientsAndFinalizeTrainingStep = channel.unary_unary( + '/Device/SetGradientsAndFinalizeTrainingStep', + request_serializer=connection__pb2.SetGradientsRequest.SerializeToString, + response_deserializer=datastructures__pb2.Empty.FromString, + ) + + +class DeviceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def TrainGlobal(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SetWeights(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TrainEpoch(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TrainBatch(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def EvaluateGlobal(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Evaluate(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def EvaluateBatch(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def FullModelTraining(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def StartExperiment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def EndExperiment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetBatteryStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDatasetModelInfo(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TrainGlobalParallelSplitLearning(self, request, context): + """/ Invoked by the controller on the server device to start one round of parallel split learning. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TrainSingleBatchOnClient(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def BackwardPropagationSingleBatchOnClient(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SetGradientsAndFinalizeTrainingStep(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_DeviceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'TrainGlobal': grpc.unary_unary_rpc_method_handler( + servicer.TrainGlobal, + request_deserializer=connection__pb2.TrainGlobalRequest.FromString, + response_serializer=connection__pb2.TrainGlobalResponse.SerializeToString, + ), + 'SetWeights': grpc.unary_unary_rpc_method_handler( + servicer.SetWeights, + request_deserializer=connection__pb2.SetWeightsRequest.FromString, + response_serializer=connection__pb2.SetWeightsResponse.SerializeToString, + ), + 'TrainEpoch': grpc.unary_unary_rpc_method_handler( + servicer.TrainEpoch, + request_deserializer=connection__pb2.TrainEpochRequest.FromString, + response_serializer=connection__pb2.TrainEpochResponse.SerializeToString, + ), + 'TrainBatch': grpc.unary_unary_rpc_method_handler( + servicer.TrainBatch, + request_deserializer=connection__pb2.TrainBatchRequest.FromString, + response_serializer=connection__pb2.TrainBatchResponse.SerializeToString, + ), + 'EvaluateGlobal': grpc.unary_unary_rpc_method_handler( + servicer.EvaluateGlobal, + request_deserializer=connection__pb2.EvalGlobalRequest.FromString, + response_serializer=connection__pb2.EvalGlobalResponse.SerializeToString, + ), + 'Evaluate': grpc.unary_unary_rpc_method_handler( + servicer.Evaluate, + request_deserializer=connection__pb2.EvalRequest.FromString, + response_serializer=connection__pb2.EvalResponse.SerializeToString, + ), + 'EvaluateBatch': grpc.unary_unary_rpc_method_handler( + servicer.EvaluateBatch, + request_deserializer=connection__pb2.EvalBatchRequest.FromString, + response_serializer=connection__pb2.EvalBatchResponse.SerializeToString, + ), + 'FullModelTraining': grpc.unary_unary_rpc_method_handler( + servicer.FullModelTraining, + request_deserializer=connection__pb2.FullModelTrainRequest.FromString, + response_serializer=connection__pb2.FullModelTrainResponse.SerializeToString, + ), + 'StartExperiment': grpc.unary_unary_rpc_method_handler( + servicer.StartExperiment, + request_deserializer=connection__pb2.StartExperimentRequest.FromString, + response_serializer=connection__pb2.StartExperimentResponse.SerializeToString, + ), + 'EndExperiment': grpc.unary_unary_rpc_method_handler( + servicer.EndExperiment, + request_deserializer=connection__pb2.EndExperimentRequest.FromString, + response_serializer=connection__pb2.EndExperimentResponse.SerializeToString, + ), + 'GetBatteryStatus': grpc.unary_unary_rpc_method_handler( + servicer.GetBatteryStatus, + request_deserializer=connection__pb2.BatteryStatusRequest.FromString, + response_serializer=connection__pb2.BatteryStatusResponse.SerializeToString, + ), + 'GetDatasetModelInfo': grpc.unary_unary_rpc_method_handler( + servicer.GetDatasetModelInfo, + request_deserializer=connection__pb2.DatasetModelInfoRequest.FromString, + response_serializer=connection__pb2.DatasetModelInfoResponse.SerializeToString, + ), + 'TrainGlobalParallelSplitLearning': grpc.unary_unary_rpc_method_handler( + servicer.TrainGlobalParallelSplitLearning, + request_deserializer=connection__pb2.TrainGlobalParallelSplitLearningRequest.FromString, + response_serializer=connection__pb2.TrainGlobalParallelSplitLearningResponse.SerializeToString, + ), + 'TrainSingleBatchOnClient': grpc.unary_unary_rpc_method_handler( + servicer.TrainSingleBatchOnClient, + request_deserializer=connection__pb2.SingleBatchTrainingRequest.FromString, + response_serializer=connection__pb2.SingleBatchTrainingResponse.SerializeToString, + ), + 'BackwardPropagationSingleBatchOnClient': grpc.unary_unary_rpc_method_handler( + servicer.BackwardPropagationSingleBatchOnClient, + request_deserializer=connection__pb2.SingleBatchBackwardRequest.FromString, + response_serializer=connection__pb2.SingleBatchBackwardResponse.SerializeToString, + ), + 'SetGradientsAndFinalizeTrainingStep': grpc.unary_unary_rpc_method_handler( + servicer.SetGradientsAndFinalizeTrainingStep, + request_deserializer=connection__pb2.SetGradientsRequest.FromString, + response_serializer=datastructures__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'Device', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Device(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def TrainGlobal(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/TrainGlobal', + connection__pb2.TrainGlobalRequest.SerializeToString, + connection__pb2.TrainGlobalResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SetWeights(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/SetWeights', + connection__pb2.SetWeightsRequest.SerializeToString, + connection__pb2.SetWeightsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TrainEpoch(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/TrainEpoch', + connection__pb2.TrainEpochRequest.SerializeToString, + connection__pb2.TrainEpochResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TrainBatch(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/TrainBatch', + connection__pb2.TrainBatchRequest.SerializeToString, + connection__pb2.TrainBatchResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def EvaluateGlobal(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/EvaluateGlobal', + connection__pb2.EvalGlobalRequest.SerializeToString, + connection__pb2.EvalGlobalResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Evaluate(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/Evaluate', + connection__pb2.EvalRequest.SerializeToString, + connection__pb2.EvalResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def EvaluateBatch(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/EvaluateBatch', + connection__pb2.EvalBatchRequest.SerializeToString, + connection__pb2.EvalBatchResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def FullModelTraining(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/FullModelTraining', + connection__pb2.FullModelTrainRequest.SerializeToString, + connection__pb2.FullModelTrainResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def StartExperiment(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/StartExperiment', + connection__pb2.StartExperimentRequest.SerializeToString, + connection__pb2.StartExperimentResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def EndExperiment(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/EndExperiment', + connection__pb2.EndExperimentRequest.SerializeToString, + connection__pb2.EndExperimentResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetBatteryStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/GetBatteryStatus', + connection__pb2.BatteryStatusRequest.SerializeToString, + connection__pb2.BatteryStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetDatasetModelInfo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/GetDatasetModelInfo', + connection__pb2.DatasetModelInfoRequest.SerializeToString, + connection__pb2.DatasetModelInfoResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TrainGlobalParallelSplitLearning(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/TrainGlobalParallelSplitLearning', + connection__pb2.TrainGlobalParallelSplitLearningRequest.SerializeToString, + connection__pb2.TrainGlobalParallelSplitLearningResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TrainSingleBatchOnClient(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/TrainSingleBatchOnClient', + connection__pb2.SingleBatchTrainingRequest.SerializeToString, + connection__pb2.SingleBatchTrainingResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def BackwardPropagationSingleBatchOnClient(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/BackwardPropagationSingleBatchOnClient', + connection__pb2.SingleBatchBackwardRequest.SerializeToString, + connection__pb2.SingleBatchBackwardResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SetGradientsAndFinalizeTrainingStep(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Device/SetGradientsAndFinalizeTrainingStep', + connection__pb2.SetGradientsRequest.SerializeToString, + datastructures__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/edml/generated/datastructures_pb2.py b/edml/generated/datastructures_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..c89e114e6d761302594d9e5cee049be33f783ad2 --- /dev/null +++ b/edml/generated/datastructures_pb2.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: datastructures.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x64\x61tastructures.proto\"\x1c\n\x06Tensor\x12\x12\n\nserialized\x18\x01 \x01(\x0c\"\x1f\n\tStateDict\x12\x12\n\nserialized\x18\x01 \x01(\x0c\"&\n\x07Weights\x12\x1b\n\x07weights\x18\x01 \x01(\x0b\x32\n.StateDict\"!\n\x06Labels\x12\x17\n\x06labels\x18\x01 \x01(\x0b\x32\x07.Tensor\"+\n\x0b\x41\x63tivations\x12\x1c\n\x0b\x61\x63tivations\x18\x01 \x01(\x0b\x32\x07.Tensor\"\'\n\tGradients\x12\x1a\n\tgradients\x18\x01 \x01(\x0b\x32\x07.Tensor\"\x1a\n\x07Metrics\x12\x0f\n\x07metrics\x18\x01 \x01(\x0c\"+\n\x0bPredictions\x12\x1c\n\x0bpredictions\x18\x01 \x01(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"0\n\nDeviceInfo\x12\x11\n\tdevice_id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\"M\n\rBatteryStatus\x12\x1d\n\x15initial_battery_level\x18\x01 \x01(\x01\x12\x1d\n\x15\x63urrent_battery_level\x18\x02 \x01(\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'datastructures_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _globals['_TENSOR']._serialized_start=24 + _globals['_TENSOR']._serialized_end=52 + _globals['_STATEDICT']._serialized_start=54 + _globals['_STATEDICT']._serialized_end=85 + _globals['_WEIGHTS']._serialized_start=87 + _globals['_WEIGHTS']._serialized_end=125 + _globals['_LABELS']._serialized_start=127 + _globals['_LABELS']._serialized_end=160 + _globals['_ACTIVATIONS']._serialized_start=162 + _globals['_ACTIVATIONS']._serialized_end=205 + _globals['_GRADIENTS']._serialized_start=207 + _globals['_GRADIENTS']._serialized_end=246 + _globals['_METRICS']._serialized_start=248 + _globals['_METRICS']._serialized_end=274 + _globals['_PREDICTIONS']._serialized_start=276 + _globals['_PREDICTIONS']._serialized_end=319 + _globals['_EMPTY']._serialized_start=321 + _globals['_EMPTY']._serialized_end=328 + _globals['_DEVICEINFO']._serialized_start=330 + _globals['_DEVICEINFO']._serialized_end=378 + _globals['_BATTERYSTATUS']._serialized_start=380 + _globals['_BATTERYSTATUS']._serialized_end=457 +# @@protoc_insertion_point(module_scope) diff --git a/edml/generated/datastructures_pb2.pyi b/edml/generated/datastructures_pb2.pyi new file mode 100644 index 0000000000000000000000000000000000000000..32831a94c51565f9eea3160757d6bdf6fd15562d --- /dev/null +++ b/edml/generated/datastructures_pb2.pyi @@ -0,0 +1,73 @@ +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Tensor(_message.Message): + __slots__ = ["serialized"] + SERIALIZED_FIELD_NUMBER: _ClassVar[int] + serialized: bytes + def __init__(self, serialized: _Optional[bytes] = ...) -> None: ... + +class StateDict(_message.Message): + __slots__ = ["serialized"] + SERIALIZED_FIELD_NUMBER: _ClassVar[int] + serialized: bytes + def __init__(self, serialized: _Optional[bytes] = ...) -> None: ... + +class Weights(_message.Message): + __slots__ = ["weights"] + WEIGHTS_FIELD_NUMBER: _ClassVar[int] + weights: StateDict + def __init__(self, weights: _Optional[_Union[StateDict, _Mapping]] = ...) -> None: ... + +class Labels(_message.Message): + __slots__ = ["labels"] + LABELS_FIELD_NUMBER: _ClassVar[int] + labels: Tensor + def __init__(self, labels: _Optional[_Union[Tensor, _Mapping]] = ...) -> None: ... + +class Activations(_message.Message): + __slots__ = ["activations"] + ACTIVATIONS_FIELD_NUMBER: _ClassVar[int] + activations: Tensor + def __init__(self, activations: _Optional[_Union[Tensor, _Mapping]] = ...) -> None: ... + +class Gradients(_message.Message): + __slots__ = ["gradients"] + GRADIENTS_FIELD_NUMBER: _ClassVar[int] + gradients: Tensor + def __init__(self, gradients: _Optional[_Union[Tensor, _Mapping]] = ...) -> None: ... + +class Metrics(_message.Message): + __slots__ = ["metrics"] + METRICS_FIELD_NUMBER: _ClassVar[int] + metrics: bytes + def __init__(self, metrics: _Optional[bytes] = ...) -> None: ... + +class Predictions(_message.Message): + __slots__ = ["predictions"] + PREDICTIONS_FIELD_NUMBER: _ClassVar[int] + predictions: Tensor + def __init__(self, predictions: _Optional[_Union[Tensor, _Mapping]] = ...) -> None: ... + +class Empty(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + +class DeviceInfo(_message.Message): + __slots__ = ["device_id", "address"] + DEVICE_ID_FIELD_NUMBER: _ClassVar[int] + ADDRESS_FIELD_NUMBER: _ClassVar[int] + device_id: str + address: str + def __init__(self, device_id: _Optional[str] = ..., address: _Optional[str] = ...) -> None: ... + +class BatteryStatus(_message.Message): + __slots__ = ["initial_battery_level", "current_battery_level"] + INITIAL_BATTERY_LEVEL_FIELD_NUMBER: _ClassVar[int] + CURRENT_BATTERY_LEVEL_FIELD_NUMBER: _ClassVar[int] + initial_battery_level: float + current_battery_level: float + def __init__(self, initial_battery_level: _Optional[float] = ..., current_battery_level: _Optional[float] = ...) -> None: ... diff --git a/edml/generated/datastructures_pb2_grpc.py b/edml/generated/datastructures_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9393943bdf46e1eb66d2d033485fdabac096ca --- /dev/null +++ b/edml/generated/datastructures_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc diff --git a/edml/helpers/__init__.py b/edml/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/helpers/config_helpers.py b/edml/helpers/config_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4ec87e06ee3542cb28ce3e857880795164188d --- /dev/null +++ b/edml/helpers/config_helpers.py @@ -0,0 +1,245 @@ +from copy import deepcopy +from inspect import signature +from typing import TYPE_CHECKING + +import torch +from hydra.utils import get_class, instantiate +from omegaconf import OmegaConf, DictConfig, ListConfig +from omegaconf.errors import ConfigAttributeError + +if TYPE_CHECKING: + from edml.controllers.base_controller import BaseController + + +def get_device_address_by_id(device_id: str, cfg: DictConfig) -> str: + """ + Returns the binding address of the device with the given id. + + Args: + device_id (str): The device id. + cfg (DictConfig): The config loaded from YAML files. + + Returns: + The device's binding address. + + Raises: + StopIteration: If the device with the given ID cannot be found. + """ + return next( + device.address + for device in cfg.topology.devices + if device.device_id == device_id + ) + + +def get_device_id_by_index(cfg: DictConfig, index: int) -> str: + """ + Returns the device info of the device with the given index in the network topology. + + Args: + cfg (DictConfig): The config loaded from YAML files. + index (int): The index of the device inside the configuration file. + + Returns: + The device's ID. + + Raises: + StopIteration: If the device with the given ID cannot be found. + """ + return cfg.topology.devices[index].device_id + + +def get_device_index_by_id(cfg: DictConfig, device_id: str) -> int: + """ + Returns the index of the device with the given id in the network topology. + + Args: + cfg (DictConfig): The config loaded from YAML files. + device_id (str): The device's ID. + + Returns: + The index of the device inside the configuration file. + + Raises: + StopIteration: If the device with the given ID cannot be found. + """ + return next( + i + for i, device in enumerate(cfg.topology.devices) + if device.device_id == device_id + ) + + +def _group_resolver(cfg: DictConfig, group_by: DictConfig): + """ + Resolver for the group_by attribute in the config. This attribute specifies which values to include in the group name. + Therefore, the values of group_by are parsed to retrieve the key paths to the desired values. + E.g. grouping by controller and scheduler name: + group_by: + - controller: [ name, scheduler: name] + yields the paths: ["controller", "name"] and ["controller", "scheduler", "name"] which are read from the cfg then. + + Args: + cfg (DictConfig): The full dict config. + group_by (DictConfig): The part of the config specifying which attributes to use for experiment grouping. + + Returns: + The name of the config group with underscores in between each value. + """ + + def __recurse__(group_by, attr_path: list): + """Retrieves the key paths to the desired attributes.""" + attr_paths = [] + if isinstance(group_by, DictConfig): + for k, v in group_by.items(): + if isinstance(v, str): + attr_paths.append(attr_path + [k] + [v]) + else: + attr_paths.extend(__recurse__(group_by[k], attr_path + [k])) + elif isinstance(group_by, ListConfig): + for idx, item in enumerate(group_by): + if isinstance(item, str): + attr_paths.append(attr_path + [item]) + else: + attr_paths.extend(__recurse__(group_by[idx], attr_path)) + + return attr_paths + + attr_paths = __recurse__(group_by, []) + # resolve each attribute + values = [] + for path in attr_paths: + value = cfg + for key in path: + if isinstance( + value, str + ): # if previous key was not found, value is the empty string + break + value = value.get(key, "") + values.append(value) + # concatenate and return + return "_".join(values) + + +def preprocess_config(cfg: DictConfig): + """ + Configures `OmegaConf` and registers custom resolvers. Additionally, normalizes the configuration file for command + line usage: + + - If `own_device_id` is an integer, the value is treated as an index into the list of available devices; it is + treated as the i-th device inside the configured topology. This functions then looks up the device_id by index + and sets `own_device_id`. + - resolves the group_name attribute specifying the composition of the experiment group name. + """ + OmegaConf.register_new_resolver("len", lambda x: len(x), replace=True) + OmegaConf.register_new_resolver( + "group_name", lambda group_by: _group_resolver(cfg, group_by), replace=True + ) + OmegaConf.resolve(cfg) + + # In case someone specified an integer instead of a proper device_id (str), we look up the proper device by indexing + # the list of all available devices using said integer. + if isinstance(cfg.own_device_id, int): + cfg.own_device_id = get_device_id_by_index(cfg, cfg.own_device_id) + + +def __drop_irrelevant_keys__(cfg: DictConfig) -> DictConfig: + """ + Removes keys from config not needed to instantiate the specified _target_ class. + Assumes that cfg has a key _target_. Hydra keys _recursive_ and _partial_ are not removed. + + Args: + cfg: The controller configuration. + + Returns: + A DictConfig without unnecessary keys. + """ + controller_class = get_class(cfg._target_) + controller_signature = signature(controller_class.__init__) + controller_args = controller_signature.parameters.keys() + + # These are special hydra keywords that we do not want to filter out. + special_keys = ["_target_", "_recursive_", "_partial_"] + cfg = {k: v for k, v in cfg.items() if k in controller_args or k in special_keys} + return cfg + + +def drop_irrelevant_keys_recursively(cfg: DictConfig) -> DictConfig: + """ + Removes parameters that are not necessary to instantiate the specified classes. + This is done for the controller class as well as for scheduler and adaptive threshold if present. + This is needed because hydra's instantiation mechanism expects that all given parameters are actually needed. + + Args: + cfg: The controller configuration. + + Returns: + A DictConfig that contains only the parameters actually needed to instantiate the specified classes. + """ + cfg.controller = __drop_irrelevant_keys__(cfg.controller) + if cfg.controller.get("scheduler", False): + cfg.controller.scheduler = __drop_irrelevant_keys__(cfg.controller.scheduler) + if cfg.controller.get("adaptive_threshold_fn", False): + cfg.controller.adaptive_threshold_fn = __drop_irrelevant_keys__( + cfg.controller.adaptive_threshold_fn + ) + return cfg + + +def instantiate_controller(cfg: DictConfig): # -> BaseController: + """ + Instantiates a controller based on the configuration. This method filters out extra parameters defined through hydra + but not required by the controller's init method. This allows for hydra's multirun feature to work even if + controllers have different parameters (like next server schedulers). + + Args: + cfg: The controller configuration. + + Returns: + An instance of `BaseController`. + """ + original_cfg = deepcopy(cfg) + # Filter out any arguments not present in the controller constructor. This is a hack required to make multirun work. + # We want to be able to use different scheduling strategies combined with different controllers. But hydra's + # `instantiate` method is strict and fails if it receives any extra arguments. + cfg = drop_irrelevant_keys_recursively(cfg) + + # Update the device ID and set it to controller. + cfg.own_device_id = "controller" + + # Instantiate the controller. + controller: BaseController = instantiate(cfg.controller)(cfg=original_cfg) + return controller + + +def get_torch_device_id(cfg: DictConfig) -> str: + """ + Returns the configured torch_device for the current device. + Resorts to default if no torch_device is configured. + + Args: + cfg (DictConfig): The config loaded from YAML files. + + Returns: + The id of the configured torch_device for the current device. + + Raises: + StopIteration: If the device with the given ID cannot be found. + ConfigAttributeError: If no device id is present in the config. + """ + own_device_id = cfg.own_device_id + try: + return next( + device_cfg.torch_device + for device_cfg in cfg.topology.devices + if device_cfg.device_id == own_device_id + ) + except ConfigAttributeError: + return _default_torch_device() + + +def _default_torch_device(): + """ + Returns the default torch devices, depending on whether cuda is available. + """ + return "cuda:0" if torch.cuda.is_available() else "cpu" diff --git a/edml/helpers/data_partitioning.py b/edml/helpers/data_partitioning.py new file mode 100644 index 0000000000000000000000000000000000000000..16eaff2af615608369d71d9f463a5fd1227c7aee --- /dev/null +++ b/edml/helpers/data_partitioning.py @@ -0,0 +1,159 @@ +import itertools +import math +from collections import defaultdict +from typing import Optional, List, Tuple, Any + +import torch +from torch.utils.data import random_split, Subset +from torch.utils.data.dataset import Dataset + + +def __get_partitioned_data_for_device__( + dataset: Dataset, + device_index: int, + num_devices: int, + fractions: Optional[List[float]] = None, + seed: int = 42, + distribution: Optional[str] = None, +) -> Subset: + """ + Returns the data partition for the given device index. + + Args: + dataset(Dataset): The dataset to partition. + device_index(int): The index of the device to get the partition. + num_devices(int): The total number of devices. + fractions (List[float], optional): The fractions of the dataset to assign to each device. + seed (int, optional): The seed for the random number generator. + distribution (str, optional): The distribution of the data. Set to "non-iid" for non-iid, otherwise defaults to + iid. + + Returns: + The partition of the dataset for the given device index. + + Raises: + ValueError: If the sum of the fractions is greater than 1. + + Notes: + If fractions is not given or too short, it will be set to 1/num_devices for each device. + If the sum of the fractions is less than 1, a dummy subset is added to sum up to 1. + If the distribution is non-iid, the data is grouped by label and split in half. The first half is shuffled + while the second half remains ordered. Then the halfs are merged and the dataset is split according to the + fractions. Not supported by every dataset since it has to be reinitialized, so test beforehand with new + datasets. + """ + if fractions is None or len(fractions) < num_devices: + quotient, remainder = divmod(len(dataset), num_devices) + fractions = [quotient] * num_devices + if remainder > 0: + for i in range(remainder): + fractions[i] += 1 + elif sum(fractions) != 1: + if sum(fractions) > 1: + if not math.isclose(sum(fractions), 1): + raise ValueError("Sum of lengths must be <= 1.") + else: + # to make sampling subsets possible, add a dummy subset to sum lengths up to 1 + # use round to avoid floating point errors leading to remainders that are added to the other subsets + fractions += [1 - round(sum(fractions), 4)] + if distribution == "non-iid": + split_position = 0.6 # split at 60% of the data, i.e. shuffle the first 60% and keep the rest ordered + # group by label + partitions_by_label = defaultdict(list) + dataset_class = type(dataset) + for sample, label in dataset: + partitions_by_label[str(label)].append((sample, label)) + # create lists sorted by label + data_list = [] + label_list = [] + for key, partition in partitions_by_label.items(): + for sample, label in partition: + data_list.append(sample) + label_list.append(label) + # split both lists at the split position + split_idx = int(len(data_list) // (1 / split_position)) + data_list1, data_list2 = data_list[:split_idx], data_list[split_idx:] + label_list1, label_list2 = label_list[:split_idx], label_list[split_idx:] + # shuffle first half + generator = torch.Generator().manual_seed(seed) + indices = torch.randperm(len(data_list1), generator=generator) + data_list1 = [data_list1[i] for i in indices] + label_list1 = [label_list1[i] for i in indices] + # merge lists again + data_list = data_list1 + data_list2 + label_list = label_list1 + label_list2 + reordered_dataset = SimpleDataset(data_list, label_list) + # split in order according to the fractions + splits = [] + for i, f in enumerate(fractions): + start = int(sum(fractions[:i]) * len(data_list)) + end = int(start + len(data_list) * f) + splits.append(Subset(reordered_dataset, range(start, end))) + # transform to dataset + return splits[device_index] + else: + generator = torch.Generator().manual_seed(seed) + splits = random_split(dataset, fractions, generator=generator) + return splits[device_index] + + +class DataPartitioner: + """ + Deterministically splits datasets into chunks of specified size based on seed and device index. + + Instead of sending training and test data to each device over the network, this helper class is used to split the + dataset in non-overlapping chunks. Thus, an instance of this class is instantiated by each device. + + Given the same parameters, the class will always split the data in the same way. + + Attributes: + device_index (int): The device the data partitioner will split data for. + num_devices (int): The total number of devices that data needs to be split for. + fractions (List[float], optional): List of fractions that represent the amount of data to split for a device. + The fraction at position `i` is used to partition the data for a device at index `i`. Defaults to `None`, + meaning that all data is partitioned into same-sized chunks. + seed (int): The seed of the PRNG. Defaults to 42. + distribution (str, optional): The distribution to use for partitioning the data. Defaults to `None`, meaning + that the data is split randomly. Possible values are `iid` and `non-iid`. + """ + + def __init__( + self, + device_index: int, + num_devices: int, + fractions: Optional[List[float]] = None, + seed: int = 42, + distribution: Optional[str] = None, + ): + self.device_index = device_index + self.num_devices = num_devices + self.fractions = fractions + self.seed = seed + self.distribution = distribution + + def partition(self, data: Dataset): + return __get_partitioned_data_for_device__( + data, + self.device_index, + self.num_devices, + self.fractions, + self.seed, + self.distribution, + ) + + +class SimpleDataset(Dataset): + """ + Dataset for wrapping data and labels after non-iid partitioning. Possible transformations are already applied at the + partitioning, so no further transformations are needed. + """ + + def __init__(self, data, labels): + self.data = data + self.labels = labels + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + return self.data[idx], self.labels[idx] diff --git a/edml/helpers/decorators.py b/edml/helpers/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..5623648311abd9e63c51281af657649ea950459e --- /dev/null +++ b/edml/helpers/decorators.py @@ -0,0 +1,251 @@ +import time +import types +from functools import wraps + +from edml.helpers.metrics import DiagnosticMetricResult, DiagnosticMetricResultContainer + + +def check_device_set(): + """ + The decorator checks if the `node_device` attribute has already been initialized. + + The decorator is used on functions inside the py:class:`edml.core.device.Device` class. + + Raises: + ValueError: If `node_device` is `None`. + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.node_device is None: + raise ValueError("Device not set") + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def log_execution_time(logger: str, log_key: str = "execution_time"): + """ + A decorator factory for methods that returns a decorator that measures the wrapped function's execution time. + + The execution time is saved in a dictionary with key `log_key` and a dictionary as value. The value dictionary holds + the start time, end time and duration of the function execution. + + Args: + logger (str): The name of the attribute that provides the `log` method used to log the execution time. + log_key (str): The key under which the execution time data is stored. Defaults to `execution_time`. + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + start = time.time() + result = func(self, *args, **kwargs) + end = time.time() + if getattr(self, logger) is not None: + getattr(self, logger).log( + { + f"{log_key}": { + "start": start, + "end": end, + "duration": end - start, + } + } + ) + else: + print(f"{log_key}: {end - start}") + return result + + return wrapper + + return decorator + + +def battery_updater(cls): + """ + A decorator for classes that should update the battery of the device. Attaches the + py:meth:`edml.helpers.decorators.update_battery` decorator to all methods of the class. + + Args: + cls (class): Class to attach the py:meth:`edml.helpers.decorators.update_battery` decorators to. + """ + for key in dir(cls): + value = getattr(cls, key) + if callable(value) and isinstance(value, types.FunctionType): + setattr(cls, key, update_battery(value)) + return cls + + +def update_battery(func, battery_attr: str = "battery"): + """ + Decorator for methods that should update the battery of the device. + + Args: + func (function): The function to decorate. + battery_attr (str): The name of the battery attribute. Defaults to `battery`. + + Notes: + Requires an attribute of type py:class:`edml.core.battery.Battery` to be set on the decorated object. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + if hasattr(self, battery_attr): + battery = getattr(self, battery_attr) + battery.update_time() + result = func(self, *args, **kwargs) + if hasattr(self, battery_attr): + battery = getattr(self, battery_attr) + battery.update_time() + return result + + return wrapper + + +def simulate_latency_decorator(latency_factor_attr): + """ + Simulates latency by sleeping for the given number of seconds. + + Args: + latency_factor_attr (str): the class attribute to determine the computational latency + + Returns: + A decorator that sleeps for the time of its wrapped method execution times the latency_factor. + + Raises: + None + + Notes: + Hence the resulting execution time is [MethodExecutionTime] * (1 + latency_factor). + Use only for methods that do not call other functions using this decorator. Otherwise, the latency is increases by more than the given factor. + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + latency = getattr(self, latency_factor_attr) + with LatencySimulator(latency_factor=latency): + res = func(self, *args, **kwargs) + return res + + return wrapper + + return decorator + + +class LatencySimulator: + def __init__(self, latency_factor: float = 0.0): + """ + Simulates latency by sleeping for the given number of seconds. + + Args: + latency_factor (float): the class attribute to determine the computational latency + + Returns: + A decorator that sleeps for the time of its wrapped method execution times the latency_factor. + + Notes: + Hence the resulting execution time is [MethodExecutionTime] * (1 + latency_factor) + Use only for methods that do not call other functions using this decorator. Otherwise, the latency is increases by more than the given factor. + """ + self.latency_factor = latency_factor + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.latency_factor is None or self.latency_factor <= 0: + return + self.end_time = time.time() + self.execution_time = self.end_time - self.start_time + time.sleep(self.execution_time * self.latency_factor) + + +def add_time_to_diagnostic_metrics(method_name: str): + """ + A decorator factory that measures the execution time of the wrapped method. It then creates a diagnostic metric + result container to store the execution time. + + The decorator is smart enough to discover if functions already return an instance of type + `DiagnosticMetricResultContainer`. If that is the case, the current metric result originating from `func` and the + existing ones are merged together. + + More specifically, the decorator does the following: + + - If `func` has a return value of `None`, then the wrapped function returns an instance of + `DiagnosticMetricResultContainer`. + - if `func` returns an instance of `DiagnosticMetricResultContainer`, then the results are merged together and + the wrapped function returns an instance of `DiagnosticMetricResultContainer`. + - If `func` returns a tuple `t`, we analyze its value types: + - If its last value is an instance of `DiagnosticMetricResultContainer`, we merge them together and return + the same tuple, but its last value is changed to the new `DiagnosticMetricResultContainer` instance. + - Else we return a new tuple of length `len(t) + 1`, where its last value will be the current + `DiagnosticMetricResultContainer` instance. + - Else the return value of `func` is a single value. We create a tuple that holds the original return value at + position 0 and the new `DiagnosticMetricResultContainer` at position 1. + + Args: + method_name (str): The name of the method to add the computation time to. Will be included in the metric. + + Returns: + A decorator that adds the computation time to the diagnostic metrics of the wrapped method. + + Notes: + The returned decorator can only be used for class functions. It expects that `func` is a class method with + `self` parameter. It also expects that an attribute named `device_id` exists. This is required because the + diagnostic metrics are bound to specific devices. + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + start_time = time.time() + res = func(self, *args, **kwargs) + end_time = time.time() + + diagnostic_metric_result = DiagnosticMetricResult( + device_id=self.device_id, + name="comp_time", + value=end_time - start_time, + method=method_name, + ) + + # The wrapped function does not have a return value. In that case we return the metrics. + if res is None: + return DiagnosticMetricResultContainer([diagnostic_metric_result]) + + # The wrapped function returns a metrics container. In that case we merge the current metrics instance with + # the existing container and return the result. + if isinstance(res, DiagnosticMetricResultContainer): + res.add_result(diagnostic_metric_result) + return res + + # The wrapped function returns a tuple. In that case we check if the last tuple value is a metrics + # container and add the current metrics instance to it if that is the case. If not, we simply create a new + # container and append it to the tuple. + if isinstance(res, tuple): + potential_metrics_container = res[-1] + if isinstance( + potential_metrics_container, DiagnosticMetricResultContainer + ): + potential_metrics_container.add_result(diagnostic_metric_result) + else: + res = ( + *res, + DiagnosticMetricResultContainer([diagnostic_metric_result]), + ) + return res + + # The wrapped function returns some kind of value. But none of the special cases above. We create a new + # tuple that holds the wrapped function's return value at position 0 and the metrics container at position + # 1. + + return res, DiagnosticMetricResultContainer([diagnostic_metric_result]) + + return wrapper + + return decorator diff --git a/edml/helpers/executor.py b/edml/helpers/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a989abf444ec02091db1fd8aa44d31bc3577ea --- /dev/null +++ b/edml/helpers/executor.py @@ -0,0 +1,5 @@ +import concurrent.futures + + +def create_executor_with_threads(threads: int, min_threads: int = 1): + return concurrent.futures.ThreadPoolExecutor(max_workers=max(threads, min_threads)) diff --git a/edml/helpers/flops.py b/edml/helpers/flops.py new file mode 100644 index 0000000000000000000000000000000000000000..4388c9754e4c42e33a958c309f8a73a2c21f4265 --- /dev/null +++ b/edml/helpers/flops.py @@ -0,0 +1,21 @@ +from typing import Union, Tuple + +from fvcore.nn import FlopCountAnalysis +from torch import Tensor, nn + + +def estimate_model_flops( + model: nn.Module, sample: Union[Tensor, Tuple[Tensor, ...]] +) -> int: + """ + Estimates the FLOPs of one forward pass of the model using the sample data provided. + + Args: + model (nn.Module): the neural network model to calculate the FLOPs for. + sample: The data used to calculate the FLOPs. + + Returns: + int: the number of FLOPs. + + """ + return FlopCountAnalysis(model, sample).total() diff --git a/edml/helpers/interceptors.py b/edml/helpers/interceptors.py new file mode 100644 index 0000000000000000000000000000000000000000..f954c19783e93d151b385d58e637eda0d5b98abd --- /dev/null +++ b/edml/helpers/interceptors.py @@ -0,0 +1,113 @@ +import threading +from typing import Callable, Any + +import grpc +from grpc_interceptor import ServerInterceptor, ClientInterceptor + +from edml.core.battery import BatteryEmptyException, Battery +from edml.helpers.logging import SimpleLogger + + +class DeviceServerInterceptor(ServerInterceptor): + + def __init__( + self, + logger: SimpleLogger, + battery: Battery = None, + stop_event: threading.Event = None, + ): + """ + Intercepts all calls made to its gRPC server. Handles the battery and logs the request and response sizes. + + Args: + logger (SimpleLogger): The logger to use. + battery (Battery): The battery to monitor. + stop_event (threading.Event): The stop event to set if the battery is empty to stop the gRPC server. + Returns: + None + Raises: + grpc.StatusCode.RESOURCE_EXHAUSTED when the battery is empty. + Notes: + """ + self.logger = logger + self.battery = battery + self.stop_event = stop_event + + def intercept( + self, + method: Callable, + request_or_iterator: Any, + context: grpc.ServicerContext, + method_name: str, + ) -> Any: + request_size = _proto_serialized_size(request_or_iterator) + self.logger.log({f"{method_name}_request_size": request_size}) + # check battery before handling request + if self.battery is not None: + try: + self.battery.update_communication_received(request_size) + except BatteryEmptyException as e: + self.logger.log("Battery empty while receiving request") + self.stop_event.set() + context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, "Battery empty") + # handle request + try: + response = method(request_or_iterator, context) + response_size = _proto_serialized_size(response) + self.logger.log({f"{method_name}_response_size": response_size}) + if self.battery is not None: + self.battery.update_communication_sent(response_size) + return response + except BatteryEmptyException as e: + self.logger.log("Battery empty while handling request or sending response") + self.stop_event.set() + context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, "Battery empty") + + +class DeviceClientInterceptor(ClientInterceptor): + + def __init__(self, logger, battery=None, stop_event=None): + self.logger = logger + self.battery = battery + self.stop_event = stop_event + + def intercept( + self, + method: Callable, + request_or_iterator: Any, + call_details: grpc.ClientCallDetails, + ): + request_size = _proto_serialized_size(request_or_iterator) + if self.battery is not None: + try: + self.battery.update_communication_sent(request_size) + future = method(request_or_iterator, call_details) + result = future.result() + response_size = _proto_serialized_size(result) + self.battery.update_communication_received(response_size) + return future + except BatteryEmptyException as e: + # client exception, i.e. local device + # catch to stop the server by setting the stop event + self.logger.log("RPC client battery empty during request") + self.stop_event.set() + # now reraise exception to abort the RPC + # any action of the device is triggered by some RPC, thus this exception will be handled by the next server interceptor in the cascade + # as this device does not have any battery left for any other action, let it fail as fast as possible + # this exception is not meant to be handled by the request dispatcher that started the RPC + raise e + except grpc.RpcError as e: + # server exception, i.e. remote device + # exceptions do not have an associated byte size and are hence not counted towards the battery + self.logger.log("RPC server battery empty during request") + raise e # forward exception to dispatcher to handle + + +def _proto_serialized_size(proto_object): + """Returns the size of the length-prefixed message in bytes. + Includes the grpc header in the calculation since messages without payload are allowed (content size 0) + """ + byte_size = proto_object.ByteSize() + length_prefix = 4 # length prefix is 4 bytes to determine the message payload size + compressed_flag = 1 + return byte_size + length_prefix + compressed_flag diff --git a/edml/helpers/load_dataset.py b/edml/helpers/load_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..242bfed6c38c182281edcf19fbc2cf732db7ba34 --- /dev/null +++ b/edml/helpers/load_dataset.py @@ -0,0 +1,44 @@ +from typing import Optional + +from edml.dataset_utils.cifar.cifar import cifar100_dataloaders, cifar10_dataloaders +from edml.dataset_utils.mnist.mnist import mnist_dataloaders +from edml.dataset_utils.ptb_xl.ptb_xl import ptb_xl_train_val_test +from edml.helpers.data_partitioning import DataPartitioner +from edml.helpers.types import DatasetDataLoaders + + +def get_dataloaders( + name: str, + batch_size: int, + data_partitioner: Optional[DataPartitioner] = None, +) -> DatasetDataLoaders: + """ + Returns the :class:`DataLoader`s for the given dataset name. In total, the function should return three + :class:`DataLoader` instances for training, validation and testing. + + Args: + name (str): The name of the dataset to create the :class:`DataLoader`s for. Currently supported values are + `mnist`, `ptbxl` and `cifar100`. + batch_size (int): The batch size. + data_partitioner (DataPartitioner, optional): A custom data partitioner to use for splitting the data. Defaults + to `None`. + + Raises: + ValueError: If the dataset name is unknown. + + Notes: + If the data partitioner is not set explicitly, the data should be split randomly. + """ + + # To add your own datasets, you can simply introduce a new name check that returns + # the appropriate data loaders. + if name == "mnist": + return mnist_dataloaders(batch_size, data_partitioner=data_partitioner) + elif name == "ptbxl": + return ptb_xl_train_val_test(batch_size, data_partitioner=data_partitioner) + elif name == "cifar10": + return cifar10_dataloaders(batch_size, data_partitioner=data_partitioner) + elif name == "cifar100": + return cifar100_dataloaders(batch_size, data_partitioner=data_partitioner) + else: + raise ValueError(f"Dataset {name} not known.") diff --git a/edml/helpers/load_model.py b/edml/helpers/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f37ab80fe6f2362682f07f53a44d53f57a4b7f80 --- /dev/null +++ b/edml/helpers/load_model.py @@ -0,0 +1,79 @@ +from os.path import exists +from typing import Tuple + +import torch +from omegaconf import DictConfig +from torch import nn + +from edml.models import mnist_models, tcn_models, resnet_models + + +def _load_weights(model: nn.Module, path: str): + """Loads the weights for the given model if the file at `path` exists""" + if exists(path): + model.load_state_dict(torch.load(path)) + + +def get_models(cfg: DictConfig) -> Tuple[nn.Module, nn.Module]: + """ + Returns the client and server models for the configured name. If `load_weights` has been set to `True`, the weights + are loaded from the provided file paths. + + Args: + cfg (DictConfig): The experiment's configuration. + + Returns: + A tuple with the client model at position 0 and the server model at position 1. + + Raises: + ValueError: If the configured model name is unknown. + """ + name = cfg.model.name + load_weights = cfg.experiment.load_weights + client_model_load_path = cfg.experiment.client_model_load_path + server_model_load_path = cfg.experiment.server_model_load_path + + if name == "simple_conv": + client_model = mnist_models.ClientNet() + server_model = mnist_models.ServerNet() + elif name == "tcn": + client_model = tcn_models.Small_TCN_5_Client( + classes=cfg.dataset.num_classes, n_inputs=cfg.dataset.n_inputs + ) + server_model = tcn_models.Small_TCN_5_Server( + classes=cfg.dataset.num_classes, n_inputs=cfg.dataset.n_inputs + ) + elif name == "resnet20": + client_model, server_model = resnet_models.resnet20( + cfg.model.cut_layer, num_classes=cfg.dataset.num_classes + ) + elif name == "resnet32": + client_model, server_model = resnet_models.resnet32( + cfg.model.cut_layer, num_classes=cfg.dataset.num_classes + ) + elif name == "resnet44": + client_model, server_model = resnet_models.resnet44( + cfg.model.cut_layer, num_classes=cfg.dataset.num_classes + ) + elif name == "resnet56": + client_model, server_model = resnet_models.resnet56( + cfg.model.cut_layer, num_classes=cfg.dataset.num_classes + ) + elif name == "resnet110": + client_model, server_model = resnet_models.resnet110( + cfg.model.cut_layer, num_classes=cfg.dataset.num_classes + ) + elif name == "resnet1202": + client_model, server_model = resnet_models.resnet1202( + cfg.model.cut_layer, num_classes=cfg.dataset.num_classes + ) + else: + raise ValueError(f"Unknown model name {name}") + + if load_weights: + if exists(client_model_load_path): + _load_weights(client_model, client_model_load_path) + if exists(server_model_load_path): + _load_weights(server_model, server_model_load_path) + + return client_model, server_model diff --git a/edml/helpers/load_optimizer.py b/edml/helpers/load_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7987e4fa84c0d559952fdac750b58d642d5a8d18 --- /dev/null +++ b/edml/helpers/load_optimizer.py @@ -0,0 +1,34 @@ +from typing import Any, Tuple, Optional + +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.optim import Optimizer + + +def get_optimizer_and_scheduler( + cfg: DictConfig, model_params: Any +) -> Tuple[Optimizer, Optional[Any]]: + """ + Returns the optimizer for the given configuration. Optionally, a learning rate scheduler is returned as well. + + Args: + cfg (DictConfig): The experiment's configuration. + model_params (Any): Model parameters passed to the optimizer. + + Returns: + An optimizer instance with the given parameters. + + Notes: + Assumes the all optimizer-related config parameters to be present. + If scheduler should be used, milestones and gamma must be present in the config. + """ + + # Instantiate the optimizer with the given parameters. + optimizer = instantiate(cfg.optimizer, model_params) + + # Optionally, instantiate the scheduler with the given parameters. + scheduler = None + if "scheduler" in cfg: + scheduler = instantiate(cfg.scheduler, optimizer) + + return optimizer, scheduler diff --git a/edml/helpers/logging.py b/edml/helpers/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e6af989d8038f06ef197bfb0c860a81200c420cb --- /dev/null +++ b/edml/helpers/logging.py @@ -0,0 +1,101 @@ +import abc +from typing import Any + +import wandb +from omegaconf import DictConfig + + +class SimpleLogger(abc.ABC): + """ + A simple logger abstraction. + """ + + @abc.abstractmethod + def log(self, message: Any): + """ + Logs the given data. + """ + pass + + def start_experiment(self): + """ + Event hook called at the start of an experiment. + """ + print("Starting experiment...") + + def end_experiment(self): + """ + Event hook called at the end of an experiment. + """ + print("Ending experiment...") + + +class ConsoleLogger(SimpleLogger): + """ + A logger implementation that logs everything to the console. + + Attributes: + device_id (str): The device id of the device that logs the data. + """ + + def __init__(self, device_id: str): + self.device_id = device_id + + def log(self, message: Any): + print(f"Device {self.device_id}: {message}") + + +class WandbLogger(SimpleLogger): + """ + A logger implementation logging evaluation metrics to wandb.ai. + """ + + def __init__(self, cfg: DictConfig, device_id: str): + self.device_id = device_id + self.wandb_enabled = False + self.cfg = cfg + + def start_experiment(self): + """ + Override event hook to start login and initialize wandb client. + """ + with open(self.cfg.wandb.key_path, "r") as f: + key = f.read().strip() + wandb.login(key=key) + wandb.init( + entity=self.cfg.wandb.entity, + project=self.cfg.experiment.project, # project = set of experiments + job_type=self.cfg.experiment.job, # train or test + group=self.cfg.group, + name=self.cfg.own_device_id, # name runs by device id + config=dict(self.cfg), + ) + self.wandb_enabled = True + + def end_experiment(self): + """Ends the wandb run.""" + wandb.finish() + self.wandb_enabled = False + + def log(self, message: Any): + """ + Override to differentiate between dictionary messages and normal string messages. + + Dict values are logged to wandb as captured metrics whereas strings are logged to the console. + """ + if type(message) is dict and self.wandb_enabled: + wandb.log(message) + else: + if not self.wandb_enabled and type(message) is dict: + print("Wandb not running, printing to stdout instead.") + print(f"Device {self.device_id}: {message}") + + +def create_logger(cfg: DictConfig): + """ + Creates the right logger based on the provided experiment configuration. + """ + if cfg.wandb.enabled: + return WandbLogger(cfg=cfg, device_id=cfg.own_device_id) + else: + return ConsoleLogger(device_id=cfg.own_device_id) diff --git a/edml/helpers/metrics.py b/edml/helpers/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ba068ad4d802aa5bfcb825ae92689272f17132 --- /dev/null +++ b/edml/helpers/metrics.py @@ -0,0 +1,527 @@ +from __future__ import annotations + +import abc +from typing import List, Optional, Dict, Tuple, Any + +import torchmetrics +from torch import Tensor +from torchmetrics import Accuracy, F1Score, AUROC + + +def create_metrics( + metric_names: List[str], num_classes: int, average_setting: str +) -> ModelMetricContainer: + """ + Returns the torchmetrics metrics instance based on the name and settings. + + Args: + metric_names (List[str]): list of metric names to instantiate. + num_classes (int): number of classes. + average_setting (str): defines the reduction to be applied. + + Returns: + ModelMetricContainer: A container holding an instance for each passed metric name. + """ + result = ModelMetricContainer() + if "accuracy" in metric_names: + acc = Accuracy(num_classes=num_classes, average=average_setting) + acc.custom_name = "accuracy" + acc.num_samples = 0 + result.add_metric(acc, "accuracy") + if "f1" in metric_names: + f1 = F1Score(num_classes=num_classes, average=average_setting) + f1.custom_name = "f1" + f1.num_samples = 0 + result.add_metric(f1, "f1") + if "auc" in metric_names: + auc = AUROC(num_classes=num_classes, average=average_setting) + auc.custom_name = "auc" + auc.num_samples = 0 # doesn't seem to be an official attribute. + result.add_metric(auc, "auc") + return result + + +class MetricResult(abc.ABC): + """ + Base class for metric results. Defines shared attributes and required methods. + + Args: + device_id (str): The device's ID the metric has been measured on. + name (str): The name of the metric. + """ + + def __init__(self, device_id: str, name: str): + self.device_id = device_id + self.name = name + + @abc.abstractmethod + def as_loggable_dict(self): + """ + Returns the result in a convenient dictionary format. Used for structural logging. + """ + pass + + @abc.abstractmethod + def __eq__(self, other: Any) -> bool: + """For unit testing purposes.""" + pass + + +class ModelMetricResult(MetricResult): + """ + A class representing a single metric result for a given device and experiment phase. + + Attributes: + device_id (str): The device's ID the metric has been measured on. + name (str): The name of the metric. + phase (str): The phase during which the metric has been measured. Can be 'train' or 'test' or similar values. + num_samples (int): The number of samples that were used to compute the metric. + """ + + def __init__(self, device_id: str, name: str, phase: str, value, num_samples: int): + super().__init__(device_id, name) + self.phase = phase + self.value = value + self.num_samples = num_samples + + def as_loggable_dict(self, round_no: Optional[int] = None): + """ + Returns the result in a convenient dictionary format for logging. + + Args: + round_no (int, optional): The number of the current round. If given, the round number is added to the + dictionary. + + Returns: + A dictionary with the metric name as key and a dictionary with the value and the number of samples as value. + + Notes: + Does not include the device id, as this information is added by the logger. + """ + if round_no is None: + return { + f"{self.phase}_{self.name}": { + "value": self.value, + "num_samples": self.num_samples, + } + } + return { + f"{self.phase}_{self.name}": { + "value": self.value, + "num_samples": self.num_samples, + "round": round_no, + } + } + + def __eq__(self, other): + """For unit testing purposes.""" + return ( + self.device_id == other.device_id + and self.name == other.name + and self.phase == other.phase + and self.value == other.value + and self.num_samples == other.num_samples + ) + + +class DiagnosticMetricResult(MetricResult): + + def __init__(self, device_id: str, name: str, method: str, value: float): + """ + Initializes the metric result with the given values. + + Attributes: + device_id (str): The id of the device. + name (str): The name of the metric such as time, flops or bytes. + method (str): The method of the metric such as TrainEpoch. + value (float): The value of the metric. + + Notes: + Diagnostic metrics are not aggregated over multiple devices. + Furthermore, there is no specific container, since the metrics are already finalized upon initialization. + """ + super().__init__(device_id, name) + self.value = value + self.method = method + + def as_loggable_dict(self): + return {f"{self.method}_{self.name}": {"value": self.value}} + + def __eq__(self, other) -> bool: + """For unit testing purposes.""" + return ( + self.device_id == other.device_id + and self.name == other.name + and self.method == other.method + and self.value == other.value + ) + + +class ModelMetricContainer: + """ + A container class for multiple `ModelMetricResult`s. + + Attributes: + metrics (Dict[str, Metric]): A dictionary of metric name and instance pairs. + """ + + def __init__(self): + self.metrics: Dict[str, torchmetrics.Metric] = {} + + def add_metric(self, metric: torchmetrics.Metric, name: str): + """ + Adds a metric to the metric container. + + Args: + metric (torchmetrics.Metric): The metric to add. + name (str): The name of the metric. + """ + self.metrics[name] = metric + + def metrics_on_batch(self, prediction: Tensor, labels: Tensor) -> list: + """ + Appends the predictions and labels to the metric objects for the current batch. + + Args: + prediction (torch.Tensor): The prediction. + labels (torch.Tensor): The labels. + + Returns: + The results of the batch a list of the outputs of the torchmetrics. + + Notes: + Assumes that the first dimension of the predictions is the batch dimension. + """ + result = [] + for metric in self.metrics.values(): + result.append(metric(prediction, labels)) + metric.num_samples += prediction.shape[ + 0 + ] # add size of batch dimension to number of samples + return result + + def compute_metrics(self, phase: str, device_id: str) -> List[ModelMetricResult]: + """ + Computes the overall metrics based on all accumulated values so far. + + Args: + phase (str): The phase to compute the metrics for. + device_id (str): The device ID to filter the collected metrics by. + + Returns: + List[ModelMetricResult]: A list of metrics inside the container, filtered. The list is empty if no metrics + are available or no predictions have been made so far. + """ + result = [] + for name, metric in self.metrics.items(): + if metric.num_samples > 0: + if hasattr(metric, "custom_name"): + result.append( + ModelMetricResult( + device_id, + metric.custom_name, + phase, + metric.compute(), + metric.num_samples, + ) + ) + else: + result.append( + ModelMetricResult( + device_id, + type(metric), # FIXME: should probably be str(type(metric)) + phase, + metric.compute(), + metric.num_samples, + ) + ) + return result + + def reset_metrics(self): + """ + Resets the given metric objects by calling their `reset` method. Sets the `num_samples` attribute to 0. + """ + for metric in self.metrics.values(): + metric.reset() + metric.num_samples = 0 + + +class ModelMetricResultContainer: + """ + A container class for holding multiple `ModelMetricResult` instances. + + Attributes: + results (Dict[Tuple[str, str], List[ModelMetricResult]]): A dictionary holding the `ModelMetricResult` + instances. The key is a tuple that holds the name and the phase of the model metric result. The value is the + result itself. + """ + + def __init__(self, results: Optional[List[ModelMetricResult]] = None): + self.results: Dict[Tuple[str, str], List[ModelMetricResult]] = {} + if results is not None: + self.add_results(results) + + def get(self, key: Tuple[str, str]): + return self.results[key] + + def add_results(self, results: List[ModelMetricResult]): + """ + Adds the `ModelMetricResult`s to the container. + """ + for result in results: + self.add_result(result) + + def add_result(self, result: ModelMetricResult): + """ + Adds a single `ModelMetricResult` to the container. + + Args: + result (ModelMetricResult): The result to add. + """ + name = result.name + phase = result.phase + if (name, phase) in self.results: + self.results[(name, phase)] += [result] + else: + self.results[(name, phase)] = [result] + + def merge(self, other): + """Merges the results of the other MetricResultContainer into this one.""" + for key, value in other.results.items(): + if key in self.results: + self.results[key] += value + else: + self.results[key] = value + + def get_aggregated_metrics(self) -> ModelMetricResultContainer: + """Returns a new MetricResultContainer with the aggregated metrics.""" + aggregates = ModelMetricResultContainer() + for key, result_list in self.results.items(): + aggregates.add_result( + ModelMetricResult( + device_id="aggregated", + name=key[0], + phase=key[1], + value=sum( + [result.value * result.num_samples for result in result_list] + ) + / sum([result.num_samples for result in result_list]), + num_samples=sum([result.num_samples for result in result_list]), + ) + ) + return aggregates + + def get_as_list(self) -> List[ModelMetricResult]: + """Returns the results as a list.""" + result = [] + for key, result_list in self.results.items(): + result += result_list + return result + + def get_raw_metrics(self): + return self.results + + def __eq__(self, other) -> bool: + """For unit testing purposes.""" + for key, value in self.results.items(): + if key not in other.results: + return False + if len(value) != len(other.results[key]): + return False + for i in range(len(value)): + if value[i] != other.results[key][i]: + return False + return True + + +class DiagnosticMetricResultContainer: + + def __init__(self, results: Optional[List[DiagnosticMetricResult]] = None): + self.results: Dict[(str, str), List[DiagnosticMetricResult]] = {} + if results is not None: + self.add_results(results) + + def add_results(self, results: List[DiagnosticMetricResult]): + for result in results: + self.add_result(result) + + def add_result(self, result: DiagnosticMetricResult): + name = result.name + method = result.method + if (name, method) in self.results: + self.results[(name, method)] += [result] + else: + self.results[(name, method)] = [result] + + def merge(self, other): + """Merges the results of the other MetricResultContainer into this one.""" + if other is not None and type(other) == DiagnosticMetricResultContainer: + for key, value in other.results.items(): + if key in self.results: + self.results[key] += value + else: + self.results[key] = value + + def get_as_list(self) -> List[DiagnosticMetricResult]: + """Returns the results as a list.""" + result = [] + for key, result_list in self.results.items(): + result += result_list + return result + + def get_raw_metrics(self): + return self.results + + def __eq__(self, other): + """For unit testing purposes.""" + for key, value in self.results.items(): + if key not in other.results: + return False + if len(value) != len(other.results[key]): + return False + for i in range(len(value)): + if value[i] != other.results[key][i]: + return False + return True + + +def compute_metrics_for_optimization( + metrics: DiagnosticMetricResultContainer, + samples_per_device: Dict[str, Tuple[int, int]], + batch_size: int, +): + """ + Computes the metrics for the optimization. + + Args: + metrics (DiagnosticMetricResultContainer): The diagnostic metrics of one round. + samples_per_device (dict[str, (int, int)]): The number of samples per device. The key is the device id, the value is the number of samples (train, validation). + batch_size (int): The batch size. + + Returns: + A dictionary with the metrics. + + Raises: + KeyError: if keys are missing which may be the case if not all diagnostic metrics were reported e.g. because a device ran out of battery. + + Notes: + Assumes that the metrics result from exactly one round. I.e. we assume that there is only one server device and e.g. all train_batch metrics are from this device. + Also, we assume that the client train time per sample is larger than the client eval time per sample. + This should be given, due the additional time for the backpropagation during training. Otherwise, the results are not meaningful, for example if there is a huge lag during evaluation. + """ + raw_metrics = metrics.get_raw_metrics() + result = {} + num_devices = len(samples_per_device.keys()) + + # byte size per sample + result["gradient_size"] = max( + [metric.value / batch_size for metric in raw_metrics[("size", "gradients")]] + ) + result["label_size"] = max( + [metric.value / batch_size for metric in raw_metrics[("size", "labels")]] + ) + result["smashed_data_size"] = max( + [metric.value / batch_size for metric in raw_metrics[("size", "smashed_data")]] + ) + + # total byte size + result["client_weight_size"] = max( + [metric.value for metric in raw_metrics[("size", "client_weights")]] + ) # all values should be equal anyway + result["server_weight_size"] = max( + [metric.value for metric in raw_metrics[("size", "server_weights")]] + ) # all values should be equal anyway + result["optimizer_state_size"] = max( + [metric.value for metric in raw_metrics[("size", "optimizer_state")]] + ) # all values should be equal anyway + + # time + result["train_global_time"] = max( + [metric.value for metric in raw_metrics[("comp_time", "train_global")]] + ) # should be only one value anyway + avg_server_model_train_time_per_sample = ( + sum([metric.value for metric in raw_metrics[("comp_time", "train_batch")]]) + / len(raw_metrics[("comp_time", "train_batch")]) + / batch_size + ) + avg_server_model_evaluate_time_per_sample = ( + sum([metric.value for metric in raw_metrics[("comp_time", "evaluate_batch")]]) + / len(raw_metrics[("comp_time", "evaluate_batch")]) + / batch_size + ) + client_train_time_per_sample = {} + client_eval_time_per_sample = {} + for device_id in samples_per_device.keys(): + # each device: compute the average training time on the client model per sample + train_epoch_time_list = [ + metric.value + for metric in raw_metrics[("comp_time", "client_train_epoch_time")] + if metric.device_id == device_id + ] # in case of multiple epochs on one device per round + client_train_time_per_sample[device_id] = sum(train_epoch_time_list) / ( + samples_per_device[device_id][0] * len(train_epoch_time_list) + ) + # analogously for the client eval time + eval_time_list = [ + metric.value + for metric in raw_metrics[("comp_time", "client_eval_epoch_time")] + if metric.device_id == device_id + ] + client_eval_time_per_sample[device_id] = sum(eval_time_list) / ( + samples_per_device[device_id][1] * len(eval_time_list) + ) + # min train time per sample on client + # use eval time as estimate for the forward pass + result["client_norm_fw_time"] = min( + [ + client_eval_time_per_sample[device_id] + for device_id in client_train_time_per_sample.keys() + ] + ) + # backward pass = train time - forward pass + bw_estimate = [ + client_train_time_per_sample[device_id] - client_eval_time_per_sample[device_id] + for device_id in client_train_time_per_sample.keys() + if client_train_time_per_sample[device_id] + - client_eval_time_per_sample[device_id] + > 0 + ] + # make sure that training time is larger than eval time + if len(bw_estimate) > 0: + result["client_norm_bw_time"] = min(bw_estimate) + # otherwise estimate the bw time with the training time per sample (which is still smaller than the eval time then) + else: + result["client_norm_bw_time"] = min( + [ + client_train_time_per_sample[device_id] + for device_id in client_train_time_per_sample.keys() + ] + ) + + # comp speed, normalized to the fastest device + min_client_train_time = min(client_train_time_per_sample.values()) + result["comp_latency_factor"] = { + device_id: client_train_time_per_sample[device_id] / min_client_train_time + for device_id in client_train_time_per_sample.keys() + } + + # weight train time per sample on server with the server device's comp speed + server_device_id = raw_metrics[("comp_time", "train_batch")][0].device_id + result["server_norm_fw_time"] = ( + avg_server_model_evaluate_time_per_sample + * result["comp_latency_factor"][server_device_id] + ) + # again make sure that training time is larger than eval time + server_bw_estimate = ( + avg_server_model_train_time_per_sample + - avg_server_model_evaluate_time_per_sample + ) * result["comp_latency_factor"][server_device_id] + if server_bw_estimate > 0: + result["server_norm_bw_time"] = server_bw_estimate + else: + result["server_norm_bw_time"] = ( + avg_server_model_train_time_per_sample + * result["comp_latency_factor"][server_device_id] + ) + + return result diff --git a/edml/helpers/model_splitting.py b/edml/helpers/model_splitting.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce22bc9790a991c6f6e77641ec7e83533ab8bcb --- /dev/null +++ b/edml/helpers/model_splitting.py @@ -0,0 +1,56 @@ +from typing import List, Tuple + +from torch import nn as nn + + +class Part(nn.Module): + """ + A part of a model that has been split after a specific layer. + + Instances of this class are usually returned by calling py:meth:`~split_network_at_layer`. + + Attributes: + layers (List[nn.Module]): The layers of this part. + """ + + def __init__(self, layers: List[nn.Module]): + super(Part, self).__init__() + for idx, layer in enumerate(layers): + self.add_module(f"layer{idx}", layer) + self.layers = layers + + def forward(self, x): + """Calls each of its layers one after the other.""" + for layer in self.layers: + x = layer(x) + return x + + +def split_network_at_layer(network: nn.Module, cut_layer: int) -> Tuple[Part, Part]: + """ + Splits a given network at the given layer. + + Args: + network (nn.Module): The network to split. + cut_layer (int): The index of the layer to cut the network at. + + Returns: + Tuple[Part, Part]: The two parts of the network. + + Raises: + ValueError: If the `cut_layer` is out of the model's range. I.e., if the model has less than `cut_layer + 2` + layers. + + Notes: + The cut_layer is included in the first part. + Assumes that the network is a sequential model and all layers are defined in the correct order as children of + the given network. + """ + children = list(network.children()) + if cut_layer > len(children) - 1: + raise ValueError("cut_layer is out of the model's range.") + + part1 = Part(children[:cut_layer]) + part2 = Part(children[cut_layer:]) + + return part1, part2 diff --git a/edml/helpers/proto_helpers.py b/edml/helpers/proto_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a5ad26e8276bf3ddabede6128ad0373335a1a9 --- /dev/null +++ b/edml/helpers/proto_helpers.py @@ -0,0 +1,139 @@ +import io +import pickle +from typing import Union, Any, Tuple + +import torch + +from edml.generated import datastructures_pb2 +from edml.helpers.metrics import ( + ModelMetricResultContainer, + DiagnosticMetricResultContainer, + DiagnosticMetricResult, +) +from edml.helpers.types import StateDict + + +class CpuUnpickler(pickle.Unpickler): + # quickfix from https://github.com/pytorch/pytorch/issues/16797 + def find_class(self, module, name): + if module == "torch.storage" and name == "_load_from_bytes": + return lambda b: torch.load(io.BytesIO(b), map_location="cpu") + else: + return super().find_class(module, name) + + +def _tensor_to_bytes(tensor: torch.Tensor) -> bytes: + return pickle.dumps(tensor) + + +def _bytes_to_tensor(raw_bytes: bytes) -> torch.Tensor: + bs = io.BytesIO(raw_bytes) + return CpuUnpickler(bs).load() + + +def _state_dict_to_bytes(state_dict: StateDict) -> bytes: + return pickle.dumps(state_dict) + + +def _bytes_to_state_dict(raw_bytes: bytes) -> StateDict: + bs = io.BytesIO(raw_bytes) + return CpuUnpickler(bs).load() + + +def _metrics_to_bytes( + metrics: Union[ModelMetricResultContainer, DiagnosticMetricResultContainer] +) -> bytes: + return pickle.dumps(metrics) + + +def _bytes_to_metrics(raw_bytes: bytes): + bs = io.BytesIO(raw_bytes) + return CpuUnpickler(bs).load() + + +def tensor_to_proto(tensor: torch.Tensor) -> datastructures_pb2.Tensor: + return datastructures_pb2.Tensor(serialized=_tensor_to_bytes(tensor)) + + +def proto_to_tensor(proto: datastructures_pb2.Tensor) -> torch.Tensor: + return _bytes_to_tensor(proto.serialized) + + +def state_dict_to_proto(state_dict: StateDict) -> datastructures_pb2.StateDict: + return datastructures_pb2.StateDict(serialized=_state_dict_to_bytes(state_dict)) + + +def proto_to_state_dict(proto: datastructures_pb2.StateDict) -> StateDict: + return _bytes_to_state_dict(proto.serialized) + + +def weights_to_proto(weights: dict) -> datastructures_pb2.Weights: + return datastructures_pb2.Weights(weights=state_dict_to_proto(weights)) + + +def proto_to_weights(proto: datastructures_pb2.Weights): + return proto_to_state_dict(proto.weights) + + +def activations_to_proto(activations: torch.Tensor) -> datastructures_pb2.Activations: + return datastructures_pb2.Activations(activations=tensor_to_proto(activations)) + + +def proto_to_activations(proto: datastructures_pb2.Activations) -> torch.Tensor: + return proto_to_tensor(proto.activations) + + +def labels_to_proto(labels: torch.Tensor) -> datastructures_pb2.Labels: + return datastructures_pb2.Labels(labels=tensor_to_proto(labels)) + + +def proto_to_labels(proto: datastructures_pb2.Labels) -> torch.Tensor: + return proto_to_tensor(proto.labels) + + +def proto_to_device_info(proto: datastructures_pb2.DeviceInfo) -> Tuple[str, str]: + return proto.device_id, proto.address + + +def device_info_to_proto(device_id: str, address: str) -> datastructures_pb2.DeviceInfo: + return datastructures_pb2.DeviceInfo(device_id=device_id, address=address) + + +def proto_to_gradients(proto: datastructures_pb2.Gradients) -> torch.Tensor: + return proto_to_tensor(proto.gradients) + + +def gradients_to_proto(gradients: torch.Tensor) -> datastructures_pb2.Gradients: + return datastructures_pb2.Gradients(gradients=tensor_to_proto(gradients)) + + +def proto_to_metrics(proto: datastructures_pb2.Metrics): + return _bytes_to_metrics(proto.metrics) + + +def metrics_to_proto( + metrics: Union[ModelMetricResultContainer, DiagnosticMetricResultContainer] +) -> datastructures_pb2.Metrics: + return datastructures_pb2.Metrics(metrics=_metrics_to_bytes(metrics)) + + +def _proto_size_per_field( + proto_object: Any, device_id: str +) -> DiagnosticMetricResultContainer: + metrics = DiagnosticMetricResultContainer() + for attribute in proto_object.DESCRIPTOR.fields: + if attribute.name != "diagnostic_metrics": + try: + byte_size = getattr(proto_object, attribute.name).ByteSize() + metrics.add_result( + DiagnosticMetricResult( + device_id=device_id, + name="size", + method=attribute.name, + value=byte_size, + ) + ) + except AttributeError: + # ignore primitive types without ByteSize() method + pass + return metrics diff --git a/edml/helpers/types.py b/edml/helpers/types.py new file mode 100644 index 0000000000000000000000000000000000000000..d7de75bc473ac6c7d31ab32dbaa03365d6cb263a --- /dev/null +++ b/edml/helpers/types.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass +from typing import Mapping, Any, Tuple, Protocol, Union, Optional, Sequence, Callable + +from torch import Tensor +from torch.utils.data import DataLoader + +from edml.controllers.scheduler.base import NextServerScheduler +from edml.generated import datastructures_pb2 + +# TODO: use `type` keyword once Python>3.9 approved. + +# Return type of the `get_dataloaders` function. +DatasetDataLoaders = Tuple[DataLoader, DataLoader, DataLoader] + +# pytorch type for working with state dictionaries. +StateDict = Mapping[str, Any] + +LossFn = Callable[[Tensor, Tensor], Tensor] + + +@dataclass +class HasMetrics(Protocol): + """ + A subtype that has access to a pickled diagnostic metrics structure. + + Attributes: + diagnostic_metrics (datastructures_pb2.DiagnosticMetrics): The pickled metrics. + """ + + diagnostic_metrics: datastructures_pb2.Metrics + + +class KeyedNextServerScheduler(Protocol): + KEY: str + + # From the `NextServerScheduler` interface. + def next_server(self, active_devices: Sequence[str]) -> Optional[str]: ... + + +@dataclass +class DeviceBatteryStatus: + current_capacity: float + initial_capacity: float + + @staticmethod + def from_tuple(t: tuple[float, float]) -> "DeviceBatteryStatus": + return DeviceBatteryStatus(initial_capacity=t[0], current_capacity=t[1]) + + +DeviceBatteryStatusReport = Union[DeviceBatteryStatus, bool] + + +@dataclass +class SLTrainBatchResult: + smashed_data: Any + labels: Any diff --git a/edml/helpers/units.py b/edml/helpers/units.py new file mode 100644 index 0000000000000000000000000000000000000000..33e39a13112b7f6fa0c517a2c182755f9dcc5eda --- /dev/null +++ b/edml/helpers/units.py @@ -0,0 +1,7 @@ +from typing import Final + +#: Factor to divide by to convert `u` into mega-u. +MEGA_FACTOR: Final[int] = 1_000_000 + +#: Factor to divide by to convert `u` into giga-u. +GIGA_FACTOR: Final[int] = MEGA_FACTOR * 1000 diff --git a/edml/models/.gitignore b/edml/models/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d301f7ed45f2ddbf169a9de487e28c286c9f2be5 --- /dev/null +++ b/edml/models/.gitignore @@ -0,0 +1,2 @@ +weights/* +!weights/initial diff --git a/edml/models/__init__.py b/edml/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/models/autoencoder.py b/edml/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ed16f3000d694c53856e901a2634776fcf488433 --- /dev/null +++ b/edml/models/autoencoder.py @@ -0,0 +1,24 @@ +import torch +from torch import nn + + +class ClientWithAutoencoder(nn.Module): + def __init__(self, model: nn.Module, autoencoder: nn.Module): + super().__init__() + self.model = model + self.autoencoder = autoencoder.requires_grad_(False) + + def forward(self, x): + x = self.model(x) + return self.autoencoder(x) + + +class ServerWithAutoencoder(nn.Module): + def __init__(self, model: nn.Module, autoencoder: nn.Module): + super().__init__() + self.model = model + self.autoencoder = autoencoder.requires_grad_(False) + + def forward(self, x): + x = self.autoencoder(x) + return self.model(x) diff --git a/edml/models/mnist_models.py b/edml/models/mnist_models.py new file mode 100644 index 0000000000000000000000000000000000000000..756979cc1b9962b37c7c31b1f61f9d874e1bdf2a --- /dev/null +++ b/edml/models/mnist_models.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Simple example adapted from https://github.com/pytorch/examples/blob/main/mnist/main.py + + +class ClientNet(nn.Module): + """ + Client-side neural network for the MNIST dataset. + + Implementation based on https://github.com/pytorch/examples/blob/main/mnist/main.py. + """ + + def __init__(self): + super(ClientNet, self).__init__() + self.conv1 = nn.Conv2d(1, 16, 3, 1) + self.conv2 = nn.Conv2d(16, 64, 3, 1) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + return x + + +class ServerNet(nn.Module): + """ + Server-side neural network for the MNIST dataset. + + Implementation based on https://github.com/pytorch/examples/blob/main/mnist/main.py. + """ + + def __init__(self): + super(ServerNet, self).__init__() + self.conv3 = nn.Conv2d(64, 128, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(12800, 300) + self.fc2 = nn.Linear(300, 10) + + def forward(self, x): + x = self.dropout1(x) + x = self.conv3(x) + x = F.relu(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output diff --git a/edml/models/partials/mnist.py b/edml/models/partials/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4e5ad4099457f6e2cf1905f5d337dcce5f26e0 --- /dev/null +++ b/edml/models/partials/mnist.py @@ -0,0 +1,37 @@ +from torch import nn + + +class Decoder(nn.Module): + """ + decoder model + """ + + def __init__(self): + super(Decoder, self).__init__() + self.t_convx = nn.ConvTranspose2d(4, 8, 1, stride=1) + self.t_conva = nn.ConvTranspose2d(8, 16, 1, stride=1) + self.t_convb = nn.ConvTranspose2d(16, 64, 1, stride=1) + + def forward(self, x): + x = self.t_convx(x) + x = self.t_conva(x) + x = self.t_convb(x) + return x + + +class Encoder(nn.Module): + """ + encoder model + """ + + def __init__(self): + super(Encoder, self).__init__() + self.conva = nn.Conv2d(64, 16, 3, padding=1) + self.convb = nn.Conv2d(16, 8, 3, padding=1) + self.convc = nn.Conv2d(8, 4, 3, padding=1) + + def forward(self, x): + x = self.conva(x) + x = self.convb(x) + x = self.convc(x) + return x diff --git a/edml/models/partials/resnet.py b/edml/models/partials/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..af9c87cc07e8e8b892684322de3639836ee94d8f --- /dev/null +++ b/edml/models/partials/resnet.py @@ -0,0 +1,37 @@ +from torch import nn + + +class Decoder(nn.Module): + """ + decoder model + """ + + def __init__(self): + super(Decoder, self).__init__() + self.t_convx = nn.ConvTranspose2d(4, 8, 1, stride=1) + self.t_conva = nn.ConvTranspose2d(8, 16, 1, stride=1) + self.t_convb = nn.ConvTranspose2d(16, 16, 1, stride=1) + + def forward(self, x): + x = self.t_convx(x) + x = self.t_conva(x) + x = self.t_convb(x) + return x + + +class Encoder(nn.Module): + """ + encoder model + """ + + def __init__(self): + super(Encoder, self).__init__() + self.conva = nn.Conv2d(16, 16, 3, padding=1) + self.convb = nn.Conv2d(16, 8, 3, padding=1) + self.convc = nn.Conv2d(8, 4, 3, padding=1) + + def forward(self, x): + x = self.conva(x) + x = self.convb(x) + x = self.convc(x) + return x diff --git a/edml/models/provider/__init__.py b/edml/models/provider/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/models/provider/autoencoder.py b/edml/models/provider/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2d5c204af9ed5d70363dc5c0d14590ea4af0ce79 --- /dev/null +++ b/edml/models/provider/autoencoder.py @@ -0,0 +1,16 @@ +from torch import nn + +from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder +from edml.models.provider.base import ModelProvider + + +class AutoencoderModelProvider(ModelProvider): + def __init__( + self, model_provider: ModelProvider, encoder: nn.Module, decoder: nn.Module + ): + inner_client, inner_server = model_provider.models + + client_with_encoder = ClientWithAutoencoder(inner_client, encoder) + server_with_decoder = ServerWithAutoencoder(inner_server, decoder) + + super().__init__(client=client_with_encoder, server=server_with_decoder) diff --git a/edml/models/provider/base.py b/edml/models/provider/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e60736bae46fb8f3a2239c9ff016ddb32ce0578e --- /dev/null +++ b/edml/models/provider/base.py @@ -0,0 +1,11 @@ +from torch import nn + + +class ModelProvider: + def __init__(self, client: nn.Module, server: nn.Module): + self._client = client + self._server = server + + @property + def models(self) -> tuple[nn.Module, nn.Module]: + return self._client, self._server diff --git a/edml/models/provider/cut_layer.py b/edml/models/provider/cut_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..eda5e0915f41452a7ea7fc559ad5bb6a87828000 --- /dev/null +++ b/edml/models/provider/cut_layer.py @@ -0,0 +1,12 @@ +from torch import nn + +from edml.helpers.model_splitting import split_network_at_layer +from edml.models.provider.base import ModelProvider + + +class CutLayerModelProvider(ModelProvider): + def __init__(self, model: nn.Module, cut_layer: int): + self.cut_layer = cut_layer + + client, server = split_network_at_layer(network=model, cut_layer=cut_layer) + super().__init__(client=client, server=server) diff --git a/edml/models/provider/path.py b/edml/models/provider/path.py new file mode 100644 index 0000000000000000000000000000000000000000..6245270a9cc43cfea81eec58d03c4ed84fb746c7 --- /dev/null +++ b/edml/models/provider/path.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + + +class SerializedModel(nn.Module): + def __init__(self, model: nn.Module, path: str): + super().__init__() + model.load_state_dict(torch.load(path)) + self.model = model + + def forward(self, x): + return self.model(x) diff --git a/edml/models/resnet_models.py b/edml/models/resnet_models.py new file mode 100644 index 0000000000000000000000000000000000000000..4a7fef39d8d7f1b4b8374edfc2ee445c18388dbe --- /dev/null +++ b/edml/models/resnet_models.py @@ -0,0 +1,149 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +from edml.helpers.model_splitting import split_network_at_layer + + +# Implementation adapted from https://github.com/akamaster/pytorch_resnet_cifar10 +# Modified the network such that all layers are class attributes to be able to split the network at any layer and to enable FLOP estimation + + +def _weights_init(m): + classname = m.__class__.__name__ + # print(classname) + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + + +class LambdaLayer(nn.Module): + def __init__(self, lambd): + super(LambdaLayer, self).__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, option="A"): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != planes: + if option == "A": + """ + For CIFAR10 ResNet paper uses option A. + """ + self.shortcut = LambdaLayer( + lambda x: F.pad( + x[:, :, ::2, ::2], + (0, 0, 0, 0, planes // 4, planes // 4), + "constant", + 0, + ) + ) + elif option == "B": + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(self.expansion * planes), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=100): + super(ResNet, self).__init__() + self.in_planes = 16 + + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(16) + self.relu = nn.ReLU(inplace=False) + self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) + self.avg_pool_d = nn.AvgPool2d(8) + self.flatten = nn.Flatten() + self.linear = nn.Linear(64, num_classes) + + self.apply(_weights_init) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + + return nn.Sequential(*layers) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.avg_pool_d(out) + out = self.flatten(out) + out = self.linear(out) + return out + + +def resnet20(cut_layer: int, num_classes: int): + return split_network_at_layer( + ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes), cut_layer=cut_layer + ) + + +def resnet32(cut_layer: int, num_classes: int): + return split_network_at_layer( + ResNet(BasicBlock, [5, 5, 5], num_classes=num_classes), cut_layer=cut_layer + ) + + +def resnet44(cut_layer: int, num_classes: int): + return split_network_at_layer( + ResNet(BasicBlock, [7, 7, 7], num_classes=num_classes), cut_layer=cut_layer + ) + + +def resnet56(cut_layer: int, num_classes: int): + return split_network_at_layer( + ResNet(BasicBlock, [9, 9, 9], num_classes=num_classes), cut_layer=cut_layer + ) + + +def resnet110(cut_layer: int, num_classes: int): + return split_network_at_layer( + ResNet(BasicBlock, [18, 18, 18], num_classes=num_classes), cut_layer=cut_layer + ) + + +def resnet1202(cut_layer: int, num_classes: int): + return split_network_at_layer( + ResNet(BasicBlock, [200, 200, 200], num_classes=num_classes), + cut_layer=cut_layer, + ) diff --git a/edml/models/tcn_models.py b/edml/models/tcn_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e69bd71f21598659a9fc347fdb8b8e0e957b28f9 --- /dev/null +++ b/edml/models/tcn_models.py @@ -0,0 +1,272 @@ +import nemo +import torch +import torch.nn as nn + +""" +Taken from https://github.com/a-ayad/Split_ECG_Classification/ +""" + + +class Small_TCN_5_Client(nn.Module): + def __init__(self, classes, n_inputs): + super(Small_TCN_5_Client, self).__init__() + # Hyperparameters for TCN + Kt = 19 + pt = 0.3 + Ft = 11 + + self.pad0 = nn.ConstantPad1d(padding=(Kt - 1, 0), value=0) + self.conv0 = nn.Conv1d( + in_channels=n_inputs, out_channels=n_inputs + 1, kernel_size=19, bias=False + ) + self.act0 = nn.ReLU() + self.batchnorm0 = nn.BatchNorm1d(num_features=n_inputs + 1) + + # First block + dilation = 1 + self.upsample = nn.Conv1d( + in_channels=n_inputs + 1, out_channels=Ft, kernel_size=1, bias=False + ) + self.upsamplerelu = nn.ReLU() + self.upsamplebn = nn.BatchNorm1d(num_features=Ft) + self.pad1 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv1 = nn.Conv1d( + in_channels=n_inputs + 1, + out_channels=Ft, + kernel_size=Kt, + dilation=1, + bias=False, + ) + self.batchnorm1 = nn.BatchNorm1d(num_features=Ft) + self.act1 = nn.ReLU() + self.dropout1 = nn.Dropout(p=pt) + self.pad2 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv2 = nn.Conv1d( + in_channels=Ft, out_channels=Ft, kernel_size=Kt, dilation=1, bias=False + ) + self.batchnorm2 = nn.BatchNorm1d(num_features=Ft) + self.act2 = nn.ReLU() + self.dropout2 = nn.Dropout(p=pt) + self.add1 = nemo.quant.pact.PACT_IntegerAdd() + self.reluadd1 = nn.ReLU() + + # Second block + dilation = 2 + self.pad3 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv3 = nn.Conv1d( + in_channels=Ft, + out_channels=Ft, + kernel_size=Kt, + dilation=dilation, + bias=False, + ) + self.batchnorm3 = nn.BatchNorm1d(num_features=Ft) + self.act3 = nn.ReLU() + self.dropout3 = nn.Dropout(p=pt) + self.pad4 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv4 = nn.Conv1d( + in_channels=Ft, + out_channels=Ft, + kernel_size=Kt, + dilation=dilation, + bias=False, + ) + self.batchnorm4 = nn.BatchNorm1d(num_features=Ft) + self.act4 = nn.ReLU() + self.dropout4 = nn.Dropout(p=pt) + self.add2 = nemo.quant.pact.PACT_IntegerAdd() + self.reluadd2 = nn.ReLU() + self.double() # Convert to double precision + + def forward(self, x): + # Now we propagate through the network + x = self.pad0(x) + x = self.conv0(x) + x = self.batchnorm0(x) + x = self.act0(x) + + # TCN + # First block + res = self.pad1(x) + res = self.conv1(res) + res = self.batchnorm1(res) + res = self.act1(res) + res = self.dropout1(res) + res = self.pad2(res) + res = self.conv2(res) + res = self.batchnorm2(res) + res = self.act2(res) + res = self.dropout2(res) + + x = self.upsample(x) + x = self.upsamplebn(x) + x = self.upsamplerelu(x) + + x = self.add1(x, res) + x = self.reluadd1(x) + + # Second block + res = self.pad3(x) + # res = self.pad3(res) + res = self.conv3(res) + res = self.batchnorm3(res) + res = self.act3(res) + res = self.dropout3(res) + res = self.pad4(res) + res = self.conv4(res) + res = self.batchnorm4(res) + res = self.act4(res) + res = self.dropout4(res) + x = self.add2(x, res) + x = self.reluadd2(x) + return x + + +class Small_TCN_5_Server(nn.Module): + def __init__(self, classes, n_inputs): + super(Small_TCN_5_Server, self).__init__() + # Hyperparameters for TCN + Kt = 19 + pt = 0.3 + Ft = 11 + + # Third block + dilation = 4 + self.pad5 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv5 = nn.Conv1d( + in_channels=Ft, + out_channels=Ft, + kernel_size=Kt, + dilation=dilation, + bias=False, + ) + self.batchnorm5 = nn.BatchNorm1d(num_features=Ft) + self.act5 = nn.ReLU() + self.dropout5 = nn.Dropout(p=pt) + self.pad6 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv6 = nn.Conv1d( + in_channels=Ft, + out_channels=Ft, + kernel_size=Kt, + dilation=dilation, + bias=False, + ) + self.batchnorm6 = nn.BatchNorm1d(num_features=Ft) + self.act6 = nn.ReLU() + self.dropout6 = nn.Dropout(p=pt) + self.add3 = nemo.quant.pact.PACT_IntegerAdd() + self.reluadd3 = nn.ReLU() + + # fourth block + dilation = 8 + self.pad7 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv7 = nn.Conv1d( + in_channels=Ft, + out_channels=Ft, + kernel_size=Kt, + dilation=dilation, + bias=False, + ) + self.batchnorm7 = nn.BatchNorm1d(num_features=Ft) + self.act7 = nn.ReLU() + self.dropout7 = nn.Dropout(p=pt) + self.pad8 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv8 = nn.Conv1d( + in_channels=Ft, + out_channels=Ft, + kernel_size=Kt, + dilation=dilation, + bias=False, + ) + self.batchnorm8 = nn.BatchNorm1d(num_features=Ft) + self.act8 = nn.ReLU() + self.dropout8 = nn.Dropout(p=pt) + self.add4 = nemo.quant.pact.PACT_IntegerAdd() + self.reluadd4 = nn.ReLU() + + # fifth block + dilation = 16 + self.pad9 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv9 = nn.Conv1d( + in_channels=Ft, + out_channels=Ft, + kernel_size=Kt, + dilation=dilation, + bias=False, + ) + self.batchnorm9 = nn.BatchNorm1d(num_features=Ft) + self.act9 = nn.ReLU() + self.dropout9 = nn.Dropout(p=pt) + self.pad10 = nn.ConstantPad1d(padding=((Kt - 1) * dilation, 0), value=0) + self.conv10 = nn.Conv1d( + in_channels=Ft, + out_channels=Ft, + kernel_size=Kt, + dilation=dilation, + bias=False, + ) + self.batchnorm10 = nn.BatchNorm1d(num_features=Ft) + self.act10 = nn.ReLU() + self.dropout10 = nn.Dropout(p=pt) + self.add5 = nemo.quant.pact.PACT_IntegerAdd() + self.reluadd5 = nn.ReLU() + + # Last layer + self.linear = nn.Linear( + in_features=Ft * 1000, out_features=classes, bias=False + ) # Ft * 250 + self.double() # Convert to double precision + + def forward(self, x): + # Now we propagate through the network correctly + + # Third block + res = self.pad5(x) + # res = self.pad5(res) + res = self.conv5(res) + res = self.batchnorm5(res) + res = self.act5(res) + res = self.dropout5(res) + res = self.pad6(res) + res = self.conv6(res) + res = self.batchnorm6(res) + res = self.act6(res) + res = self.dropout6(res) + x = self.add3(x, res) + x = self.reluadd3(x) + + # Fourth block + res = self.pad7(x) + # res = self.pad5(res) + res = self.conv7(res) + res = self.batchnorm7(res) + res = self.act7(res) + res = self.dropout7(res) + res = self.pad8(res) + res = self.conv8(res) + res = self.batchnorm8(res) + res = self.act8(res) + res = self.dropout8(res) + x = self.add4(x, res) + x = self.reluadd4(x) + + # Fifth block + res = self.pad9(x) + # res = self.pad5(res) + res = self.conv9(res) + res = self.batchnorm9(res) + res = self.act9(res) + res = self.dropout9(res) + res = self.pad10(res) + res = self.conv10(res) + res = self.batchnorm10(res) + res = self.act10(res) + res = self.dropout10(res) + x = self.add5(x, res) + x = self.reluadd5(x) + + # Linear layer to classify + x = x.flatten(1) + o = self.linear(x) + o = torch.sigmoid(o) + return o # Return directly without softmax diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto new file mode 100644 index 0000000000000000000000000000000000000000..2f12881a1cf7e3ee6ae2e076b2f9005141cfece2 --- /dev/null +++ b/edml/proto/connection.proto @@ -0,0 +1,178 @@ +syntax = "proto3"; +import "datastructures.proto"; + +service Device { + rpc TrainGlobal (TrainGlobalRequest) returns (TrainGlobalResponse) {} + rpc SetWeights (SetWeightsRequest) returns (SetWeightsResponse) {} + rpc TrainEpoch (TrainEpochRequest) returns (TrainEpochResponse) {} + rpc TrainBatch (TrainBatchRequest) returns (TrainBatchResponse) {} + rpc EvaluateGlobal (EvalGlobalRequest) returns (EvalGlobalResponse) {} + rpc Evaluate (EvalRequest) returns (EvalResponse) {} + rpc EvaluateBatch (EvalBatchRequest) returns (EvalBatchResponse) {} + rpc FullModelTraining (FullModelTrainRequest) returns (FullModelTrainResponse) {} + rpc StartExperiment (StartExperimentRequest) returns (StartExperimentResponse) {} + rpc EndExperiment (EndExperimentRequest) returns (EndExperimentResponse) {} + rpc GetBatteryStatus (BatteryStatusRequest) returns (BatteryStatusResponse) {} + rpc GetDatasetModelInfo (DatasetModelInfoRequest) returns (DatasetModelInfoResponse) {} + + /// Invoked by the controller on the server device to start one round of parallel split learning. + rpc TrainGlobalParallelSplitLearning (TrainGlobalParallelSplitLearningRequest) returns (TrainGlobalParallelSplitLearningResponse) {} + rpc TrainSingleBatchOnClient (SingleBatchTrainingRequest) returns (SingleBatchTrainingResponse) {} + rpc BackwardPropagationSingleBatchOnClient(SingleBatchBackwardRequest) returns (SingleBatchBackwardResponse) {} + rpc SetGradientsAndFinalizeTrainingStep(SetGradientsRequest) returns (Empty) {} +} + +message SetGradientsRequest { + Gradients gradients = 1; +} + +message UpdateWeightsRequest { + Gradients gradients = 1; +} + +message SingleBatchBackwardRequest { + Gradients gradients = 1; +} + +message SingleBatchBackwardResponse { + Metrics metrics = 1; + optional Gradients gradients = 2; +} + +message SingleBatchTrainingRequest { + int32 batch_index = 1; +} + +message SingleBatchTrainingResponse { + optional Activations smashed_data = 1; + optional Labels labels = 2; +} + +message TrainGlobalParallelSplitLearningRequest { + optional int32 round_no = 1; + optional double adaptive_learning_threshold = 2; + optional StateDict optimizer_state = 3; +} + +message TrainGlobalParallelSplitLearningResponse { + Weights client_weights = 1; + Weights server_weights = 2; + Metrics metrics = 3; + optional StateDict optimizer_state = 4; + optional Metrics diagnostic_metrics = 5; +} + +message TrainGlobalRequest { + int32 epochs = 1; + optional int32 round_no = 2; + optional StateDict optimizer_state = 3; + +} + +message TrainGlobalResponse { + Weights client_weights = 1; + Weights server_weights = 2; + Metrics metrics = 3; + optional StateDict optimizer_state = 4; + optional Metrics diagnostic_metrics = 5; +} + +message SetWeightsRequest { + Weights weights = 1; + bool on_client = 2; +} + +message SetWeightsResponse { + optional Metrics diagnostic_metrics = 1; +} + +message TrainEpochRequest { + DeviceInfo server = 1; + optional int32 round_no = 2; +} + +message TrainEpochResponse { + Weights weights = 1; + optional Metrics diagnostic_metrics = 2; +} + +message TrainBatchRequest { + Activations smashed_data = 1; + Labels labels = 2; +} + +message TrainBatchResponse { + Gradients gradients = 1; + optional Metrics diagnostic_metrics = 2; + optional double loss = 3; +} + +message EvalGlobalRequest { + bool validation = 1; + bool federated = 2; +} + +message EvalGlobalResponse { + Metrics metrics = 1; + optional Metrics diagnostic_metrics = 2; +} + +message EvalRequest { + DeviceInfo server = 1; + bool validation = 2; +} + +message EvalResponse { + optional Metrics diagnostic_metrics = 1; +} + +message EvalBatchRequest { + Activations smashed_data = 1; + Labels labels = 2; +} + +message EvalBatchResponse { + Metrics metrics = 1; + optional Metrics diagnostic_metrics = 2; +} + +message FullModelTrainRequest { + optional int32 round_no = 1; +} + +message FullModelTrainResponse { + Weights client_weights = 1; + Weights server_weights = 2; + int32 num_samples = 3; + Metrics metrics = 4; + optional Metrics diagnostic_metrics = 5; +} + +message StartExperimentRequest {} + +message StartExperimentResponse { + optional Metrics diagnostic_metrics = 1; +} + +message EndExperimentRequest {} + +message EndExperimentResponse { + optional Metrics diagnostic_metrics = 1; +} + +message BatteryStatusRequest {} + +message BatteryStatusResponse { + BatteryStatus status = 1; + optional Metrics diagnostic_metrics = 2; +} + +message DatasetModelInfoRequest {} + +message DatasetModelInfoResponse { + int32 train_samples = 1; + int32 validation_samples = 2; + int32 client_model_flops = 3; + int32 server_model_flops = 4; + optional Metrics diagnostic_metrics = 5; +} diff --git a/edml/proto/datastructures.proto b/edml/proto/datastructures.proto new file mode 100644 index 0000000000000000000000000000000000000000..fa7d74a3b8b8796f71c7c7a76875d10aa5fb0167 --- /dev/null +++ b/edml/proto/datastructures.proto @@ -0,0 +1,45 @@ +syntax = "proto3"; + +message Tensor { + bytes serialized = 1; +} + +message StateDict { + bytes serialized = 1; +} + +message Weights { + StateDict weights = 1; +} + +message Labels { + Tensor labels = 1; +} + +message Activations { + Tensor activations = 1; +} + +message Gradients { + Tensor gradients = 1; +} + +message Metrics { + bytes metrics = 1; +} + +message Predictions { + Tensor predictions = 1; +} + +message Empty {} + +message DeviceInfo { + string device_id = 1; + string address = 2; +} + +message BatteryStatus { + double initial_battery_level = 1; + double current_battery_level = 2; +} diff --git a/edml/tests/__init__.py b/edml/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/edml/tests/controllers/base_controller_test.py b/edml/tests/controllers/base_controller_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5238ae39c2e535ac22443d8727a13b0939821e84 --- /dev/null +++ b/edml/tests/controllers/base_controller_test.py @@ -0,0 +1,78 @@ +import unittest +from unittest.mock import Mock + +from edml.controllers.base_controller import BaseController +from edml.core.device import DeviceRequestDispatcher +from edml.tests.controllers.test_helper import load_sample_config + + +class TestBaseController(BaseController): + """Subclass of BaseController for testing purposes.""" + + def _train(self): + pass + + +class BaseControllerTest(unittest.TestCase): + + def setUp(self): + self.base_controller = TestBaseController(load_sample_config()) + self.mock = Mock(spec=DeviceRequestDispatcher) + self.base_controller.request_dispatcher = self.mock + + def test_devices_empty_or_only_server_left_with_server_only(self): + self.mock.active_devices.return_value = ["d0"] + self.base_controller._refresh_active_devices() + self.assertTrue(self.base_controller._devices_empty_or_only_server_left("d0")) + + def test_devices_empty_or_only_server_left_with_no_devices(self): + self.mock.active_devices.return_value = [] + self.base_controller._refresh_active_devices() + self.assertTrue(self.base_controller._devices_empty_or_only_server_left("d0")) + + def test_devices_empty_or_only_server_left_with_two_devices(self): + self.mock.active_devices.return_value = ["d0", "d1"] + self.base_controller._refresh_active_devices() + self.assertFalse(self.base_controller._devices_empty_or_only_server_left("d0")) + self.assertFalse(self.base_controller._devices_empty_or_only_server_left("d1")) + + def test_battery_status(self): + self.mock.get_battery_status_on.return_value = 42 + self.assertEqual( + self.base_controller._get_battery_status(), {"d0": 42, "d1": 42} + ) + + def test_battery_status_with_empty_device(self): + self.mock.get_battery_status_on.return_value = 42 + self.assertEqual( + self.base_controller._get_battery_status(), {"d0": 42, "d1": 42} + ) + + def battery_status_side_effect(*args, **kwargs): + if args[0] == "d0": + return 42 + elif args[0] == "d1": + return False + + self.mock.get_battery_status_on.side_effect = battery_status_side_effect + self.mock.active_devices.return_value = [ + "d0", + "d1", + ] # d1 is inactive, but still in active_devices + self.assertEqual( + self.base_controller._get_battery_status(), {"d0": 42, "d1": None} + ) + + def test_update_device_battery_status(self): + self.mock.get_battery_status_on.return_value = 42 + self.base_controller._update_devices_battery_status() + self.assertEqual( + self.base_controller.device_batteries_status, {"d0": 42, "d1": 42} + ) + + self.mock.active_devices.return_value = ["d0"] + self.base_controller._refresh_active_devices() + self.base_controller._update_devices_battery_status() + self.assertEqual( + self.base_controller.device_batteries_status, {"d0": 42, "d1": None} + ) diff --git a/edml/tests/controllers/early_stopping_test.py b/edml/tests/controllers/early_stopping_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c449ff9db7573dcc7842855a0be23948fc4caa --- /dev/null +++ b/edml/tests/controllers/early_stopping_test.py @@ -0,0 +1,90 @@ +import unittest + +from edml.controllers.early_stopping import EarlyStopping +from edml.helpers.metrics import ModelMetricResultContainer, ModelMetricResult + + +def create_metric_result_container(round_results): + containers = [] + for round_result in round_results: + containers += [ + ModelMetricResultContainer( + [ModelMetricResult(*result) for result in round_result] + ) + ] + return containers + + +class EarlyStoppingTest(unittest.TestCase): + def setUp(self): + self.early_stopping = EarlyStopping(patience=2, metric="acc") + + def test_no_stop(self): + round_results = [ + [("d1", "acc", "val", 0.1, 42)], + [("d1", "acc", "val", 0.2, 42)], + [("d1", "acc", "val", 0.3, 42)], + ] + expected_result = [False, False, False] + + results = self._call_early_stopping(round_results) + + self.assertEqual(results, expected_result) + + def test_no_stop_for_missing_metric(self): + round_results = [ + [("d1", "acc", "val", 0.2, 42)], + [("d1", "f1", "val", 0.1, 42)], + [("d1", "acc", "train", 0.0, 42)], + [("d1", "acc", "test", 0.0, 42)], + ] + expected_result = [False, False, False, False] + + results = self._call_early_stopping(round_results) + + self.assertEqual(results, expected_result) + + def test_no_stop_with_aggregation(self): + round_results = [ + [("d1", "acc", "val", 0.2, 42), ("d2", "acc", "val", 0.2, 42)], + [("d1", "acc", "val", 0.21, 42), ("d2", "acc", "val", 0.2, 42)], # better + [("d1", "acc", "val", 0.21, 42), ("d2", "acc", "val", 0.2, 42)], # same + [("d1", "acc", "val", 0.3, 42), ("d2", "acc", "val", 0.13, 42)], # better + [("d1", "acc", "val", 0.29, 42), ("d2", "acc", "val", 0.13, 42)], + ] # worse + expected_result = [False, False, False, False, False] + + results = self._call_early_stopping(round_results) + + self.assertEqual(results, expected_result) + + def test_stop_with_aggregation_no_improvement(self): + round_results = [ + [("d1", "acc", "val", 0.2, 42), ("d2", "acc", "val", 0.2, 42)], + [("d1", "acc", "val", 0.2, 42), ("d2", "acc", "val", 0.2, 42)], + [("d1", "acc", "val", 0.3, 42), ("d2", "acc", "val", 0.1, 42)], + ] + expected_result = [False, False, True] + + results = self._call_early_stopping(round_results) + + self.assertEqual(results, expected_result) + + def test_stop_with_aggregation_getting_worse(self): + round_results = [ + [("d1", "acc", "val", 0.2, 42), ("d2", "acc", "val", 0.2, 42)], + [("d1", "acc", "val", 0.1, 42), ("d2", "acc", "val", 0.2, 42)], + [("d1", "acc", "val", 0.1, 42), ("d2", "acc", "val", 0.1, 42)], + ] + expected_result = [False, False, True] + + results = self._call_early_stopping(round_results) + + self.assertEqual(results, expected_result) + + def _call_early_stopping(self, round_results): + metrics = create_metric_result_container(round_results) + results = [] + for idx, call in enumerate(metrics): + results.append(self.early_stopping(call, idx)) + return results diff --git a/edml/tests/controllers/fed_controller_test.py b/edml/tests/controllers/fed_controller_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a44090eefc7b1fb3b791cc749c516835e6d8cc67 --- /dev/null +++ b/edml/tests/controllers/fed_controller_test.py @@ -0,0 +1,238 @@ +import math +import unittest +from unittest.mock import Mock, patch, call + +from omegaconf import DictConfig +from torch import Tensor + +from edml.controllers.fed_controller import FedController, fed_average +from edml.core.device import DeviceRequestDispatcher +from edml.generated.connection_pb2 import FullModelTrainResponse +from edml.helpers.metrics import ( + ModelMetricResultContainer, + ModelMetricResult, + DiagnosticMetricResultContainer, +) +from edml.helpers.proto_helpers import weights_to_proto, metrics_to_proto +from edml.tests.controllers.test_helper import load_sample_config, get_side_effect + + +class FederatedLearningTest(unittest.TestCase): + + def test_fed_average_unweighted(self): + avg = fed_average([{"weights": Tensor([42])}, {"weights": Tensor([43])}]) + self.assertEqual(avg, {"weights": Tensor([42.5])}) + + def test_fed_average_weighted_by_examples(self): + avg = fed_average( + [{"weights": Tensor([42])}, {"weights": Tensor([43])}], [10, 30] + ) + self.assertEqual(avg, {"weights": Tensor([42.75])}) + + def test_fed_average_weighted_by_percentage(self): + avg = fed_average( + [{"weights": Tensor([42])}, {"weights": Tensor([43])}], [0.25, 0.75] + ) + self.assertEqual(avg, {"weights": Tensor([42.75])}) + + +class FedControllerTest(unittest.TestCase): + def setUp(self) -> None: + self.fed_controller = FedController(load_sample_config()) + self.mock = Mock(spec=DeviceRequestDispatcher) + self.mock.active_devices.return_value = ["d0", "d1"] + self.fed_controller.request_dispatcher = self.mock + + def test_fed_train_round(self): + def fed_train_side_effect(*args, **kwargs): + if args[0] == "d0": + return ( + {"weights": Tensor([42])}, + {"weights": Tensor([43])}, + 100, + ModelMetricResultContainer( + [ModelMetricResult("d0", "acc", "train", Tensor([42]), 100)] + ), + DiagnosticMetricResultContainer(), + ) + elif args[0] == "d1": + return ( + {"weights": Tensor([44])}, + {"weights": Tensor([45])}, + 100, + ModelMetricResultContainer( + [ModelMetricResult("d1", "acc", "train", Tensor([44]), 100)] + ), + DiagnosticMetricResultContainer(), + ) + + self.mock.federated_train_on.side_effect = fed_train_side_effect + + client_weights, server_weights, metrics = self.fed_controller._fed_train_round() + + self.assertEqual(client_weights, {"weights": Tensor([43])}) + self.assertEqual(server_weights, {"weights": Tensor([44])}) + self.assertEqual( + metrics, + ModelMetricResultContainer( + [ + ModelMetricResult("d0", "acc", "train", Tensor([42]), 100), + ModelMetricResult("d1", "acc", "train", Tensor([44]), 100), + ] + ), + ) + + def test_train(self): + self.mock.federated_train_on.return_value = ( + {"weights": Tensor([42])}, + {"weights": Tensor([43])}, + 100, + ModelMetricResultContainer( + [ModelMetricResult("d0", "acc", "train", Tensor([42]), 100)] + ), + DiagnosticMetricResultContainer(), + ) + + self.fed_controller._train() + + self.mock.federated_train_on.assert_has_calls( + [ + (("d0", 0), {}), + (("d1", 0), {}), + ] + ) + self.mock.active_devices.assert_called() + self.mock.set_weights_on.assert_has_calls( + [ + call("d0", {"weights": Tensor([42])}, True, wait_for_ready=False), + call("d1", {"weights": Tensor([42])}, True, wait_for_ready=False), + call("d0", {"weights": Tensor([43])}, False, wait_for_ready=False), + call("d1", {"weights": Tensor([43])}, False, wait_for_ready=False), + ] + ) + + def test_train_with_empty_batteries_after_train_round(self): + def side_effect(*args, **kwargs): + if self.mock.federated_train_on.call_count > 1: + self.mock.active_devices.return_value = [] + return ( + {"weights": Tensor([42])}, + {"weights": Tensor([43])}, + 100, + ModelMetricResultContainer( + [ModelMetricResult("d0", "acc", "train", Tensor([42]), 42)] + ), + DiagnosticMetricResultContainer(), + ) + + self.mock.federated_train_on.side_effect = side_effect + + self.fed_controller._train() + + self.mock.federated_train_on.assert_has_calls( + [ + (("d0", 0), {}), + (("d1", 0), {}), + ] + ) + self.mock.active_devices.assert_called() + self.mock.set_weights_on.assert_not_called() + + +class FedTrainRoundThreadingTest(unittest.TestCase): + + def setUp(self): + self.fed_controller = FedController( + DictConfig( + { + "topology": {"devices": []}, + "num_devices": 0, + "wandb": {"enabled": False}, + "experiment": {"early_stopping": False}, + } + ) + ) + self.request_dispatcher = DeviceRequestDispatcher([]) + + def test_fed_train_round(self): + """Test _fed_train_round with an increasing number of threads. Mocking the behavior of the FedControllers + request_dispatcher to let every second request fail. + Afterward, the assertions make sure that the correct weights and metrics are returned. + Furthermore, the calls of fed_average are checked to be as expected.""" + for n in range(2, 20): + with self.subTest(method_name=f"Test fed train with {n} threads"): + with patch( + "edml.controllers.fed_controller.fed_average" + ) as fed_average_mock: + print(f"Test with {n} threads") + self.request_dispatcher.connections = { + f"d{i}": Mock(spec=["FullModelTraining"]) for i in range(n) + } + self.response = FullModelTrainResponse( + client_weights=weights_to_proto({"weights": Tensor([42])}), + server_weights=weights_to_proto({"weights": Tensor([43])}), + num_samples=100, + metrics=metrics_to_proto( + ModelMetricResultContainer( + [ + ModelMetricResult( + "d0", "acc", "train", Tensor([42]), 100 + ) + ] + ) + ), + diagnostic_metrics=metrics_to_proto( + DiagnosticMetricResultContainer() + ), + ) + for i, mock in enumerate( + self.request_dispatcher.connections.values() + ): + mock.FullModelTraining.side_effect = get_side_effect( + i, self.response + ) # let every second fail + self.fed_controller.request_dispatcher = self.request_dispatcher + self.fed_controller._refresh_active_devices() + + def mock_with_actual_behavior(*args, **kwargs): + """Mock with actual behavior to check if fed_average is called.""" + return fed_average(*args, **kwargs) + + fed_average_mock.side_effect = mock_with_actual_behavior + + client_weights, server_weights, metrics = ( + self.fed_controller._fed_train_round() + ) + + expected_valid_responses = math.ceil(n / 2) + self.assertEqual(client_weights, {"weights": Tensor([42])}) + self.assertEqual(server_weights, {"weights": Tensor([43])}) + self.assertEqual( + metrics.get_aggregated_metrics(), + ModelMetricResultContainer( + [ + ModelMetricResult( + "d0", + "acc", + "train", + Tensor([42]), + 100 * expected_valid_responses, + ) + ] + ).get_aggregated_metrics(), + ) + # Check that fed_average is called with the correct number of arguments + fed_average_mock.assert_has_calls( + [ + call( + model_weights=[{"weights": Tensor([42])}] + * expected_valid_responses, + weighting_scheme=[100] * expected_valid_responses, + ), + call( + model_weights=[{"weights": Tensor([43])}] + * expected_valid_responses, + weighting_scheme=[100] * expected_valid_responses, + ), + ] + ) diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8cc390c7a18e0d2b5c8e43822930973f32f953 --- /dev/null +++ b/edml/tests/controllers/optimization_test.py @@ -0,0 +1,452 @@ +import unittest + +from edml.controllers.strategy_optimization import ( + DeviceParams, + GlobalParams, + ServerChoiceOptimizer, + EnergySimulator, +) + + +class StrategyOptimizationTest(unittest.TestCase): + + def setUp(self): + self.device_params_list = [ + DeviceParams( + device_id="d0", + initial_battery=3000, + current_battery=10000, + train_samples=2, + validation_samples=1, + comp_latency_factor=1, + ), + DeviceParams( + device_id="d1", + initial_battery=2000, + current_battery=10000, + train_samples=3, + validation_samples=1, + comp_latency_factor=2, + ), + DeviceParams( + device_id="d2", + initial_battery=1000, + current_battery=10000, + train_samples=4, + validation_samples=1, + comp_latency_factor=0.5, + ), + ] + self.global_params = GlobalParams( + cost_per_sec=1, + cost_per_byte_sent=1, + cost_per_byte_received=1, + cost_per_flop=1, + client_model_flops=10, + server_model_flops=20, + smashed_data_size=50, + label_size=5, + gradient_size=50, + batch_size=1, + client_norm_fw_time=1, + client_norm_bw_time=2, + server_norm_fw_time=2, + server_norm_bw_time=4, + client_weights_size=10, + server_weights_size=10, + optimizer_state_size=10, + ) + self.optimizer = ServerChoiceOptimizer( + self.device_params_list, self.global_params + ) + + def test_num_devices(self): + self.assertEqual(self.optimizer._num_devices(), 3) + + def test_total_battery(self): + self.assertEqual(self.optimizer._total_battery(), 30000) + + def test_train_dataset_size(self): + self.assertEqual(self.optimizer._total_train_dataset_size(), 9) + + def test_validation_dataset_size(self): + self.assertEqual(self.optimizer._total_validation_dataset_size(), 3) + + def test_get_device_params(self): + self.assertEqual(self.optimizer._get_device_params("d0").device_id, "d0") + self.assertEqual(self.optimizer._get_device_params("d1").device_id, "d1") + self.assertEqual(self.optimizer._get_device_params("d2").device_id, "d2") + + def test_transmission_latency_no_values(self): + self.assertEqual(self.optimizer._transmission_latency(), 0) + + def test_transmission_latency_with_values(self): + self.global_params.train_global_time = 100 + self.global_params.last_server_device_id = "d0" + self.assertEqual(self.optimizer._transmission_latency(), 6.5) + + def test_round_runtime_with_server_no_latency(self): + self.assertEqual(self.optimizer.round_runtime_with_server("d0"), 93.5) + self.assertEqual(self.optimizer.round_runtime_with_server("d1"), 153.5) + self.assertEqual(self.optimizer.round_runtime_with_server("d2"), 63.5) + + def test_round_runtime_with_server_with_latency(self): + self.global_params.train_global_time = 100 + self.global_params.last_server_device_id = "d0" + self.assertEqual(self.optimizer.round_runtime_with_server("d0"), 100) + self.assertEqual(self.optimizer.round_runtime_with_server("d1"), 160) + self.assertEqual(self.optimizer.round_runtime_with_server("d2"), 70) + + def test_num_flops_per_round_on_device(self): + self.assertEqual(self.optimizer.num_flops_per_round_on_device("d0", "d0"), 670) + self.assertEqual(self.optimizer.num_flops_per_round_on_device("d1", "d0"), 100) + self.assertEqual(self.optimizer.num_flops_per_round_on_device("d2", "d0"), 130) + + def test_num_bytes_sent_per_round_on_device(self): + self.assertEqual( + self.optimizer.num_bytes_sent_per_round_on_device("d0", "d0"), 390 + ) + self.assertEqual( + self.optimizer.num_bytes_sent_per_round_on_device("d1", "d0"), 230 + ) + self.assertEqual( + self.optimizer.num_bytes_sent_per_round_on_device("d2", "d0"), 285 + ) + + def test_num_bytes_received_per_round_on_device(self): + self.assertEqual( + self.optimizer.num_bytes_received_per_round_on_device("d0", "d0"), 535 + ) + self.assertEqual( + self.optimizer.num_bytes_received_per_round_on_device("d1", "d0"), 160 + ) + self.assertEqual( + self.optimizer.num_bytes_received_per_round_on_device("d2", "d0"), 210 + ) + + def test_energy_per_round_on_device(self): + self.assertEqual(self.optimizer.energy_per_round_on_device("d0", "d0"), 1688.5) + self.assertEqual(self.optimizer.energy_per_round_on_device("d1", "d0"), 583.5) + self.assertEqual(self.optimizer.energy_per_round_on_device("d2", "d0"), 718.5) + + def test_optimize(self): + solution, status = self.optimizer.optimize() + self.assertEqual(solution, {"d0": 4.0, "d1": 3.0, "d2": 3.0}) + self.assertEqual(status, 0) + + +class EnergySimulatorTest(unittest.TestCase): + def setUp(self): + self.device_params_list = [ + DeviceParams( + device_id="d0", + initial_battery=3000, + current_battery=10000, + train_samples=2, + validation_samples=1, + comp_latency_factor=1, + ), + DeviceParams( + device_id="d1", + initial_battery=2000, + current_battery=10000, + train_samples=3, + validation_samples=1, + comp_latency_factor=2, + ), + DeviceParams( + device_id="d2", + initial_battery=1000, + current_battery=10000, + train_samples=4, + validation_samples=1, + comp_latency_factor=0.5, + ), + ] + self.global_params = GlobalParams( + cost_per_sec=1, + cost_per_byte_sent=1, + cost_per_byte_received=1, + cost_per_flop=1, + client_model_flops=10, + server_model_flops=20, + smashed_data_size=50, + label_size=5, + gradient_size=50, + batch_size=1, + client_norm_fw_time=1, + client_norm_bw_time=2, + server_norm_fw_time=2, + server_norm_bw_time=4, + client_weights_size=10, + server_weights_size=10, + optimizer_state_size=10, + ) + self.simulator = EnergySimulator(self.device_params_list, self.global_params) + + def test_simulate_greedy_selection(self): + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + self.assertEqual(num_rounds, 10) + self.assertEqual( + schedule, ["d0", "d1", "d2", "d0", "d1", "d2", "d0", "d1", "d0", "d2"] + ) + self.assertEqual(remaining_batteries, [465.0, 985.0, 265.0]) + + def test_simulate_smart_selection(self): + num_rounds, solution, remaining_batteries = ( + self.simulator.simulate_smart_selection() + ) + self.assertEqual(num_rounds, 10) + self.assertEqual(solution, {"d0": 4.0, "d1": 3.0, "d2": 3.0}) + self.assertEqual(remaining_batteries, [465.0, 985.0, 265.0]) + + def test_fl_round_time(self): + self.assertEqual(self.simulator._fl_round_time(), 60.0) + + def test_fl_flops_on_device(self): + self.assertEqual(self.simulator._fl_flops_on_device("d0"), 210) + + def test_fl_data_sent_per_device(self): + self.assertEqual(self.simulator._fl_data_sent_per_device(), 20) + + def test_fl_data_received_per_device(self): + self.assertEqual(self.simulator._fl_data_received_per_device(), 20) + + def test_fl_energy_per_round_on_device(self): + self.assertEqual(self.simulator._fl_energy_per_round_on_device("d0"), 310) + + def test_simulate_federated_learning(self): + num_rounds, remaining_batteries = self.simulator.simulate_federated_learning() + self.assertEqual(num_rounds, 20) + self.assertEqual(remaining_batteries, [3800, 2000, 200]) + + +class TestWithRealData(unittest.TestCase): + """Case studies for estimating the number of rounds of each strategy with given energy constratins.""" + + def setUp(self): + self.train_samples = [9600, 9600, 9600, 9600, 9600] + self.validation_samples = [2400, 2400, 2400, 2400, 2400] + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + self.comp_latency = [1, 1, 1, 1, 1] + self.cost_per_sec = 1 + self.cost_per_mbyte_sent = 1 + self.cost_per_mbyte_received = 1 + self.cost_per_mflop = 1 + + def _init_params_and_optimizer(self): + self.device_params_list = [ + DeviceParams( + device_id="d0", + initial_battery=3750, + current_battery=self.current_batteries[0], + train_samples=self.train_samples[0], + validation_samples=self.validation_samples[0], + comp_latency_factor=self.comp_latency[0], + ), + DeviceParams( + device_id="d1", + initial_battery=3750, + current_battery=self.current_batteries[1], + train_samples=self.train_samples[1], + validation_samples=self.validation_samples[1], + comp_latency_factor=self.comp_latency[1], + ), + DeviceParams( + device_id="d2", + initial_battery=3750, + current_battery=self.current_batteries[2], + train_samples=self.train_samples[2], + validation_samples=self.validation_samples[2], + comp_latency_factor=self.comp_latency[2], + ), + DeviceParams( + device_id="d3", + initial_battery=3750, + current_battery=self.current_batteries[3], + train_samples=self.train_samples[3], + validation_samples=self.validation_samples[3], + comp_latency_factor=self.comp_latency[3], + ), + DeviceParams( + device_id="d4", + initial_battery=3750, + current_battery=self.current_batteries[4], + train_samples=self.train_samples[4], + validation_samples=self.validation_samples[4], + comp_latency_factor=self.comp_latency[4], + ), + ] + self.global_params = GlobalParams( + cost_per_sec=self.cost_per_sec, + cost_per_byte_sent=self.cost_per_mbyte_sent / 1000000, + cost_per_byte_received=self.cost_per_mbyte_received / 1000000, + cost_per_flop=self.cost_per_mflop / 1000000, + client_model_flops=5405760, + server_model_flops=11215800, + smashed_data_size=36871, + label_size=14, + gradient_size=36871, + batch_size=64, + client_norm_fw_time=0.0001318198063394479, + client_norm_bw_time=1.503657614975644e-05, + server_norm_fw_time=2.1353501238321005e-05, + server_norm_bw_time=3.1509113154913254e-05, + client_weights_size=71678, + # train global response size 15878758 + server_weights_size=15807080, + optimizer_state_size=0, # was not recorded then + ) + self.optimizer = ServerChoiceOptimizer( + self.device_params_list, self.global_params + ) + self.simulator = EnergySimulator(self.device_params_list, self.global_params) + + def test_with_actual_data(self): + self.current_batteries = [3719.654, 3717.608, 3711.923, 3708.051, 3704.294] + self.cost_per_sec = 1 # 0.1 + self.cost_per_mbyte_sent = 0.05 # 0.002 + self.cost_per_mbyte_received = 0.05 # 0.0005 + self.cost_per_mflop = 0.00025 # 0.0005 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + self.assertEqual( + solution, {"d0": 4.0, "d1": 4.0, "d2": 4.0, "d3": 3.0, "d4": 0.0} + ) + + def test_simulate_greedy_selection_with_actual_data(self): + self.current_batteries = [3719.654, 3717.608, 3711.923, 3708.051, 3704.294] + self.cost_per_sec = 1 # 0.1 + self.cost_per_mbyte_sent = 0.05 # 0.002 + self.cost_per_mbyte_received = 0.05 # 0.0005 + self.cost_per_mflop = 0.00025 # 0.0005 + self._init_params_and_optimizer() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + self.assertEqual(num_rounds, 15) + self.assertEqual( + schedule, + [ + "d0", + "d1", + "d2", + "d3", + "d4", + "d0", + "d1", + "d2", + "d3", + "d4", + "d0", + "d1", + "d2", + "d3", + "d4", + ], + ) + + def test_unequal_split(self): + self.current_batteries = [5750, 4750, 3750, 2750, 1750] + self.train_samples = [4800, 9600, 19200, 7200, 7200] + self.validation_samples = [1200, 2400, 4800, 1800, 1800] + self.cost_per_sec = 1 + self.cost_per_mbyte_sent = 0.05 + self.cost_per_mbyte_received = 0.05 + self.cost_per_mflop = 0.00025 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + self.assertEqual( + solution, {"d0": 8.0, "d1": 5.0, "d2": 1.0, "d3": 2.0, "d4": 0.0} + ) + + def test_unequal_split_and_batteries_with_high_communication_cost(self): + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + data_partitions = [0.1, 0.1, 0.6, 0.1, 0.1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 0.1 + self.cost_per_mbyte_sent = 0.1 + self.cost_per_mbyte_received = 0.1 + self.cost_per_mflop = 0.00001 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertGreater(sum(solution.values()), num_rounds) + + def test_unequal_battery_unequal_processing_high_time_cost(self): + self.current_batteries = [5750, 4750, 3750, 2750, 1750] + data_partitions = [0.2, 0.1, 0.1, 0.1, 0.5] + self.comp_latency = [5, 5, 5, 5, 1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 10 + self.cost_per_mbyte_sent = 0.0 # 1 + self.cost_per_mbyte_received = 0.0 # 1 + self.cost_per_mflop = 0.0000 # 5 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds) + self.assertGreater(sum(solution.values()), num_rounds) + + def test_unequal_battery_unequal_processing_high_time_cost2(self): + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + data_partitions = [0.05, 0.1, 0.1, 0.1, 0.65] + self.comp_latency = [10, 10, 10, 10, 1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 3 + self.cost_per_mbyte_sent = 0.05 + self.cost_per_mbyte_received = 0.05 + self.cost_per_mflop = 0.000005 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertGreater(sum(solution.values()), num_rounds) + + def test_unequal_battery_unequal_processing_high_time_cost3(self): + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + data_partitions = [0.01, 0.01, 0.01, 0.01, 0.96] + self.comp_latency = [100, 100, 100, 100, 1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 3 + self.cost_per_mbyte_sent = 0.05 # 375 + self.cost_per_mbyte_received = 0.05 + self.cost_per_mflop = 0.000005 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertGreater(sum(solution.values()), num_rounds) + + def test_unequal_battery_unequal_processing_high_time_cost4(self): + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + data_partitions = [0.1, 0.1, 0.1, 0.1, 0.6] + self.comp_latency = [10, 10, 10, 10, 1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 4 + self.cost_per_mbyte_sent = 0.01 + self.cost_per_mbyte_received = 0.01 + self.cost_per_mflop = 0.00001 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertGreater(sum(solution.values()), num_rounds) diff --git a/edml/tests/controllers/sample_config.yaml b/edml/tests/controllers/sample_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..505bdae73da5d1fd3901b53d94ae2316cb6ae560 --- /dev/null +++ b/edml/tests/controllers/sample_config.yaml @@ -0,0 +1,26 @@ +topology: + devices: [ + { + device_id: "d0", + address: "localhost:50051" + }, + { + device_id: "d1", + address: "localhost:50052" + } + ] +num_devices: 2 +experiment: + max_rounds: 1 + save_weights: False + early_stopping: True + early_stopping_patience: 2 + early_stopping_metric: "accuracy" + batch_size: 1 +wandb: + enabled: False +battery: + deduction_per_second: 1 + deduction_per_mbyte_sent: 1 + deduction_per_mbyte_received: 1 + deduction_per_mflop: 1 diff --git a/edml/tests/controllers/scheduler/max_battery_test.py b/edml/tests/controllers/scheduler/max_battery_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ed808d5d1a033839c5fed9dfad1a0b33f541831f --- /dev/null +++ b/edml/tests/controllers/scheduler/max_battery_test.py @@ -0,0 +1,41 @@ +import unittest + +from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler +from edml.helpers.types import DeviceBatteryStatus + + +class MaxBatteryServerDeviceSelectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.active_devices = ["d0", "d1", "d2"] + self.scheduler = MaxBatteryNextServerScheduler() + + def test_select_first_device_for_equal_batteries(self): + self.scheduler._update_batteries_cb = lambda: { + "d0": DeviceBatteryStatus.from_tuple((100, 50)), + "d1": DeviceBatteryStatus.from_tuple((100, 50)), + "d2": DeviceBatteryStatus.from_tuple((100, 50)), + } + + server_device = self.scheduler.next_server(self.active_devices) + self.assertEqual(server_device, "d0") + + def test_select_device_with_max_battery(self): + self.scheduler._update_batteries_cb = lambda: { + "d0": None, + "d1": DeviceBatteryStatus.from_tuple((100, 50)), + "d2": DeviceBatteryStatus.from_tuple((100, 100)), + } + + server_device = self.scheduler.next_server(self.active_devices) + self.assertEqual(server_device, "d2") + + def test_no_server_for_no_active_device(self): + self.scheduler._update_batteries_cb = lambda: { + "d0": None, + "d1": None, + "d2": None, + } + + server_device = self.scheduler.next_server(self.active_devices) + self.assertEqual(server_device, None) diff --git a/edml/tests/controllers/scheduler/sequential_test.py b/edml/tests/controllers/scheduler/sequential_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0029292068599179121532791082e34ab70d5f02 --- /dev/null +++ b/edml/tests/controllers/scheduler/sequential_test.py @@ -0,0 +1,63 @@ +import unittest + +from edml.controllers.scheduler.sequential import SequentialNextServerScheduler + + +class SequentialServerDeviceSelectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.active_devices = ["d0", "d1", "d2"] + self.scheduler = SequentialNextServerScheduler() + self.scheduler.devices = self.active_devices + + def test_select_server_device_for_only_active_devices(self): + server_device = self.scheduler.next_server(self.active_devices) + self.assertEqual(server_device, "d0") + print("=1="), + + server_device = self.scheduler.next_server( + self.active_devices, last_server_device_id=server_device + ) + self.assertEqual(server_device, "d1") + print("=2=") + + server_device = self.scheduler.next_server( + self.active_devices, last_server_device_id=server_device + ) + self.assertEqual(server_device, "d2") + print("=3=") + + server_device = self.scheduler.next_server( + self.active_devices, last_server_device_id=server_device + ) + self.assertEqual(server_device, "d0") + print("=4=") + + def test_select_server_device_with_last_server_device_inactive(self): + self.active_devices = ["d1", "d2"] + + server_device = self.scheduler.next_server( + self.active_devices, last_server_device_id="d0" + ) + self.assertEqual(server_device, "d1") + + server_device = self.scheduler.next_server( + self.active_devices, last_server_device_id=server_device + ) + self.assertEqual(server_device, "d2") + + server_device = self.scheduler.next_server( + self.active_devices, last_server_device_id=server_device + ) + self.assertEqual(server_device, "d1") + + def test_select_same_server_device_if_all_other_devices_inactive(self): + self.active_devices = ["d1"] + + server_device = self.scheduler.next_server(self.active_devices) + self.assertEqual(server_device, "d1") + + server_device = self.scheduler.next_server( + self.active_devices, last_server_device_id=server_device + ) + self.assertEqual(server_device, "d1") diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb7e758073a82ff965eaf414fc37946164d2541 --- /dev/null +++ b/edml/tests/controllers/scheduler/smart_test.py @@ -0,0 +1,73 @@ +import unittest +from unittest.mock import Mock, patch + +from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler +from edml.controllers.scheduler.smart import SmartNextServerScheduler +from edml.helpers.metrics import DiagnosticMetricResultContainer +from edml.helpers.types import DeviceBatteryStatus +from edml.tests.controllers.test_helper import load_sample_config + + +class SmartServerDeviceSelectionTest(unittest.TestCase): + + def setUp(self): + self.diagnostic_metric_container = Mock(spec=DiagnosticMetricResultContainer) + self.metrics = { + "gradient_size": 100000, + "label_size": 2000, + "smashed_data_size": 100000, + "client_weight_size": 300000, + "server_weight_size": 300000, + "optimizer_state_size": 300000, + "client_norm_fw_time": 3, + "client_norm_bw_time": 3, + "server_norm_fw_time": 3, + "server_norm_bw_time": 3, + "comp_latency_factor": {"d0": 1, "d1": 1.01}, + } + self.scheduler = SmartNextServerScheduler( + fallback_scheduler=Mock(spec=MaxBatteryNextServerScheduler), + ) + self.scheduler.cfg = load_sample_config() + self.scheduler._update_batteries_cb = lambda: { + "d0": DeviceBatteryStatus.from_tuple((500, 500)), + "d1": DeviceBatteryStatus.from_tuple((500, 470)), + } + + self.scheduler._data_model_cb = lambda: [ + { + "d0": (2, 1), + "d1": (4, 1), + }, + {"client": 1000000, "server": 1000000}, + ] + + def test_select_server_device_smart_first_round(self): + # should select according to max battery + self.scheduler.next_server( + [""], diagnostic_metric_container=None, last_server_device_id=None + ) + self.scheduler.fallback_scheduler.next_server.assert_called_once() + + def test_select_server_device_smart_second_round(self): + # should select build a schedule and select the first element + with patch( + "edml.controllers.scheduler.smart.compute_metrics_for_optimization", + return_value=self.metrics, + ): + server_device = self.scheduler.next_server( + ["d0", "d1"], + diagnostic_metric_container=self.diagnostic_metric_container, + last_server_device_id=None, + ) + self.assertEqual(server_device, "d0") + + def test_get_selection_schedule(self): + with patch( + "edml.controllers.scheduler.smart.compute_metrics_for_optimization", + return_value=self.metrics, + ): + schedule = self.scheduler._get_selection_schedule( + self.diagnostic_metric_container, + ) + self.assertEqual(schedule, ["d0", "d1", "d1", "d1"]) diff --git a/edml/tests/controllers/split_controller_test.py b/edml/tests/controllers/split_controller_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8531a2d1fe0aa146494d1ffb45317739a1c01dc1 --- /dev/null +++ b/edml/tests/controllers/split_controller_test.py @@ -0,0 +1,51 @@ +import unittest +from unittest.mock import Mock + +from edml.controllers.split_controller import SplitController +from edml.core.device import DeviceRequestDispatcher +from edml.helpers.metrics import ( + ModelMetricResultContainer, + DiagnosticMetricResultContainer, +) +from edml.tests.controllers.test_helper import load_sample_config + + +class SplitControllerTest(unittest.TestCase): + def setUp(self) -> None: + self.cfg = load_sample_config() + self.split_controller = SplitController(self.cfg) + self.mock = Mock(spec=DeviceRequestDispatcher) + self.mock.active_devices.return_value = ["d0", "d1"] + self.split_controller.request_dispatcher = self.mock + + def test_train(self): + self.mock.train_global_on.return_value = ( + {"weights": 42}, + {"weights": 43}, + ModelMetricResultContainer(), + {"optimizer_state": 44}, + DiagnosticMetricResultContainer(), + ) + + self.split_controller._train() + + self.mock.set_weights_on.assert_called_once_with( + device_id="d0", state_dict=None, on_client=True + ) + self.mock.train_global_on.assert_called_once_with( + self.cfg.topology.devices[0].device_id, epochs=1, round_no=0 + ) + + def test_train_with_inactive_server_device(self): + self.mock.train_global_on.return_value = False + self.mock.active_devices.return_value = ["d1"] + + self.split_controller._train() + + # check that first device in active devices is called with weights + self.mock.set_weights_on.assert_called_once_with( + device_id="d1", state_dict=None, on_client=True + ) + self.mock.train_global_on.assert_called_once_with( + self.cfg.topology.devices[0].device_id, epochs=1, round_no=0 + ) diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c025f4b78a2a648b6055751cf5875b4b7224494 --- /dev/null +++ b/edml/tests/controllers/swarm_controller_test.py @@ -0,0 +1,105 @@ +import unittest +from unittest.mock import Mock, call + +from omegaconf import ListConfig + +from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler +from edml.controllers.scheduler.random import RandomNextServerScheduler +from edml.controllers.scheduler.sequential import SequentialNextServerScheduler +from edml.controllers.scheduler.smart import SmartNextServerScheduler +from edml.controllers.swarm_controller import SwarmController +from edml.core.device import DeviceRequestDispatcher +from edml.helpers.metrics import ( + ModelMetricResultContainer, + DiagnosticMetricResultContainer, +) +from edml.tests.controllers.test_helper import load_sample_config + + +class SwarmControllerTest(unittest.TestCase): + + def setUp(self) -> None: + self.cfg = load_sample_config() + self.swarm_controller = SwarmController( + self.cfg, SequentialNextServerScheduler() + ) + self.mock = Mock(spec=DeviceRequestDispatcher) + self.mock.active_devices.return_value = ["d0", "d1"] + self.swarm_controller.request_dispatcher = self.mock + + def test_split_train_round(self): + self.mock.train_global_on.return_value = ( + {"weights": 42}, + {"weights": 43}, + ModelMetricResultContainer(), + {"optimizer_state": 42}, + DiagnosticMetricResultContainer(), + ) + + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( + self.swarm_controller._swarm_train_round(None, None, "d1", 0) + ) + + self.assertEqual(client_weights, {"weights": 42}) + self.assertEqual(server_weights, {"weights": 43}) + self.assertEqual(metrics, ModelMetricResultContainer()) + self.assertEqual(diagnostic_metrics, DiagnosticMetricResultContainer()) + self.assertEqual(optimizer_state, {"optimizer_state": 42}) + self.mock.set_weights_on.assert_has_calls( + [ + call(device_id="d0", state_dict=None, on_client=True), + call(device_id="d1", state_dict=None, on_client=False), + ] + ) + self.mock.train_global_on.assert_called_once_with( + "d1", epochs=1, round_no=0, optimizer_state=None + ) + + def test_split_train_round_with_inactive_server_device(self): + self.mock.train_global_on.return_value = False + + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( + self.swarm_controller._swarm_train_round(None, None, "d1", 0) + ) + + self.assertEqual(client_weights, None) + self.assertEqual(server_weights, None) + self.assertEqual(metrics, None) + self.assertEqual(diagnostic_metrics, None) + self.assertEqual(optimizer_state, None) + self.mock.set_weights_on.assert_has_calls( + [ + call(device_id="d0", state_dict=None, on_client=True), + call(device_id="d1", state_dict=None, on_client=False), + ] + ) + self.mock.train_global_on.assert_called_once_with( + "d1", epochs=1, round_no=0, optimizer_state=None + ) + + +class ServerDeviceSelectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.swarm_controller = SwarmController( + load_sample_config(), Mock(SequentialNextServerScheduler) + ) + self.swarm_controller.request_dispatcher = Mock(spec=DeviceRequestDispatcher) + self.swarm_controller.devices = ListConfig( + [{"device_id": "d0"}, {"device_id": "d1"}, {"device_id": "d2"}] + ) # omitted address etc. + + def test_select_no_server_device_if_no_active_devices(self): + self.swarm_controller.request_dispatcher.active_devices.return_value = [] + self.swarm_controller._refresh_active_devices() + + server_device = self.swarm_controller._select_server_device(None) + self.assertEqual(server_device, None) + + server_device = self.swarm_controller._select_server_device("d1") + self.assertEqual(server_device, None) + + def test_sequential_selection_default(self): + self.swarm_controller._select_server_device() + + self.swarm_controller._next_server_scheduler.next_server.assert_called() diff --git a/edml/tests/controllers/test_controller_test.py b/edml/tests/controllers/test_controller_test.py new file mode 100644 index 0000000000000000000000000000000000000000..848071a068d6418f6ac44303f83ada770755bbe9 --- /dev/null +++ b/edml/tests/controllers/test_controller_test.py @@ -0,0 +1,29 @@ +import unittest +from unittest.mock import patch, Mock, call + +from edml.controllers.test_controller import TestController +from edml.core.device import DeviceRequestDispatcher +from edml.tests.controllers.test_helper import load_sample_config + + +class TestControllerTest(unittest.TestCase): + + def test_run_with_test_data(self): + cfg = load_sample_config() + cfg.best_round = 42 + test_controller = TestController(cfg) + mock = Mock(spec=DeviceRequestDispatcher) + mock.active_devices.return_value = ["d0", "d1"] + test_controller.request_dispatcher = mock + + with patch( + "edml.controllers.base_controller.BaseController._load_weights" + ) as mock_load_weights: + mock_load_weights.return_value = {"weights": 42}, {"weights": 43} + + test_controller._train() + + mock_load_weights.assert_called_once_with(42) + mock.evaluate_global_on.assert_has_calls( + [call("d0", False, True), call("d1", False, True)] + ) diff --git a/edml/tests/controllers/test_helper.py b/edml/tests/controllers/test_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..af922e79c4391310fbe97e2251a84b0a1d6ad138 --- /dev/null +++ b/edml/tests/controllers/test_helper.py @@ -0,0 +1,84 @@ +import os.path +import time +from unittest.mock import Mock + +import grpc +from omegaconf import OmegaConf +from torch import Tensor + +from edml.generated import connection_pb2 +from edml.generated.connection_pb2_grpc import DeviceStub +from edml.helpers import proto_helpers + + +def mock_establish_device_connections(): + """Mocks device connections for testing. + Sets return values for some methods offered by the DeviceStub class. + """ + connections = { + "d0": Mock(spec=DeviceStub(Mock(grpc.Channel))), + "d1": Mock(spec=DeviceStub(Mock(grpc.Channel))), + } + client0_weights = proto_helpers.weights_to_proto({"weights": Tensor([42])}) + server0_weights = proto_helpers.weights_to_proto({"weights": Tensor([43])}) + client1_weights = proto_helpers.weights_to_proto({"weights": Tensor([44])}) + server1_weights = proto_helpers.weights_to_proto({"weights": Tensor([45])}) + + connections["d0"].FullModelTraining.return_value = ( + connection_pb2.FullModelTrainResponse( + client_weights=client0_weights, + server_weights=server0_weights, + ) + ) + connections["d1"].FullModelTraining.return_value = ( + connection_pb2.FullModelTrainResponse( + client_weights=client1_weights, + server_weights=server1_weights, + ) + ) + + connections["d0"].Evaluate.return_value = connection_pb2.EvalResponse() + connections["d1"].Evaluate.return_value = connection_pb2.EvalResponse() + + connections["d0"].TrainGlobal.return_value = connection_pb2.TrainGlobalResponse( + client_weights=client0_weights, + server_weights=server0_weights, + ) + connections["d1"].TrainGlobal.return_value = connection_pb2.TrainGlobalResponse( + client_weights=client1_weights, + server_weights=server1_weights, + ) + + connections["d0"].SetWeights.return_value = connection_pb2.SetWeightsResponse() + connections["d1"].SetWeights.return_value = connection_pb2.SetWeightsResponse() + + return connections + + +def load_sample_config(): + return OmegaConf.load(os.path.join(os.path.dirname(__file__), "sample_config.yaml")) + + +def assert_experiment_started_and_ended(test_case): + test_case.connections["d0"].StartExperiment.assert_called_once() + test_case.connections["d1"].StartExperiment.assert_called_once() + test_case.connections["d0"].EndExperiment.assert_called_once() + test_case.connections["d1"].EndExperiment.assert_called_once() + + +def get_side_effect(i, valid_response): + """Constructs a side effect based on integer input. + If even, the side effect returns a valid response, else raises a grpc error""" + if i % 2 == 0: + + def side_effect(*args, **kwargs): + time.sleep(0.2) + return valid_response + + else: + + def side_effect(*args, **kwargs): + time.sleep(0.2) + raise grpc.RpcError() + + return side_effect diff --git a/edml/tests/core/battery_test.py b/edml/tests/core/battery_test.py new file mode 100644 index 0000000000000000000000000000000000000000..64c2be164bf205bb45d238dd9e12ce4f8001eb9e --- /dev/null +++ b/edml/tests/core/battery_test.py @@ -0,0 +1,128 @@ +import threading +import time +import unittest +from unittest.mock import patch + +from edml.core.battery import Battery + + +class BatteryTest(unittest.TestCase): + def setUp(self): + self.battery = Battery(1000, 1, 0.1, 0.2, 0.3) + self.battery.start_experiment() + + def test_update_flops(self): + self.battery.update_flops(420000000) + self.assertEqual(self.battery.remaining_capacity(), 958) + + def test_update_time_after_one_second(self): + self.battery.__last_update__ = 0.0 + with patch("time.time", side_effect=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]): + for i in range(0, 6): + self.battery.update_time() + self.assertEqual(self.battery.remaining_capacity(), 995) + self.assertEqual(self.battery.initial_capacity(), 1000) + + def test_update_time_immediately(self): + self.battery.update_time() + self.assertAlmostEqual(self.battery.remaining_capacity(), 1000, delta=0.1) + + def test_update_communication_received(self): + self.battery.update_communication_received(10000000) + self.assertEqual(self.battery.remaining_capacity(), 998) + + def test_update_communication_sent(self): + self.battery.update_communication_sent(10000000) + self.assertEqual(self.battery.remaining_capacity(), 997) + + def test_exception_when_empty_after_flop_update(self): + self.battery.__capacity__ = 0.00000001 + self.assertFalse(self.battery.is_empty()) + with self.assertRaises(Exception): + self.battery.update_flops(1000) + self.assertTrue(self.battery.is_empty()) + + def test_exception_when_empty_after_time_update(self): + self.battery.__capacity__ = 0.00000001 + self.assertFalse(self.battery.is_empty()) + with self.assertRaises(Exception): + self.battery.update_time() + self.assertTrue(self.battery.is_empty()) + + +class BatteryTestNotUpdatingWhenExperimentNotStarted(unittest.TestCase): + def setUp(self): + self.battery = Battery(1000, 1, 0.1, 0.2, 0.3) + + def test_not_updating_flops(self): + self.battery.update_flops(420000000) + self.assertEqual(self.battery.remaining_capacity(), 1000) + + def test_not_updating_time(self): + self.battery.update_time() + self.assertEqual(self.battery.remaining_capacity(), 1000) + + def test_not_updating_communication_received(self): + self.battery.update_communication_received(10000000) + self.assertEqual(self.battery.remaining_capacity(), 1000) + + def test_not_updating_communication_sent(self): + self.battery.update_communication_sent(10000000) + self.assertEqual(self.battery.remaining_capacity(), 1000) + + +class BatteryMultiThreadedTest(unittest.TestCase): + def setUp(self): + self.battery = Battery(1000, 1, 0.1, 0.1, 0.1) + self.battery.start_experiment() + + def bulk_update_flops(): + """with 0.1 deduction per MFLOP it should decrease capacity by 499.995""" + for i in range(100000): + self.battery.update_flops(i) + + def bulk_update_communication_received(): + """with 0.1 deduction per mbyte it should decrease capacity by 499.995""" + for i in range(100000): + self.battery.update_communication_received(i) + + def bulk_update_communication_sent(): + """with 0.1 deduction per mbyte it should decrease capacity by 499.995""" + for i in range(100000): + self.battery.update_communication_sent(i) + + self.bulk_update_flops = bulk_update_flops + self.bulk_update_communication_received = bulk_update_communication_received + self.bulk_update_communication_sent = bulk_update_communication_sent + + def test_multi_threaded_flops(self): + t1 = threading.Thread(target=self.bulk_update_flops) + t2 = threading.Thread(target=self.bulk_update_flops) + t1.start() + t2.start() + t1.join() + t2.join() + self.assertAlmostEqual( + self.battery.remaining_capacity(), 0.01, delta=1e-12 + ) # delta due to FP precision + + def test_multi_threaded_time_and_flops(self): + start_time = time.time() + self.battery.__last_update__ = start_time + with patch("time.time", return_value=start_time + 42): + t1 = threading.Thread(target=self.bulk_update_flops) + t2 = threading.Thread(target=self.battery.update_time) + t1.start() + t2.start() + t1.join() + t2.join() + self.assertAlmostEqual(self.battery.remaining_capacity(), 458.005, delta=1e-12) + + def test_multi_threaded_communication(self): + t1 = threading.Thread(target=self.bulk_update_communication_received) + t2 = threading.Thread(target=self.bulk_update_communication_sent) + t1.start() + t2.start() + t1.join() + t2.join() + self.assertAlmostEqual(self.battery.remaining_capacity(), 0.01, delta=1e-12) diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3827df9c727254bb13482e22bef6899bb2ae2c47 --- /dev/null +++ b/edml/tests/core/device_test.py @@ -0,0 +1,937 @@ +import concurrent.futures +import math +import threading +import unittest +from inspect import signature +from unittest.mock import Mock, patch + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time +from omegaconf import DictConfig +from torch import Tensor + +from edml.core.battery import Battery, BatteryEmptyException +from edml.core.client import DeviceClient +from edml.core.device import RPCDeviceServicer, NetworkDevice, DeviceRequestDispatcher +from edml.core.server import DeviceServer +from edml.generated import connection_pb2, datastructures_pb2 +from edml.generated.connection_pb2 import ( + EvalResponse, + EvalBatchResponse, + SetWeightsResponse, + TrainEpochResponse, +) +from edml.generated.datastructures_pb2 import ( + Activations, + Labels, + Weights, + DeviceInfo, + BatteryStatus, +) +from edml.helpers.metrics import ( + ModelMetricResult, + ModelMetricResultContainer, + DiagnosticMetricResultContainer, + DiagnosticMetricResult, +) +from edml.helpers.proto_helpers import ( + tensor_to_proto, + proto_to_tensor, + state_dict_to_proto, + proto_to_state_dict, + proto_to_weights, + weights_to_proto, + gradients_to_proto, + metrics_to_proto, + proto_to_metrics, +) +from edml.helpers.types import DeviceBatteryStatus +from edml.tests.controllers.test_helper import get_side_effect + + +class NetworkDeviceTest(unittest.TestCase): + + def setUp(self) -> None: + self.device = NetworkDevice("0", None, Mock(Battery)) + self.device.set_client(Mock(DeviceClient)) + self.device.set_server(Mock(DeviceServer)) + + def test_get_device_ids(self): + self.device.devices = [ + DictConfig({"device_id": "0", "address": "42"}), + DictConfig({"device_id": "1", "address": "43"}), + ] + self.assertEqual(self.device.__get_device_ids__(), ["0", "1"]) + + def test_train_epoch_on_same_device(self): + self.device.train_epoch_on("0", "0", 0) + self.device.client.train_epoch.assert_called_once_with("0", round_no=0) + + def test_train_epoch_on_different_device(self): + self.device.train_epoch_on("0", "1") + self.device.server.train.assert_not_called() + + def test_rpc_call_for_train_epoch_on_different_device(self): + self.device.request_dispatcher.connections["1"] = Mock(spec=["TrainEpoch"]) + self.device.request_dispatcher.connections["1"].TrainEpoch.return_value = ( + TrainEpochResponse( + weights=Weights(weights=state_dict_to_proto({"weights": Tensor([42])})) + ) + ) + + weights, diagnostic_metrics = self.device.train_epoch_on("1", "0", 0) + + self.assertEqual(weights, {"weights": Tensor([42])}) + self.assertIsNotNone(diagnostic_metrics) + self.device.request_dispatcher.connections[ + "1" + ].TrainEpoch.assert_called_once_with( + connection_pb2.TrainEpochRequest( + server=DeviceInfo(device_id="0", address=""), round_no=0 + ) + ) + + +class RPCDeviceServicerTest(unittest.TestCase): + + def setUp(self) -> None: + # instantiating NetworkDevice to mock Device(ABC)'s properties + # store mock_device reference for later function call assertions + self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery))) + self.mock_device.device_id = 42 + + my_servicer = RPCDeviceServicer(device=self.mock_device) + servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer} + self.test_server = server_from_dictionary(servicers, strict_real_time()) + + self.metrics = ModelMetricResultContainer( + [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)] + ) + self.diagnostic_metrics = DiagnosticMetricResultContainer( + [DiagnosticMetricResult("d1", "comp_time", "train", 42)] + ) + + def make_call(self, method_name, request): + method = self.test_server.invoke_unary_unary( + method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[ + "Device" + ].methods_by_name[method_name], + invocation_metadata={}, + request=request, + timeout=None, + ) + return method.termination() + + def test_train_global(self): + self.mock_device.train_global.return_value = ( + {"weights": Tensor([42])}, + {"weights": Tensor([43])}, + self.metrics, + {"optimizer_state": 44}, + self.diagnostic_metrics, + ) + request = connection_pb2.TrainGlobalRequest(epochs=42) + + response, metadata, code, details = self.make_call("TrainGlobal", request) + + self.assertEqual( + ( + proto_to_weights(response.client_weights), + proto_to_weights(response.server_weights), + ), + ({"weights": Tensor([42])}, {"weights": Tensor([43])}), + ) + self.assertEqual(proto_to_metrics(response.metrics), self.metrics) + self.assertEqual( + proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44} + ) + self.assertEqual(code, StatusCode.OK) + self.mock_device.train_global.assert_called_once_with(42, 0) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_set_weights(self): + request = connection_pb2.SetWeightsRequest( + weights=Weights(weights=state_dict_to_proto({"weights": Tensor([42])})), + on_client=True, + ) + + response, metadata, code, details = self.make_call("SetWeights", request) + + self.assertEqual(type(response), type(SetWeightsResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.set_weights.assert_called_once_with( + {"weights": Tensor([42])}, True + ) + + def test_train_epoch(self): + self.mock_device.train_epoch.return_value = { + "weights": Tensor([42]) + }, self.diagnostic_metrics + request = connection_pb2.TrainEpochRequest( + server=datastructures_pb2.DeviceInfo(device_id="42", address="") + ) + + response, metadata, code, details = self.make_call("TrainEpoch", request) + + self.assertEqual( + proto_to_state_dict(response.weights.weights), {"weights": Tensor([42])} + ) + self.assertEqual(code, StatusCode.OK) + self.mock_device.train_epoch.assert_called_once_with("42", 0) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_train_batch(self): + self.mock_device.train_batch.return_value = ( + Tensor([42]), + 42.0, + self.diagnostic_metrics, + ) + request = connection_pb2.TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))), + labels=Labels(labels=tensor_to_proto(Tensor([1]))), + ) + + response, metadata, code, details = self.make_call("TrainBatch", request) + + self.assertEqual(code, StatusCode.OK) + self.mock_device.train_batch.assert_called_once_with(Tensor([1.0]), Tensor([1])) + self.assertEqual(proto_to_tensor(response.gradients.gradients), Tensor([42])) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_evaluate_global(self): + self.mock_device.evaluate_global.return_value = ( + self.metrics, + self.diagnostic_metrics, + ) + request = connection_pb2.EvalGlobalRequest(validation=False, federated=False) + + response, metadata, code, details = self.make_call("EvaluateGlobal", request) + + self.assertEqual(proto_to_metrics(response.metrics), self.metrics) + self.assertEqual(code, StatusCode.OK) + self.mock_device.evaluate_global.assert_called_once_with(False, False) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_evaluate(self): + self.mock_device.evaluate.return_value = self.diagnostic_metrics + request = connection_pb2.EvalRequest( + server=DeviceInfo(device_id="42", address=""), validation=True + ) + + response, metadata, code, details = self.make_call("Evaluate", request) + + self.assertEqual(type(response), type(EvalResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.evaluate.assert_called_once_with("42", True) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_evaluate_batch(self): + self.mock_device.evaluate_batch.return_value = self.diagnostic_metrics + request = connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))), + labels=Labels(labels=tensor_to_proto(Tensor([1]))), + ) + + response, metadata, code, details = self.make_call("EvaluateBatch", request) + + self.assertEqual(type(response), type(EvalBatchResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.evaluate_batch.assert_called_once_with( + Tensor([1.0]), Tensor([1]) + ) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_full_model_training(self): + self.mock_device.federated_train.return_value = ( + {"weights": Tensor([42])}, + {"weights": Tensor([43])}, + 44, + self.metrics, + self.diagnostic_metrics, + ) + request = connection_pb2.FullModelTrainRequest() + + response, metadata, code, details = self.make_call("FullModelTraining", request) + + self.assertEqual( + proto_to_state_dict(response.client_weights.weights), + {"weights": Tensor([42])}, + ) + self.assertEqual( + proto_to_state_dict(response.server_weights.weights), + {"weights": Tensor([43])}, + ) + self.assertEqual(response.num_samples, 44) + self.assertEqual(proto_to_metrics(response.metrics), self.metrics) + self.assertEqual(code, StatusCode.OK) + self.mock_device.federated_train.assert_called_once() + + def test_start_experiment(self): + self.mock_device.start_experiment.return_value = None + request = connection_pb2.StartExperimentRequest() + + response, metadata, code, details = self.make_call("StartExperiment", request) + + self.assertEqual(type(response), type(connection_pb2.StartExperimentResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.start_experiment.assert_called_once() + + def test_end_experiment(self): + self.mock_device.end_experiment.return_value = None + request = connection_pb2.EndExperimentRequest() + + response, metadata, code, details = self.make_call("EndExperiment", request) + + self.assertEqual(type(response), type(connection_pb2.EndExperimentResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.end_experiment.assert_called_once() + + def test_battery_status(self): + self.mock_device.get_battery_status.return_value = (42, 21) + request = connection_pb2.BatteryStatusRequest() + + response, metadata, code, details = self.make_call("GetBatteryStatus", request) + + self.assertEqual( + response.status, + BatteryStatus(initial_battery_level=42, current_battery_level=21), + ) + self.assertEqual(code, StatusCode.OK) + self.mock_device.get_battery_status.assert_called_once() + + def test_dataset_model_info(self): + self.mock_device.client._train_data.dataset = [1] + self.mock_device.client._val_data.dataset = [2] + self.mock_device.client._model_flops = 3 + self.mock_device.server._model_flops = 4 + request = connection_pb2.DatasetModelInfoRequest() + + response, metadata, code, details = self.make_call( + "GetDatasetModelInfo", request + ) + + self.assertEqual(code, StatusCode.OK) + self.assertEqual(response.train_samples, 1) + self.assertEqual(response.validation_samples, 1) + self.assertEqual(response.client_model_flops, 3) + self.assertEqual(response.server_model_flops, 4) + + +class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase): + + def setUp(self) -> None: + # instantiating NetworkDevice to mock Device(ABC)'s properties + # store mock_device reference for later function call assertions + self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery))) + my_servicer = RPCDeviceServicer(device=self.mock_device) + servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer} + self.test_server = server_from_dictionary(servicers, strict_real_time()) + + def make_call(self, method_name, request): + method = self.test_server.invoke_unary_unary( + method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[ + "Device" + ].methods_by_name[method_name], + invocation_metadata={}, + request=request, + timeout=None, + ) + return method.termination() + + def test_stop_at_train_global(self): + self.mock_device.train_global.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.TrainGlobalRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal") + + def test_stop_at_set_weights(self): + self.mock_device.set_weights.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.SetWeightsRequest(weights=weights_to_proto({})) + self._test_device_out_of_battery_lets_rpc_fail(request, "SetWeights") + + def test_stop_at_train_epoch(self): + self.mock_device.train_epoch.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.TrainEpochRequest(server=None) + self._test_device_out_of_battery_lets_rpc_fail(request, "TrainEpoch") + + def test_stop_at_train_batch(self): + self.mock_device.train_batch.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))), + labels=Labels(labels=tensor_to_proto(Tensor([1]))), + ) + self._test_device_out_of_battery_lets_rpc_fail(request, "TrainBatch") + + def test_stop_at_evaluate_global(self): + self.mock_device.evaluate_global.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.EvalGlobalRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateGlobal") + + def test_stop_at_evaluate(self): + self.mock_device.evaluate.side_effect = BatteryEmptyException("Battery empty") + request = connection_pb2.EvalRequest(server=None) + self._test_device_out_of_battery_lets_rpc_fail(request, "Evaluate") + + def test_stop_at_evaluate_batch(self): + self.mock_device.evaluate_batch.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))), + labels=Labels(labels=tensor_to_proto(Tensor([1]))), + ) + self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateBatch") + + def test_stop_at_full_model_training(self): + self.mock_device.federated_train.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.FullModelTrainRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "FullModelTraining") + + def test_stop_at_start_experiment(self): + self.mock_device.start_experiment.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.StartExperimentRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "StartExperiment") + + def test_stop_at_end_experiment(self): + self.mock_device.end_experiment.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.EndExperimentRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "EndExperiment") + + def test_stop_at_get_battery_status(self): + self.mock_device.get_battery_status.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.BatteryStatusRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "GetBatteryStatus") + + def _test_device_out_of_battery_lets_rpc_fail(self, request, servicer_method_name): + response, metadata, code, details = self.make_call( + servicer_method_name, request + ) + self.assertIsNone(response) + self.assertEqual(code, StatusCode.UNKNOWN) + self.assertEqual(details, "Exception calling application: Battery empty") + + +class BatteryUpdateTest(unittest.TestCase): + + def setUp(self): + self.battery = Mock(Battery(1000, 1, 0.1)) + + def test_battery_update_on_train_epoch(self): + device = NetworkDevice("0", None, battery=self.battery) + device.set_client(Mock(DeviceClient)) + device.train_epoch(None) + self.battery.update_time.assert_called() + + +class RequestDispatcherTest(unittest.TestCase): + + def setUp(self) -> None: + self.dispatcher = DeviceRequestDispatcher( + [] + ) # pass no devices to avoid grpc calls + # mock connections instead + self.dispatcher.connections["1"] = Mock( + spec=[ + "TrainGlobal", + "SetWeights", + "TrainEpoch", + "TrainBatch", + "EvaluateGlobal", + "Evaluate", + "EvaluateBatch", + "FullModelTraining", + "StartExperiment", + "EndExperiment", + "GetBatteryStatus", + "GetDatasetModelInfo", + ] + ) + self.dispatcher.devices = [ + DictConfig({"device_id": "0", "address": "42"}), # inactive device + DictConfig({"device_id": "1", "address": "43"}), + ] + self.mock_stub = self.dispatcher.connections["1"] # for convenience + self.weights = {"weights": Tensor([42])} + self.gradients = Tensor([42]) + self.activations = Tensor([1]) + self.labels = Tensor([1]) + self.metrics = ModelMetricResultContainer( + [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)] + ) + self.diagnostic_metrics = DiagnosticMetricResultContainer( + [DiagnosticMetricResult("d1", "comp_time", "train", 42)] + ) + + def test_train_global_on_without_error(self): + self.mock_stub.TrainGlobal.return_value = connection_pb2.TrainGlobalResponse( + client_weights=weights_to_proto(self.weights), + server_weights=weights_to_proto(self.weights), + metrics=metrics_to_proto(self.metrics), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + optimizer_state=state_dict_to_proto({"optimizer_state": 42}), + ) + + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( + self.dispatcher.train_global_on("1", 42, 43) + ) + + self.assertEqual(client_weights, self.weights) + self.assertEqual(server_weights, self.weights) + self.assertEqual(metrics, self.metrics) + self.assertEqual(optimizer_state, {"optimizer_state": 42}) + + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.TrainGlobal.assert_called_once_with( + connection_pb2.TrainGlobalRequest( + epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None) + ) + ) + + def test_train_global_on_with_error(self): + self.mock_stub.TrainGlobal.side_effect = grpc.RpcError() + + response = self.dispatcher.train_global_on("1", 42, round_no=43) + + self.assertEqual(response, False) + self.mock_stub.TrainGlobal.assert_called_once_with( + connection_pb2.TrainGlobalRequest( + epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None) + ) + ) + + def test_set_weights_on_without_error(self): + self.mock_stub.SetWeights.return_value = SetWeightsResponse() + + self.dispatcher.set_weights_on("1", self.weights, True) + + self.mock_stub.SetWeights.assert_called_once_with( + connection_pb2.SetWeightsRequest( + weights=weights_to_proto(self.weights), on_client=True + ), + wait_for_ready=False, + ) + + def test_set_weights_on_with_error(self): + self.mock_stub.SetWeights.side_effect = grpc.RpcError() + + response = self.dispatcher.set_weights_on("1", self.weights, True) + + self.mock_stub.SetWeights.assert_called_once_with( + connection_pb2.SetWeightsRequest( + weights=weights_to_proto(self.weights), on_client=True + ), + wait_for_ready=False, + ) + self.assertEqual(response, False) + + def test_train_epoch_on_without_error(self): + self.mock_stub.TrainEpoch.return_value = TrainEpochResponse( + weights=weights_to_proto(self.weights), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + ) + + weights, diagnostic_metrics = self.dispatcher.train_epoch_on("1", "0", 42) + + self.assertEqual(weights, self.weights) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.TrainEpoch.assert_called_once_with( + connection_pb2.TrainEpochRequest( + server=DeviceInfo(device_id="0", address="42"), round_no=42 + ) + ) + + def test_train_epoch_on_with_error(self): + self.mock_stub.TrainEpoch.side_effect = grpc.RpcError() + + response = self.dispatcher.train_epoch_on("1", "0", 42) + + self.assertEqual(response, False) + self.mock_stub.TrainEpoch.assert_called_once_with( + connection_pb2.TrainEpochRequest( + server=DeviceInfo(device_id="0", address="42"), round_no=42 + ) + ) + + def test_train_batch_on_without_error(self): + self.mock_stub.TrainBatch.return_value = connection_pb2.TrainBatchResponse( + gradients=gradients_to_proto(self.gradients), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + loss=42.0, + ) + + gradients, loss, diagnostic_metrics = self.dispatcher.train_batch_on( + "1", self.activations, self.labels + ) + + self.assertEqual(gradients, self.gradients) + self.assertEqual(loss, 42.0) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.TrainBatch.assert_called_once_with( + connection_pb2.TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(self.activations)), + labels=Labels(labels=tensor_to_proto(self.labels)), + ) + ) + + def test_train_batch_on_with_error(self): + self.mock_stub.TrainBatch.side_effect = grpc.RpcError() + + response = self.dispatcher.train_batch_on("1", self.activations, self.labels) + + self.assertEqual(response, False) + self.mock_stub.TrainBatch.assert_called_once_with( + connection_pb2.TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(self.activations)), + labels=Labels(labels=tensor_to_proto(self.labels)), + ) + ) + + def test_evaluate_global_on_without_error(self): + self.mock_stub.EvaluateGlobal.return_value = connection_pb2.EvalGlobalResponse( + metrics=metrics_to_proto(self.metrics), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + ) + + metrics, diagnostic_metrics = self.dispatcher.evaluate_global_on( + "1", True, True + ) + + self.assertEqual(metrics, self.metrics) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.EvaluateGlobal.assert_called_once_with( + connection_pb2.EvalGlobalRequest(validation=True, federated=True) + ) + + def test_evaluate_global_on_with_error(self): + self.mock_stub.EvaluateGlobal.side_effect = grpc.RpcError() + + response = self.dispatcher.evaluate_global_on("1", True, True) + + self.assertEqual(response, False) + self.mock_stub.EvaluateGlobal.assert_called_once_with( + connection_pb2.EvalGlobalRequest(validation=True, federated=True) + ) + + def test_evaluate_on_without_error(self): + self.mock_stub.Evaluate.return_value = EvalResponse( + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics) + ) + + diagnostic_metrics = self.dispatcher.evaluate_on("1", "0", True) + + self.assertEqual(diagnostic_metrics, self.diagnostic_metrics) + self.mock_stub.Evaluate.assert_called_once_with( + connection_pb2.EvalRequest( + server=DeviceInfo(device_id="0", address="42"), validation=True + ) + ) + + def test_evaluate_on_with_error(self): + self.mock_stub.Evaluate.side_effect = grpc.RpcError() + + response = self.dispatcher.evaluate_on("1", "0", True) + + self.assertEqual(response, False) + self.mock_stub.Evaluate.assert_called_once_with( + connection_pb2.EvalRequest( + server=DeviceInfo(device_id="0", address="42"), validation=True + ) + ) + + def test_evaluate_batch_on_without_error(self): + self.mock_stub.EvaluateBatch.return_value = EvalBatchResponse( + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics) + ) + + diagnostic_metrics = self.dispatcher.evaluate_batch_on( + "1", self.activations, self.labels + ) + + self._assert_field_size_added_to_diagnostic_metrics( + diagnostic_metrics + ) # metric field present in response, but not used in practice. Thus, a field size is added + self.mock_stub.EvaluateBatch.assert_called_once_with( + connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(self.activations)), + labels=Labels(labels=tensor_to_proto(self.labels)), + ) + ) + + def test_evaluate_batch_on_with_error(self): + self.mock_stub.EvaluateBatch.side_effect = grpc.RpcError() + + response = self.dispatcher.evaluate_batch_on("1", self.activations, self.labels) + + self.assertEqual(response, False) + self.mock_stub.EvaluateBatch.assert_called_once_with( + connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(self.activations)), + labels=Labels(labels=tensor_to_proto(self.labels)), + ) + ) + + def test_full_model_training_without_error(self): + self.mock_stub.FullModelTraining.return_value = ( + connection_pb2.FullModelTrainResponse( + client_weights=weights_to_proto(self.weights), + server_weights=weights_to_proto(self.weights), + num_samples=42, + metrics=metrics_to_proto(self.metrics), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + ) + ) + + client_weights, server_weights, num_samples, metrics, diagnostic_metrics = ( + self.dispatcher.federated_train_on("1", 42) + ) + + self.assertEqual(client_weights, self.weights) + self.assertEqual(server_weights, self.weights) + self.assertEqual(num_samples, 42) + self.assertEqual(metrics, self.metrics) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.FullModelTraining.assert_called_once_with( + connection_pb2.FullModelTrainRequest(round_no=42) + ) + + def test_full_model_training_with_error(self): + self.mock_stub.FullModelTraining.side_effect = grpc.RpcError() + + response = self.dispatcher.federated_train_on("1", 42) + + self.assertEqual(response, False) + self.mock_stub.FullModelTraining.assert_called_once_with( + connection_pb2.FullModelTrainRequest(round_no=42) + ) + + def test_start_experiment_on_without_error(self): + self.mock_stub.StartExperiment.return_value = ( + connection_pb2.StartExperimentResponse() + ) + + response = self.dispatcher.start_experiment_on("1", True) + + self.assertEqual(response, True) + self.mock_stub.StartExperiment.assert_called_once_with( + connection_pb2.StartExperimentRequest(), wait_for_ready=True + ) + + def test_start_experiment_on_with_error(self): + self.mock_stub.StartExperiment.side_effect = grpc.RpcError() + + response = self.dispatcher.start_experiment_on("1", True) + + self.assertEqual(response, False) + self.mock_stub.StartExperiment.assert_called_once_with( + connection_pb2.StartExperimentRequest(), wait_for_ready=True + ) + + def test_end_experiment_on_without_error(self): + self.mock_stub.EndExperiment.return_value = ( + connection_pb2.EndExperimentResponse() + ) + + response = self.dispatcher.end_experiment_on("1") + + self.assertEqual(response, True) + self.mock_stub.EndExperiment.assert_called_once_with( + connection_pb2.EndExperimentRequest() + ) + + def test_end_experiment_on_with_error(self): + self.mock_stub.EndExperiment.side_effect = grpc.RpcError() + + response = self.dispatcher.end_experiment_on("1") + + self.assertEqual(response, False) + self.mock_stub.EndExperiment.assert_called_once_with( + connection_pb2.EndExperimentRequest() + ) + + def test_get_battery_status_without_error(self): + self.mock_stub.GetBatteryStatus.return_value = ( + connection_pb2.BatteryStatusResponse( + status=BatteryStatus(initial_battery_level=42, current_battery_level=21) + ) + ) + + response = self.dispatcher.get_battery_status_on("1") + + self.assertEqual( + response, DeviceBatteryStatus(initial_capacity=42, current_capacity=21) + ) + self.mock_stub.GetBatteryStatus.assert_called_once_with( + connection_pb2.BatteryStatusRequest() + ) + + def test_get_battery_status_with_error(self): + self.mock_stub.GetBatteryStatus.side_effect = grpc.RpcError() + + response = self.dispatcher.get_battery_status_on("1") + + self.assertEqual(response, False) + self.mock_stub.GetBatteryStatus.assert_called_once_with( + connection_pb2.BatteryStatusRequest() + ) + + def test_get_dataset_model_info_without_error(self): + self.mock_stub.GetDatasetModelInfo.return_value = ( + connection_pb2.DatasetModelInfoResponse( + train_samples=42, + validation_samples=21, + client_model_flops=42, + server_model_flops=21, + ) + ) + + response = self.dispatcher.get_dataset_model_info_on("1") + + self.assertEqual(response, (42, 21, 42, 21)) + self.mock_stub.GetDatasetModelInfo.assert_called_once_with( + connection_pb2.DatasetModelInfoRequest() + ) + + def test_get_dataset_model_info_with_error(self): + self.mock_stub.GetDatasetModelInfo.side_effect = grpc.RpcError() + + response = self.dispatcher.get_dataset_model_info_on("1") + + self.assertEqual(response, False) + self.mock_stub.GetDatasetModelInfo.assert_called_once_with( + connection_pb2.DatasetModelInfoRequest() + ) + + def test_handle_calls_to_inactive_device(self): + """Test each method of the dispatcher that does RPC calls to handle calls to inactive devices. + One test where the device is known, but inactive and one where the device is unknown. + """ + methods_names = [ + "train_global_on", + "set_weights_on", + "train_epoch_on", + "train_batch_on", + "evaluate_global_on", + "evaluate_on", + "evaluate_batch_on", + "federated_train_on", + "start_experiment_on", + "end_experiment_on", + "get_battery_status_on", + "get_dataset_model_info_on", + ] + for device in [("0", True), ("2", False)]: # 0 is inactive, 2 is unknown + for method_name in methods_names: + with self.subTest(method_name=f"{method_name}_device{device[0]}"): + with patch("builtins.print") as print_patch: + method = getattr(self.dispatcher, method_name) + params = list(signature(method).parameters) + + # make method call with device id and all other params as None + response = method( + device[0], + *[None for _ in params if _ not in ["self", "device_id"]], + ) + + self.assertEqual(response, False) + self.mock_stub.assert_not_called() + if device[1]: + print_patch.assert_called_once_with( + f"Device {device[0]} not active" + ) + else: + print_patch.assert_called_once_with( + f"Unknown Device ID {device[0]}" + ) + + def _assert_field_size_added_to_diagnostic_metrics(self, diagnostic_metrics): + """Not to use when response only includes diagnostic metrics, as these are ignored for the field size""" + self.assertEqual( + diagnostic_metrics.get_as_list()[0], + self.diagnostic_metrics.get_as_list()[0], + ) # check that previous diagnostic metrics still exist + self.assertEqual( + diagnostic_metrics.get_as_list()[1].name, "size" + ) # check that at least one new diagnostic metric for size was added + + +class RequestDispatcherThreadingTest(unittest.TestCase): + + def test_handle_rpc_error_thread_safety(self): + + # run the procedure with 2 to 20 threads + for n in range(2, 21): + with self.subTest(method_name=f"Test with {n} threads"): + print(f"Test with {n} threads") + self.request_dispatcher = DeviceRequestDispatcher([]) + + # mock n connections with half of them returning errors and the other half valid responses + self.request_dispatcher.connections = { + f"d{i}": Mock(spec=["GetBatteryStatus"]) for i in range(n) + } + response = connection_pb2.BatteryStatusResponse( + status=BatteryStatus( + initial_battery_level=42, current_battery_level=21 + ) + ) + for i, mock in enumerate(self.request_dispatcher.connections.values()): + mock.GetBatteryStatus.side_effect = get_side_effect(i, response) + + # make n concurrent calls to the dispatcher + responses = [] + response_lock = threading.Lock() + with concurrent.futures.ThreadPoolExecutor(max_workers=n) as executor: + futures = [ + executor.submit( + self.request_dispatcher.get_battery_status_on, device_id + ) + for device_id in self.request_dispatcher.active_devices() + ] + for future in concurrent.futures.as_completed(futures): + with response_lock: + responses.append(future.result()) + + # check that only error connections were removed + self.assertEqual( + list(self.request_dispatcher.connections.keys()), + [f"d{i}" for i in range(n) if i % 2 == 0], + ) + + response_battery_level = DeviceBatteryStatus( + initial_capacity=42, current_capacity=21 + ) + + # check that there are n responses in total, half of them valid and half of them None + self.assertEqual(responses.count(False), math.floor(n / 2)) + self.assertEqual( + responses.count(response_battery_level), math.ceil(n / 2) + ) diff --git a/edml/tests/core/server_test.py b/edml/tests/core/server_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1331f8aa7bad195eb02044bcf25d381fe2ee2a13 --- /dev/null +++ b/edml/tests/core/server_test.py @@ -0,0 +1,192 @@ +import unittest +from collections import OrderedDict +from copy import copy +from typing import Tuple, Any +from unittest.mock import Mock + +import torch.utils.data +from omegaconf import DictConfig +from torch import nn, tensor +from torch.utils.data import DataLoader, Dataset, TensorDataset + +from edml.core.battery import Battery +from edml.core.client import DeviceClient +from edml.core.device import Device +from edml.core.server import DeviceServer + + +class ClientModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 1) + self.output = nn.ReLU() + + def forward(self, x): + return self.output(self.layer(x)) + + +class ServerModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 2) + # self.layer.weight = tensor([[-0.6857]]) + self.output = nn.Softmax(dim=1) + + def forward(self, x): + return self.output(self.layer(x)) + + +class ToyDataset(Dataset): + def __init__(self, data: list, labels: list): + self.length = len(data) + self.data = torch.Tensor(data) + self.labels = torch.Tensor(labels, dtype=int) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + return self.data[index], self.labels[index] + + def __len__(self): + return self.length + + +class PSLTest(unittest.TestCase): + def setUp(self): + cfg = DictConfig( + { + "optimizer": { + "_target_": "torch.optim.SGD", + "lr": 1, + "momentum": 0, + "weight_decay": 0, + }, + "experiment": {"metrics": ["accuracy"]}, + "dataset": {"num_classes": 2, "average_setting": "micro"}, + "topology": { + "devices": [ + { + "device_id": "d0", + "address": "localhost:50051", + "torch_device": "cuda:0", + }, + { + "device_id": "d1", + "address": "localhost:50052", + "torch_device": "cuda:0", + }, + ] + }, + "own_device_id": "d0", + } + ) + # init models with fixed weights for repeatability + server_state_dict = OrderedDict( + [ + ("layer.weight", tensor([[-0.5], [-1.0]])), + ("layer.bias", tensor([-0.5, 0.25])), + ] + ) + client_state_dict = OrderedDict( + [("layer.weight", tensor([[-1.0]])), ("layer.bias", tensor([0.5]))] + ) + server_model = ServerModel() + server_model.load_state_dict(server_state_dict) + client_model1 = ClientModel() + client_model1.load_state_dict(client_state_dict) + client_model2 = ClientModel() + client_model2.load_state_dict(client_state_dict) + self.server = DeviceServer( + model=server_model, loss_fn=torch.nn.L1Loss(), cfg=cfg + ) + self.client1 = DeviceClient( + model=client_model1, + cfg=cfg, + train_dl=DataLoader(TensorDataset(tensor([[0.9]]), tensor([[0.0, 1.0]]))), + val_dl=DataLoader(TensorDataset(tensor([[0.75]]), tensor([[0.0, 1.0]]))), + test_dl=None, + ) + client2_cfg = cfg.copy() + client2_cfg["own_device_id"] = "d1" + self.client2 = DeviceClient( + model=client_model2, + cfg=client2_cfg, + train_dl=DataLoader(TensorDataset(tensor([[0.1]]), tensor([[1.0, 0.0]]))), + val_dl=DataLoader(TensorDataset(tensor([[0.25]]), tensor([[1.0, 0.0]]))), + test_dl=None, + ) + + def get_client_side_effect(fn): + """ + Creates side effects for methods of the form METHOD_on(device_id) skipping the request dispatcher. + """ + + def side_effect(*args, **kwargs): + # get the device id which is either a positional or keyword arg + if len(args) > 0: + device_id = args[0] + elif "client_id" in kwargs: + device_id = kwargs.pop("client_id") + elif "device_id" in kwargs: + device_id = kwargs.pop("device_id") + else: + return KeyError( + f"Could not find device_id in args or kwargs for function {fn}" + ) + # delegate to correct client then using the given method name + if device_id == "d0": + return self.client1.__class__.__dict__[fn]( + self.client1, *args, **kwargs + ) + elif device_id == "d1": + return self.client2.__class__.__dict__[fn]( + self.client2, *args, **kwargs + ) + else: + return KeyError(f"Unknown device_id {device_id}") + + return side_effect + + def get_server_side_effect(fn): + def side_effect(*args, **kwargs): + return self.server.__class__.__dict__[fn]( + self.server, *args, **kwargs + ) + ( + 1, + ) # Add (1,) as placeholder for DiagnosticMetricsContainer + + return side_effect + + node_device = Mock(Device) + node_device.battery = Mock(Battery) + node_device.train_batch_on_client_only_on.side_effect = get_client_side_effect( + "train_single_batch" + ) + node_device.backpropagation_on_client_only_on.side_effect = ( + get_client_side_effect("backward_single_batch") + ) + node_device.set_gradient_and_finalize_training_on_client_only_on.side_effect = ( + get_client_side_effect("set_gradient_and_finalize_training") + ) + node_device.train_batch.side_effect = get_server_side_effect("train_batch") + node_device.evaluate_on.side_effect = get_client_side_effect("evaluate") + node_device.evaluate_batch.side_effect = get_server_side_effect( + "evaluate_batch" + ) + + self.node_device1 = copy(node_device) + self.node_device2 = copy(node_device) + self.node_device1.client = self.client1 + self.node_device1.device_id = "d0" + self.node_device2.device_id = "d1" + self.client1.set_device(self.node_device1) + self.client2.set_device(self.node_device2) + self.server.set_device(self.node_device1) + + def test_train_parallel_sl(self): + ( + client_weights, + server_weights, + model_metrics, + optimizer_state, + diagnostic_metrics, + ) = self.server.train_parallel_split_learning(["d0", "d1"], round_no=0) + self.assertDictEqual(self.client1.get_weights(), self.client2.get_weights()) diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py new file mode 100644 index 0000000000000000000000000000000000000000..abb5d9d61d68c791657d168f9e266abec366a74c --- /dev/null +++ b/edml/tests/core/start_device_test.py @@ -0,0 +1,108 @@ +import os +import unittest +from copy import deepcopy + +import torch +from omegaconf import OmegaConf +from torch.autograd import Variable + +from edml.core.start_device import _get_models +from edml.helpers.model_splitting import Part +from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder + + +class GetModelsTest(unittest.TestCase): + def setUp(self): + os.chdir(os.path.join(os.path.dirname(__file__), "../../../")) + self.cfg = OmegaConf.create({"some_key": "some_value"}) + self.cfg.seed = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), + "../../config/seed/default.yaml", + ) + ) + + def _get_model_from_model_provider_config(self, config_name): + self.cfg.model_provider = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), + f"../../config/model_provider/{config_name}.yaml", + ) + ) + return _get_models(self.cfg) + + def test_load_resnet20(self): + client, server = self._get_model_from_model_provider_config("resnet20") + self.assertIsInstance(client, Part) + self.assertIsInstance(server, Part) + self.assertEqual(len(client.layers), 4) + self.assertEqual(len(server.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + + def test_load_resnet20_with_ae(self): + client, server = self._get_model_from_model_provider_config( + "resnet20-with-autoencoder" + ) + self.assertIsInstance(client, ClientWithAutoencoder) + self.assertIsInstance(server, ServerWithAutoencoder) + self.assertEqual(len(client.model.layers), 4) + self.assertEqual(len(server.model.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + optimizer = torch.optim.Adam(server.parameters()) + smashed_data = client(torch.zeros(1, 3, 32, 32)) + server_smashed_data = Variable(smashed_data, requires_grad=True) + output_train = server(server_smashed_data) + loss_train = torch.nn.functional.cross_entropy( + output_train, torch.zeros((1, 100)) + ) + loss_train.backward() + optimizer.step() + smashed_data.backward(server_smashed_data.grad) + optimizer.step() + + def test_training_resnet20_with_ae_as_non_trainable_layers(self): + client_encoder, server_decoder = self._get_model_from_model_provider_config( + "resnet20-with-autoencoder" + ) + client_params = deepcopy(str(client_encoder.model.state_dict())) + encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict())) + server_params = deepcopy(str(server_decoder.model.state_dict())) + decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict())) + + # Training loop + client_optimizer = torch.optim.Adam(client_encoder.parameters()) + server_optimizer = torch.optim.Adam(server_decoder.parameters()) + smashed_data = client_encoder(torch.zeros(1, 3, 32, 32)) + server_smashed_data = Variable(smashed_data, requires_grad=True) + output_train = server_decoder(server_smashed_data) + loss_train = torch.nn.functional.cross_entropy( + output_train, torch.rand((1, 100)) + ) + loss_train.backward() + server_optimizer.step() + smashed_data.backward(server_smashed_data.grad) + client_optimizer.step() + + # check that AE hasn't changed, but client and server have + self.assertEqual(encoder_params, str(client_encoder.autoencoder.state_dict())) + self.assertEqual(decoder_params, str(server_decoder.autoencoder.state_dict())) + self.assertNotEqual(client_params, str(client_encoder.model.state_dict())) + self.assertNotEqual(server_params, str(server_decoder.model.state_dict())) + + def test_load_resnet110(self): + client, server = self._get_model_from_model_provider_config("resnet110") + self.assertIsInstance(client, Part) + self.assertIsInstance(server, Part) + self.assertEqual(len(client.layers), 4) + self.assertEqual(len(server.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + + def test_load_resnet110_with_ae(self): + client, server = self._get_model_from_model_provider_config( + "resnet110-with-autoencoder" + ) + self.assertIsInstance(client, ClientWithAutoencoder) + self.assertIsInstance(server, ServerWithAutoencoder) + self.assertEqual(len(client.model.layers), 4) + self.assertEqual(len(server.model.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) diff --git a/edml/tests/helpers/config_helpers_test.py b/edml/tests/helpers/config_helpers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9279aeb714b84efdb27fd9895fca65393fa0b650 --- /dev/null +++ b/edml/tests/helpers/config_helpers_test.py @@ -0,0 +1,145 @@ +import os +import unittest +from unittest.mock import patch + +from omegaconf import DictConfig, OmegaConf + +from edml.controllers.parallel_split_controller import ParallelSplitController +from edml.helpers.config_helpers import ( + get_device_id_by_index, + get_device_address_by_id, + preprocess_config, + get_device_index_by_id, + instantiate_controller, + __drop_irrelevant_keys__, + get_torch_device_id, +) + + +class ConfigHelpersTest(unittest.TestCase): + + def setUp(self) -> None: + self.cfg = DictConfig( + { + "own_device_id": "d0", + "topology": { + "devices": [ + {"device_id": "d0", "address": "localhost:50051"}, + {"device_id": "d1", "address": "localhost:50052"}, + ] + }, + "num_devices": "${len:${topology.devices}}", + "controller": { + "name": "swarm", + "scheduler": {"name": "max_battery"}, + "adaptive_threshold_fn": {"name": "static"}, + }, + "group_by": { + "controller": [ + "name", + {"scheduler": "name", "adaptive_threshold_fn": "name"}, + ], + }, + "group": "${group_name:${group_by}}", + } + ) + + def test_get_device_id_by_index(self): + device_id = get_device_id_by_index(self.cfg, 0) + self.assertEqual(self.cfg.own_device_id, device_id) + + def test_get_device_index_by_id(self): + device_idx = get_device_index_by_id(self.cfg, "d1") + self.assertEqual(1, device_idx) + + def test_get_device_address_by_id(self): + device_address = get_device_address_by_id(device_id="d1", cfg=self.cfg) + self.assertEqual("localhost:50052", device_address) + + def test_preprocess_config_resolving_num_devices(self): + preprocess_config(self.cfg) + self.assertEqual(2, self.cfg.num_devices) + + def test_preprocess_config_set_device_id_by_index(self): + self.cfg.own_device_id = 1 + preprocess_config(self.cfg) + self.assertEqual("d1", self.cfg.own_device_id) + + def test_get_default_torch_device_if_cuda_available(self): + with patch("torch.cuda.is_available", return_value=True): + self.assertEqual(get_torch_device_id(self.cfg), "cuda:0") + + def test_get_default_torch_device_if_cuda_not_available(self): + with patch("torch.cuda.is_available", return_value=False): + self.assertEqual(get_torch_device_id(self.cfg), "cpu") + + def test_preprocess_config_group_name(self): + preprocess_config(self.cfg) + self.assertEqual(self.cfg.group, "swarm_max_battery_static") + + +class ControllerInstantiationTest(unittest.TestCase): + def setUp(self) -> None: + self.cfg = OmegaConf.create({"some_key": "some_value"}) + self.cfg.controller = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), + "../../../edml/config/controller/parallel_swarm.yaml", + ) + ) + self.cfg.controller.scheduler = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), + "../../../edml/config/controller/scheduler/max_battery.yaml", + ) + ) + + def test_parallel_split_controller_with_max_battery_instantiation(self): + with patch( + "edml.controllers.base_controller.BaseController.__init__" + ): # Avoid initializing the base_controller for brevity + with patch( + "edml.controllers.base_controller.BaseController._get_device_ids" + ) as _get_device_ids: # needed by scheduler + _get_device_ids.return_value = ["d0"] + controller = instantiate_controller(self.cfg) + self.assertIsInstance(controller, ParallelSplitController) + + def test_drop_irrelevant_keys(self): + self.cfg.controller["irrelevant_key"] = "some value" + reduced_cfg = __drop_irrelevant_keys__(self.cfg.controller) + self.assertListEqual( + list(reduced_cfg.keys()), ["_target_", "_partial_", "scheduler"] + ) + + +class GetTorchDeviceIdTest(unittest.TestCase): + + def setUp(self) -> None: + self.cfg = DictConfig( + { + "own_device_id": "d0", + "topology": { + "devices": [ + { + "device_id": "d0", + "address": "localhost:50051", + "torch_device": "my_torch_device1", + }, + { + "device_id": "d1", + "address": "localhost:50052", + "torch_device": "my_torch_device2", + }, + ] + }, + "num_devices": "${len:${topology.devices}}", + } + ) + + def test_get_torch_device1(self): + self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device1") + + def test_get_torch_device2(self): + self.cfg.own_device_id = "d1" + self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device2") diff --git a/edml/tests/helpers/data_partitioning_test.py b/edml/tests/helpers/data_partitioning_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0fde3277fb4ee7e8825605167111ba17480c032c --- /dev/null +++ b/edml/tests/helpers/data_partitioning_test.py @@ -0,0 +1,134 @@ +import unittest + +import torch.utils.data +from torch import Tensor +from torch.utils.data import TensorDataset + +from edml.helpers.data_partitioning import ( + __get_partitioned_data_for_device__, + DataPartitioner, +) + + +class DataPartitioningTest(unittest.TestCase): + + def setUp(self) -> None: + self.data = TensorDataset( + Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + Tensor([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]), + ) + + def test_deterministic_partitioning(self): + partition = __get_partitioned_data_for_device__(self.data, 0, 2, seed=42) + self.assertTrue( + torch.equal(partition[0:5][0], Tensor([3.0, 7.0, 2.0, 9.0, 5.0])) + ) + self.assertTrue( + torch.equal(partition[0:5][1], Tensor([1.0, 1.0, 0.0, 1.0, 1.0])) + ) + + def test_partition_distinct_data(self): + partition0_indices = __get_partitioned_data_for_device__( + self.data, 0, 2 + ).indices + partition1_indices = __get_partitioned_data_for_device__( + self.data, 1, 2 + ).indices + for i in partition0_indices: + self.assertFalse(i in partition1_indices) + self.assertEqual( + len(partition0_indices) + len(partition1_indices), len(self.data) + ) + + def test_partitioned_data_unequal_size(self): + partition0 = __get_partitioned_data_for_device__(self.data, 0, 3) + self.assertEqual(len(partition0), 4) + + partition1 = __get_partitioned_data_for_device__(self.data, 1, 3) + self.assertEqual(len(partition1), 3) + + partition2 = __get_partitioned_data_for_device__(self.data, 2, 3) + self.assertEqual(len(partition2), 3) + + def test_partition_with_lengths(self): + partition0 = __get_partitioned_data_for_device__( + self.data, 0, 3, fractions=[0.55, 0.27, 0.18] + ) + partition1 = __get_partitioned_data_for_device__( + self.data, 1, 3, fractions=[0.55, 0.27, 0.18] + ) + partition2 = __get_partitioned_data_for_device__( + self.data, 2, 3, fractions=[0.55, 0.27, 0.18] + ) + self.assertEqual(len(partition0), 6) + self.assertEqual(len(partition1), 3) + self.assertEqual(len(partition2), 1) + + def test_subset_only_with_lengths(self): + partition0 = __get_partitioned_data_for_device__( + self.data, 0, 3, fractions=[0.1, 0.2, 0.3] + ) + partition1 = __get_partitioned_data_for_device__( + self.data, 1, 3, fractions=[0.1, 0.2, 0.3] + ) + partition2 = __get_partitioned_data_for_device__( + self.data, 2, 3, fractions=[0.1, 0.2, 0.3] + ) + self.assertEqual(len(partition0), 1) + self.assertEqual(len(partition1), 2) + self.assertEqual(len(partition2), 3) + + def test_non_iid(self): + data = TensorDataset( + Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unsqueeze(1), + Tensor([1, 1, 2, 2, 3, 3, 4, 4, 5, 5]).unsqueeze(1), + ) + partition0 = __get_partitioned_data_for_device__( + data, 0, 3, fractions=[0.4, 0.3, 0.3], distribution="non-iid" + ) + partition1 = __get_partitioned_data_for_device__( + data, 1, 3, fractions=[0.4, 0.3, 0.3], distribution="non-iid" + ) + partition2 = __get_partitioned_data_for_device__( + data, 2, 3, fractions=[0.4, 0.3, 0.3], distribution="non-iid" + ) + self.assertEqual(len(partition0), 4) + self.assertEqual(len(partition1), 3) + self.assertEqual(partition2[0], (Tensor([8.0]), Tensor([4.0]))) + self.assertEqual(partition2[1], (Tensor([9.0]), Tensor([5.0]))) + self.assertEqual(partition2[2], (Tensor([10.0]), Tensor([5.0]))) + + def test_floating_point_arithmetic(self): + data = TensorDataset( + Tensor(range(45000)).unsqueeze(1), Tensor(range(45000)).unsqueeze(1) + ) + for num_devices in range(1, 100): + partition = __get_partitioned_data_for_device__( + data, num_devices - 1, num_devices + ) + self.assertEqual(len(partition), int(45000 / num_devices)) + + +class TestDataPartitioner(unittest.TestCase): + + def setUp(self) -> None: + self.data = TensorDataset( + Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + Tensor([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]), + ) + + def test_partition(self): + partitioner = DataPartitioner(0, 2, seed=42) + partition = partitioner.partition(self.data) + self.assertTrue( + torch.equal(partition[0:5][0], Tensor([3.0, 7.0, 2.0, 9.0, 5.0])) + ) + self.assertTrue( + torch.equal(partition[0:5][1], Tensor([1.0, 1.0, 0.0, 1.0, 1.0])) + ) + + def test_subset(self): + partitioner = DataPartitioner(0, 2, seed=42, fractions=[0.2, 0.2]) + partition = partitioner.partition(self.data) + self.assertTrue(torch.equal(partition[0:2][0], Tensor([3.0, 7.0]))) + self.assertTrue(torch.equal(partition[0:2][1], Tensor([1.0, 1.0]))) diff --git a/edml/tests/helpers/decorators_test.py b/edml/tests/helpers/decorators_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dd648e0a92040f876e3c9564a93826dc70e48580 --- /dev/null +++ b/edml/tests/helpers/decorators_test.py @@ -0,0 +1,293 @@ +import time +import unittest +from unittest.mock import Mock, patch + +import torch + +from edml.core.battery import Battery +from edml.core.client import DeviceClient +from edml.core.device import NetworkDevice +from edml.helpers.decorators import ( + check_device_set, + log_execution_time, + update_battery, + battery_updater, + LatencySimulator, + simulate_latency_decorator, + add_time_to_diagnostic_metrics, +) +from edml.helpers.logging import SimpleLogger +from edml.helpers.metrics import DiagnosticMetricResult, DiagnosticMetricResultContainer + + +class CheckDeviceDecoratorTest(unittest.TestCase): + + def test_client_with_device_set(self): + # check that decorator raises no exception if client has device set + client = Mock(spec=DeviceClient) + client.node_device = Mock(spec=NetworkDevice) + decorator = check_device_set() + decorated_function = decorator(lambda x: x) + exception_raised = False + try: + decorated_function(client) + except ValueError: + exception_raised = True + self.assertFalse(exception_raised) + + def test_client_without_device_set(self): + # check that decorator raises exception if client has no device set + client = Mock(spec=DeviceClient) + client.node_device = None + decorator = check_device_set() + decorated_function = decorator(lambda x: x) + + self.assertRaises(ValueError, decorated_function, client) + + +class LogExecutionTimeDecoratorTest(unittest.TestCase): + + @log_execution_time("logger", "sleep_time") + def function_to_decorate(self, sleep_time): + time.sleep(sleep_time) + + def test_log_execution_time(self): + # check that decorator returns execution time + # slight probability that this test fails since time.sleep is not exact + self.logger = Mock(spec=SimpleLogger) + + self.function_to_decorate(1) + + self.logger.log.assert_called_once() + execution_time = self.logger.log.call_args[0][0]["sleep_time"] + self.assertEqual(["start", "end", "duration"], list(execution_time.keys())) + self.assertAlmostEqual(execution_time["duration"], 1, delta=0.1) + + +class UpdateBatteryMethodDecoratorTest(unittest.TestCase): + + @update_battery + def function_to_decorate(self, sleep_time): + time.sleep(sleep_time) + + def test_update_battery(self): + self.battery = Battery(1000, deduction_per_second=1) + self.battery.start_experiment() + + self.function_to_decorate(1) + + self.assertAlmostEqual(self.battery.remaining_capacity(), 999, delta=0.1) + + def test_update_battery_without_time(self): + self.battery = Battery(1000, deduction_per_second=0) + self.battery.start_experiment() + + self.function_to_decorate(1) + + self.assertEqual(self.battery.remaining_capacity(), 1000) + + +class BatteryUpdaterClassDecoratorTest(unittest.TestCase): + # Deprecated + @battery_updater + class ClassToDecorate(object): + def __init__(self, battery): + self.battery = battery + + def function_to_decorate(self, sleep_time): + time.sleep(sleep_time) + + def test_update_battery(self): + battery = Battery(1000, deduction_per_second=1) + battery.start_experiment() + decorated_class = self.ClassToDecorate(battery) + decorated_class.function_to_decorate(1) + + self.assertAlmostEqual(battery.remaining_capacity(), 999, delta=0.1) + + def test_update_battery_without_time(self): + battery = Battery(1000, deduction_per_second=0) + battery.start_experiment() + + decorated_class = self.ClassToDecorate(battery) + decorated_class.function_to_decorate(1) + + self.assertEqual(battery.remaining_capacity(), 1000) + + +class LatencySimulatorTest(unittest.TestCase): + + def test_no_latency(self): + self._test_latency(0, 0, 0) + self._test_latency(0.5, 0, 0.5) + self._test_latency(1, 0, 1) + + def test_latency(self): + self._test_latency(0, 1, 0) + self._test_latency(0.5, 1, 1) + self._test_latency(1, 2, 3) + + def _test_latency(self, computation_time, latency_factor, resulting_time): + start_time = time.time() + with LatencySimulator(latency_factor=latency_factor) as l: + time.sleep(computation_time) + end_time = time.time() + self.assertAlmostEqual(end_time - start_time, resulting_time, delta=0.1) + + +class SimulateLatencyDecoratorTest(unittest.TestCase): + + def test_no_latency(self): + self._test_latency(0, 0, 0) + self._test_latency(0.5, 0, 0.5) + self._test_latency(1, 0, 1) + + def test_latency(self): + self._test_latency(0, 1, 0) + self._test_latency(0.5, 1, 1) + self._test_latency(1, 2, 3) + + @simulate_latency_decorator(latency_factor_attr="latency_factor") + def method_to_slow(self, sleep_time): + time.sleep(sleep_time) + + def _test_latency(self, computation_time, latency_factor, resulting_time): + start_time = time.time() + self.latency_factor = latency_factor + + self.method_to_slow(computation_time) + + end_time = time.time() + self.assertAlmostEqual(end_time - start_time, resulting_time, delta=0.1) + + +class AddTimeToDiagnosticMetricsTest(unittest.TestCase): + + def setUp(self): + self.device_id = "d0" # assume class to be device + + @add_time_to_diagnostic_metrics("echo_method_to_decorate") + def echo_method_to_decorate(self, input): + return input + + def test_append_diagnostic_metrics_to_no_output(self): + with patch("time.time", side_effect=[0, 42]): + response = self.echo_method_to_decorate(None) + self.assertEqual( + response.get_as_list(), + [ + DiagnosticMetricResult( + device_id="d0", + name="comp_time", + value=42, + method="echo_method_to_decorate", + ) + ], + ) + + def test_append_diagnostic_metrics_to_single_output(self): + with patch("time.time", side_effect=[0, 42]): + weights = {"weights": torch.tensor(42)} + response = self.echo_method_to_decorate(weights) + self.assertEqual( + response, + ( + weights, + DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", + name="comp_time", + value=42, + method="echo_method_to_decorate", + ) + ] + ), + ), + ) + + def test_append_diagnostic_metrics_to_multiple_outputs(self): + with patch("time.time", side_effect=[0, 42]): + weights = {"weights": torch.tensor(42)} + metrics = {"acc": torch.tensor(42)} + response = self.echo_method_to_decorate((weights, metrics)) + self.assertEqual( + response, + ( + weights, + metrics, + DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", + name="comp_time", + value=42, + method="echo_method_to_decorate", + ) + ] + ), + ), + ) + + def test_append_to_existing_diagnostic_metrics(self): + with patch("time.time", side_effect=[0, 42]): + previous_diagnostic_metrics = DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", name="comp_time", value=3, method="TrainBatch" + ) + ] + ) + response = self.echo_method_to_decorate(previous_diagnostic_metrics) + self.assertEqual( + response, + DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", + name="comp_time", + value=3, + method="TrainBatch", + ), + DiagnosticMetricResult( + device_id="d0", + name="comp_time", + value=42, + method="echo_method_to_decorate", + ), + ] + ), + ) + + def test_multiple_outputs_append_to_existing_diagnostic_metrics(self): + with patch("time.time", side_effect=[0, 42]): + previous_diagnostic_metrics = DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", name="comp_time", value=3, method="TrainBatch" + ) + ] + ) + response = self.echo_method_to_decorate((42, previous_diagnostic_metrics)) + self.assertEqual( + response, + ( + 42, + DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", + name="comp_time", + value=3, + method="TrainBatch", + ), + DiagnosticMetricResult( + device_id="d0", + name="comp_time", + value=42, + method="echo_method_to_decorate", + ), + ] + ), + ), + ) diff --git a/edml/tests/helpers/flops_test.py b/edml/tests/helpers/flops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4b7f0d66bbfe6c74ba37e517969bb596842277b4 --- /dev/null +++ b/edml/tests/helpers/flops_test.py @@ -0,0 +1,73 @@ +import unittest + +import torch +import torch.nn as nn + +from edml.helpers.flops import estimate_model_flops +from edml.models.mnist_models import ClientNet, ServerNet + + +class FullTestModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(in_features=1000, out_features=10) + self.fc2 = nn.Linear(in_features=10, out_features=10) + self.conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=1) + self.act = nn.ReLU() + + def forward(self, x): + return self.fc2(self.act(self.fc1(self.act(self.conv(x)).flatten(1)))) + + +class ClientTestModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=1) + self.act = nn.ReLU() + + def forward(self, x): + return self.act(self.conv(x)).flatten(1) + + +class ServerTestModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(in_features=1000, out_features=10) + self.fc2 = nn.Linear(in_features=10, out_features=10) + self.act = nn.ReLU() + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class FlopsTest(unittest.TestCase): + + def test_split_count(self): + """Check that a split model yields the same flop count as if the model was not split.""" + client_model = ClientTestModel() + server_model = ServerTestModel() + full_model = FullTestModel() + + inputs = (torch.randn((10, 3, 10, 10)),) + server_inputs = client_model(inputs[0]) + + full_flops = estimate_model_flops(full_model, inputs) + client_flops = estimate_model_flops(client_model, inputs) + server_flops = estimate_model_flops(server_model, server_inputs) + + self.assertEqual(client_flops, 30000) + self.assertEqual(server_flops, 101000) + self.assertEqual(full_flops, client_flops + server_flops) + + def test_mnist_split_count(self): + client_model = ClientNet() + server_model = ServerNet() + + inputs = torch.randn((1, 1, 28, 28)) + server_inputs = client_model(inputs) + + client_flops = estimate_model_flops(client_model, inputs) + server_flops = estimate_model_flops(server_model, server_inputs) + + self.assertEqual(client_flops, 5405760) + self.assertEqual(server_flops, 11215800) diff --git a/edml/tests/helpers/interceptors_test.py b/edml/tests/helpers/interceptors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..508bcf6e482a2d10105aff0e71c36b058188d729 --- /dev/null +++ b/edml/tests/helpers/interceptors_test.py @@ -0,0 +1,167 @@ +import threading +import unittest +from unittest.mock import Mock, call + +from grpc_interceptor.testing import dummy_client, DummyRequest +from torch import Tensor + +from edml.core.battery import Battery, BatteryEmptyException +from edml.generated.connection_pb2 import TrainGlobalRequest, TrainGlobalResponse +from edml.helpers.interceptors import ( + DeviceServerInterceptor, + _proto_serialized_size, + DeviceClientInterceptor, +) +from edml.helpers.logging import SimpleLogger +from edml.helpers.proto_helpers import weights_to_proto + + +class ProtoByteSizeTest(unittest.TestCase): + def setUp(self): + self.context = Mock() + self.context.invocation_metadata.return_value = [] + self.context.trailing_metadata.return_value = [] + + def test_TrainGlobalRequest_byte_size(self): + proto_object = TrainGlobalRequest(epochs=42) + self.assertEqual(_proto_serialized_size(proto_object), 7) + + def test_TrainGlobalResponse_byte_size(self): + proto_object = TrainGlobalResponse( + server_weights=weights_to_proto({"weights": Tensor([42])}), + client_weights=weights_to_proto({"weights": Tensor([42])}), + ) + self.assertEqual(_proto_serialized_size(proto_object), 845) + + +class DeviceServerInterceptorTest(unittest.TestCase): + + def setUp(self): + self.logger = Mock(spec=SimpleLogger) + self.battery = Mock(spec=Battery) + + def test_measure_byte_size(self): + interceptors = [ + DeviceServerInterceptor( + logger=self.logger, stop_event=None, battery=self.battery + ) + ] + with dummy_client(special_cases={}, interceptors=interceptors) as client: + request_input = "request" + request = DummyRequest(input=request_input) + + self.assertTrue(client.Execute(request).output == request_input) + + self.logger.log.assert_has_calls( + [ + call({"/DummyService/Execute_request_size": 14}), + call({"/DummyService/Execute_response_size": 14}), + ] + ) + self.battery.update_communication_received.assert_called_with(14) + self.battery.update_communication_sent.assert_called_with(14) + + def test_without_battery(self): + interceptors = [DeviceServerInterceptor(logger=self.logger, stop_event=None)] + with dummy_client(special_cases={}, interceptors=interceptors) as client: + request_input = "request" + request = DummyRequest(input=request_input) + self.assertTrue(client.Execute(request).output == request_input) + self.logger.log.assert_has_calls( + [ + call({"/DummyService/Execute_request_size": 14}), + call({"/DummyService/Execute_response_size": 14}), + ] + ) + self.battery.update_communication_received.assert_not_called() + self.battery.update_communication_sent.assert_not_called() + + def test_with_empty_battery_at_receiving(self): + stop_mock = Mock(threading.Event) + interceptors = [ + DeviceServerInterceptor( + logger=self.logger, stop_event=stop_mock, battery=self.battery + ) + ] + with dummy_client(special_cases={}, interceptors=interceptors) as client: + request_input = "request" + request = DummyRequest(input=request_input) + self.battery.is_empty.return_value = False + self.battery.update_communication_received.side_effect = ( + BatteryEmptyException + ) + + with self.assertRaises(Exception): + client.Execute(request) + stop_mock.set.assert_called() + + def test_with_empty_battery_at_responding(self): + stop_mock = Mock(threading.Event) + interceptors = [ + DeviceServerInterceptor( + logger=self.logger, stop_event=stop_mock, battery=self.battery + ) + ] + with dummy_client(special_cases={}, interceptors=interceptors) as client: + request_input = "request" + request = DummyRequest(input=request_input) + self.battery.is_empty.return_value = False + self.battery.update_communication_sent.side_effect = BatteryEmptyException + + with self.assertRaises(Exception): + client.Execute(request) + stop_mock.set.assert_called() + + +class DeviceClientInterceptorTest(unittest.TestCase): + def setUp(self): + self.logger = Mock(spec=SimpleLogger) + self.battery = Mock(spec=Battery) + self.stop_event = Mock(threading.Event) + + def test_measure_byte_size(self): + interceptors = [ + DeviceClientInterceptor( + logger=self.logger, stop_event=self.stop_event, battery=self.battery + ) + ] + with dummy_client(special_cases={}, client_interceptors=interceptors) as client: + request_input = "request" + request = DummyRequest(input=request_input) + + self.assertTrue(client.Execute(request).output == request_input) + + self.battery.update_communication_sent.assert_called_with(14) + self.battery.update_communication_received.assert_called_with(14) + + def test_with_empty_battery_at_sending(self): + interceptors = [ + DeviceClientInterceptor( + logger=self.logger, stop_event=self.stop_event, battery=self.battery + ) + ] + with dummy_client(special_cases={}, client_interceptors=interceptors) as client: + request_input = "request" + request = DummyRequest(input=request_input) + self.battery.update_communication_sent.side_effect = BatteryEmptyException + + with self.assertRaises(BatteryEmptyException): + client.Execute(request) + self.stop_event.set.assert_called() + + def test_with_empty_battery_at_receiving(self): + interceptors = [ + DeviceClientInterceptor( + logger=self.logger, stop_event=self.stop_event, battery=self.battery + ) + ] + with dummy_client(special_cases={}, client_interceptors=interceptors) as client: + request_input = "request" + request = DummyRequest(input=request_input) + self.battery.update_communication_received.side_effect = ( + BatteryEmptyException + ) + + with self.assertRaises(BatteryEmptyException): + client.Execute(request) + self.stop_event.set.assert_called() diff --git a/edml/tests/helpers/logging_test.py b/edml/tests/helpers/logging_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac00bb292110d454ea43fa1df01cde2b72820954 --- /dev/null +++ b/edml/tests/helpers/logging_test.py @@ -0,0 +1,29 @@ +import unittest +from unittest.mock import patch + +import edml.helpers.logging as logging + + +class WandbLoggerTest(unittest.TestCase): + + def setUp(self) -> None: + self.logger = logging.WandbLogger(None, 42) + self.logger.wandb_enabled = ( + True # skip logger.start_experiment to avoid unnecessary patching + ) + + @patch("wandb.finish", return_value=None) + def test_end_experiment(self, wandb_finish_mock): + self.logger.end_experiment() + self.assertFalse(self.logger.wandb_enabled) + wandb_finish_mock.assert_called_once() + + @patch("wandb.log") + def test_log_string(self, wandb_log_mock): + self.logger.log("test") + wandb_log_mock.log.assert_not_called() + + @patch("wandb.log") + def test_log_metrics(self, wandb_log_mock): + self.logger.log({"test": 42}) + wandb_log_mock.assert_called_once_with({"test": 42}) diff --git a/edml/tests/helpers/metrics_test.py b/edml/tests/helpers/metrics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1d13f2e895a2dfa00d7ee91998812442a80182f0 --- /dev/null +++ b/edml/tests/helpers/metrics_test.py @@ -0,0 +1,317 @@ +import unittest + +import torch.utils.data +import torchmetrics + +from edml.helpers import metrics +from edml.helpers.metrics import ( + DiagnosticMetricResult, + compute_metrics_for_optimization, + DiagnosticMetricResultContainer, +) + + +def _all_metrics_num_samples_equals( + container: metrics.ModelMetricContainer, num_samples: int +): + """Helper method to test if all metric objects have the given empty flag.""" + for metric in container.metrics.values(): + if metric.num_samples != num_samples: + return False + return True + + +class CreateMetricsTest(unittest.TestCase): + def setUp(self): + self.cfg = { + "average_setting": "micro", + "classes": 5, + "metrics": ["accuracy", "f1", "auc"], + } + self.metric_container = metrics.create_metrics( + self.cfg["metrics"], self.cfg["classes"], self.cfg["average_setting"] + ) + + def test_create_metrics(self): + self.assertEqual(len(self.metric_container.metrics), 3) + self.assertTrue(_all_metrics_num_samples_equals(self.metric_container, 0)) + + +class MetricContainerTest(unittest.TestCase): + def setUp(self) -> None: + self.cfg = { + "average_setting": "micro", + "classes": 5, + "metrics": ["accuracy", "f1", "auc"], + } + self.metric_container = metrics.create_metrics( + self.cfg["metrics"], self.cfg["classes"], self.cfg["average_setting"] + ) + + def test_add_metric(self): + self.assertEqual(len(self.metric_container.metrics.values()), 3) + self.metric_container.add_metric(torchmetrics.classification.Recall(), "recall") + self.assertEqual(len(self.metric_container.metrics.values()), 4) + + def test_metric_on_batch(self): + predictions = torch.Tensor([[0.3351, 0.0562, 0.7052, 0.4191, 0.2356]]) + labels = torch.Tensor([[0, 0, 0, 1, 0]]).int() + self.assertTrue(_all_metrics_num_samples_equals(self.metric_container, 0)) + + res = self.metric_container.metrics_on_batch(predictions, labels) + + self.assertTrue(_all_metrics_num_samples_equals(self.metric_container, 1)) + self.assertEqual( + res, [torch.Tensor([0.6000]), torch.Tensor([0.0]), torch.Tensor([0.7500])] + ) + + def test_compute_metrics(self): + predictions_1 = torch.Tensor([[0.3351, 0.0562, 0.7052, 0.4191, 0.2356]]) + labels_1 = torch.Tensor([[0, 0, 0, 1, 0]]).int() + predictions_2 = torch.Tensor([[0.7075, 0.4229, 0.9530, 0.1434, 0.4220]]) + labels_2 = torch.Tensor([[0, 0, 0, 1, 0]]).int() + self.metric_container.metrics_on_batch(predictions_1, labels_1) + self.metric_container.metrics_on_batch(predictions_2, labels_2) + self.assertTrue(_all_metrics_num_samples_equals(self.metric_container, 2)) + + res = self.metric_container.compute_metrics("test", "d1") + self.assertEqual(type(res), list) + self.assertEqual(res[0].value, torch.Tensor([0.5000])) # acc + self.assertEqual(res[1].value, torch.Tensor([0.0])) # f1 + self.assertEqual(res[2].value, torch.Tensor([0.2500])) # auc + for metric_result in res: + self.assertEqual(metric_result.num_samples, 2) + + def test_with_empty_metrics(self): + self.assertTrue(_all_metrics_num_samples_equals(self.metric_container, 0)) + + res = self.metric_container.compute_metrics("test", "d1") + + self.assertEqual(res, []) + + def test_reset_metrics(self): + predictions = torch.Tensor([[0.3351, 0.0562, 0.7052, 0.4191, 0.2356]]) + labels = torch.Tensor([[0, 0, 0, 1, 0]]).int() + self.metric_container.metrics_on_batch(predictions, labels) + self.assertTrue(_all_metrics_num_samples_equals(self.metric_container, 1)) + + self.metric_container.reset_metrics() + + self.assertTrue(_all_metrics_num_samples_equals(self.metric_container, 0)) + + +class MetricResultContainerTest(unittest.TestCase): + def setUp(self): + self.metric_result_container = metrics.ModelMetricResultContainer() + + def test_add_result(self): + self.assertEqual(len(self.metric_result_container.results), 0) + result = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([42]), 42) + self.metric_result_container.add_result(result) + self.assertDictEqual( + self.metric_result_container.get_raw_metrics(), {("acc", "test"): [result]} + ) + + def test_add_results_with_different_phases(self): + self.assertEqual(len(self.metric_result_container.results), 0) + result1 = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([42]), 42) + result2 = metrics.ModelMetricResult("d1", "acc", "val", torch.Tensor([43]), 42) + self.metric_result_container.add_results([result1, result2]) + self.assertDictEqual( + self.metric_result_container.get_raw_metrics(), + {("acc", "test"): [result1], ("acc", "val"): [result2]}, + ) + + def test_added_result_with_same_phase(self): + self.assertEqual(len(self.metric_result_container.results), 0) + result1 = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([42]), 42) + result2 = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([43]), 42) + self.metric_result_container.add_result(result1) + self.metric_result_container.add_result(result2) + self.assertDictEqual( + self.metric_result_container.get_raw_metrics(), + {("acc", "test"): [result1, result2]}, + ) + + def test_merge_results_into_container(self): + self.assertEqual(len(self.metric_result_container.results), 0) + result1 = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([42]), 42) + result2 = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([43]), 42) + result3 = metrics.ModelMetricResult("d1", "acc", "val", torch.Tensor([44]), 42) + result4 = metrics.ModelMetricResult("d1", "acc", "val", torch.Tensor([45]), 42) + self.metric_result_container.add_results([result1, result2]) + other = metrics.ModelMetricResultContainer([result3, result4]) + + self.metric_result_container.merge(other) + + self.assertDictEqual( + self.metric_result_container.get_raw_metrics(), + {("acc", "test"): [result1, result2], ("acc", "val"): [result3, result4]}, + ) + + def test_get_aggregated_metrics(self): + result1 = metrics.ModelMetricResult("d0", "acc", "test", torch.Tensor([42]), 1) + result2 = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([43]), 3) + result3 = metrics.ModelMetricResult("d0", "acc", "val", torch.Tensor([44]), 1) + result4 = metrics.ModelMetricResult("d1", "acc", "val", torch.Tensor([45]), 3) + self.metric_result_container.add_results([result1, result2, result3, result4]) + + aggregated_results_container = ( + self.metric_result_container.get_aggregated_metrics() + ) + + self.assertDictEqual( + aggregated_results_container.get_raw_metrics(), + { + ("acc", "test"): [ + metrics.ModelMetricResult( + "aggregated", "acc", "test", torch.Tensor([42.75]), 4 + ) + ], + ("acc", "val"): [ + metrics.ModelMetricResult( + "aggregated", "acc", "val", torch.Tensor([44.75]), 4 + ) + ], + }, + ) + + def test_get_as_list(self): + result1 = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([42]), 42) + result2 = metrics.ModelMetricResult("d1", "acc", "test", torch.Tensor([43]), 42) + result3 = metrics.ModelMetricResult("d1", "acc", "val", torch.Tensor([44]), 42) + result4 = metrics.ModelMetricResult("d1", "acc", "val", torch.Tensor([45]), 42) + self.metric_result_container.add_results([result1, result2, result3, result4]) + + result_list = self.metric_result_container.get_as_list() + + self.assertEqual(result_list, [result1, result2, result3, result4]) + + +class DiagnosticMetricResultContainerTest(unittest.TestCase): + + def setUp(self): + self.diagnostic_metric_result_container = ( + metrics.DiagnosticMetricResultContainer() + ) + + def test_add_result(self): + self.assertEqual(len(self.diagnostic_metric_result_container.results), 0) + result = metrics.DiagnosticMetricResult("d1", "acc", "test", 0.42) + self.diagnostic_metric_result_container.add_result(result) + self.assertDictEqual( + self.diagnostic_metric_result_container.get_raw_metrics(), + {("acc", "test"): [result]}, + ) + + +class MetricsForOptimizationTest(unittest.TestCase): + def test_get_metrics_for_optimization(self): + results = [ + # comp_time + # 2 + 4 samples & batch_size = 1 + DiagnosticMetricResult( + device_id="d0", method="train_batch", name="comp_time", value=0.6 + ), + DiagnosticMetricResult( + device_id="d0", method="train_batch", name="comp_time", value=0.6 + ), + DiagnosticMetricResult( + device_id="d0", method="train_batch", name="comp_time", value=0.6 + ), + DiagnosticMetricResult( + device_id="d0", method="train_batch", name="comp_time", value=0.6 + ), + DiagnosticMetricResult( + device_id="d0", method="train_batch", name="comp_time", value=0.6 + ), + DiagnosticMetricResult( + device_id="d0", method="train_batch", name="comp_time", value=0.6 + ), + DiagnosticMetricResult( + device_id="d0", method="train_global", name="comp_time", value=10 + ), + # train epoch time = train batch time + client model train time + DiagnosticMetricResult( + device_id="d0", + method="client_train_epoch_time", + name="comp_time", + value=0.6, + ), + DiagnosticMetricResult( + device_id="d1", + method="client_train_epoch_time", + name="comp_time", + value=2.4, + ), + # one eval batch each + DiagnosticMetricResult( + device_id="d0", method="evaluate_batch", name="comp_time", value=0.1 + ), + DiagnosticMetricResult( + device_id="d0", method="evaluate_batch", name="comp_time", value=0.1 + ), + DiagnosticMetricResult( + device_id="d0", + method="client_eval_epoch_time", + name="comp_time", + value=0.1, + ), + DiagnosticMetricResult( + device_id="d1", + method="client_eval_epoch_time", + name="comp_time", + value=0.2, + ), + # size + DiagnosticMetricResult( + device_id="d0", method="gradients", name="size", value=50 + ), + DiagnosticMetricResult( + device_id="d1", method="gradients", name="size", value=50 + ), + DiagnosticMetricResult( + device_id="d0", method="labels", name="size", value=6 + ), + DiagnosticMetricResult( + device_id="d1", method="labels", name="size", value=6 + ), + DiagnosticMetricResult( + device_id="d0", method="smashed_data", name="size", value=42 + ), + DiagnosticMetricResult( + device_id="d1", method="smashed_data", name="size", value=42 + ), + DiagnosticMetricResult( + device_id="d0", method="client_weights", name="size", value=500 + ), + DiagnosticMetricResult( + device_id="d1", method="server_weights", name="size", value=500 + ), + DiagnosticMetricResult( + device_id="d1", method="optimizer_state", name="size", value=500 + ), + ] + diagnostic_metric_result_container = DiagnosticMetricResultContainer(results) + batch_size = 1 + + num_samples_per_device = {"d0": (2, 1), "d1": (4, 1)} # (train, eval) + + normalized_metrics = compute_metrics_for_optimization( + diagnostic_metric_result_container, num_samples_per_device, batch_size + ) + + self.assertEqual(normalized_metrics["smashed_data_size"], 42) + self.assertEqual(normalized_metrics["label_size"], 6) + self.assertEqual(normalized_metrics["gradient_size"], 50) + self.assertEqual(normalized_metrics["client_weight_size"], 500) + self.assertEqual(normalized_metrics["server_weight_size"], 500) + self.assertEqual(normalized_metrics["optimizer_state_size"], 500) + self.assertEqual(normalized_metrics["train_global_time"], 10) + self.assertAlmostEqual(normalized_metrics["client_norm_fw_time"], 0.1) + self.assertAlmostEqual(normalized_metrics["client_norm_bw_time"], 0.2) + self.assertAlmostEqual(normalized_metrics["server_norm_fw_time"], 0.1) + self.assertAlmostEqual(normalized_metrics["server_norm_bw_time"], 0.5) + self.assertDictEqual( + normalized_metrics["comp_latency_factor"], {"d0": 1.0, "d1": 2.0} + ) diff --git a/edml/tests/helpers/model_splitting_test.py b/edml/tests/helpers/model_splitting_test.py new file mode 100644 index 0000000000000000000000000000000000000000..88876bd124af4648645f617dfa0ed2968f6151ef --- /dev/null +++ b/edml/tests/helpers/model_splitting_test.py @@ -0,0 +1,18 @@ +import unittest + +import torch + +from edml.helpers.model_splitting import split_network_at_layer +from edml.models.resnet_models import BasicBlock, ResNet + + +class ModelSplittingTest(unittest.TestCase): + def test_split_network_at_layer_with_resnet(self): + # test that the original model and the concatenation of client and server model produce the same output + model = ResNet(BasicBlock, [3, 3, 3], num_classes=100) + client, server = split_network_at_layer(model, 3) + random_input = torch.rand(1, 3, 32, 32) + model_output = model(random_input) + client_output = client(random_input) + server_output = server(client_output) + self.assertTrue(torch.equal(model_output, server_output)) diff --git a/edml/tests/helpers/proto_helpers_test.py b/edml/tests/helpers/proto_helpers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fd65f98b72570ecd06ea9b1d1f6490f1e486b71a --- /dev/null +++ b/edml/tests/helpers/proto_helpers_test.py @@ -0,0 +1,167 @@ +import unittest + +from torch import Tensor + +import edml.generated.datastructures_pb2 +import edml.helpers.proto_helpers as proto_helpers +from edml.generated.connection_pb2 import FullModelTrainResponse +from edml.helpers.metrics import ( + ModelMetricResultContainer, + ModelMetricResult, + DiagnosticMetricResultContainer, + DiagnosticMetricResult, +) + + +class ProtoHelpersTest(unittest.TestCase): + + def setUp(self) -> None: + self.tensor = Tensor([42]) + self.state_dict = {"weights": self.tensor} + self.metrics = ModelMetricResultContainer() + self.metrics.add_result(ModelMetricResult("d1", "acc", "test", self.tensor, 42)) + + def test_tensor_to_proto(self): + proto = proto_helpers.tensor_to_proto(self.tensor) + self.assertEqual(proto_helpers.proto_to_tensor(proto), self.tensor) + + def test_state_dict_to_proto(self): + proto = proto_helpers.state_dict_to_proto(self.state_dict) + self.assertEqual(proto_helpers.proto_to_state_dict(proto), self.state_dict) + + def test_proto_to_tensor(self): + proto = edml.generated.datastructures_pb2.Tensor( + serialized=proto_helpers._tensor_to_bytes(self.tensor) + ) + self.assertEqual(proto_helpers.proto_to_tensor(proto), self.tensor) + + def test_proto_to_state_dict(self): + proto = edml.generated.datastructures_pb2.StateDict( + serialized=proto_helpers._state_dict_to_bytes(self.state_dict) + ) + self.assertEqual(proto_helpers.proto_to_state_dict(proto), self.state_dict) + + def test_weights_to_proto(self): + proto = proto_helpers.weights_to_proto(self.state_dict) + self.assertEqual(proto_helpers.proto_to_weights(proto), self.state_dict) + + def test_none_weights_to_proto(self): + proto = proto_helpers.weights_to_proto(None) + self.assertEqual(proto_helpers.proto_to_weights(proto), None) + + def test_proto_to_weights(self): + proto = edml.generated.datastructures_pb2.Weights( + weights=proto_helpers.state_dict_to_proto(self.state_dict) + ) + self.assertEqual(proto_helpers.proto_to_weights(proto), self.state_dict) + + def test_activations_to_proto(self): + proto = proto_helpers.activations_to_proto(self.tensor) + self.assertEqual(proto_helpers.proto_to_activations(proto), self.tensor) + + def test_proto_to_activations(self): + proto = edml.generated.datastructures_pb2.Activations( + activations=proto_helpers.tensor_to_proto(self.tensor) + ) + self.assertEqual(proto_helpers.proto_to_activations(proto), self.tensor) + + def test_labels_to_proto(self): + proto = proto_helpers.labels_to_proto(self.tensor) + self.assertEqual(proto_helpers.proto_to_labels(proto), self.tensor) + + def test_proto_to_labels(self): + proto = edml.generated.datastructures_pb2.Labels( + labels=proto_helpers.tensor_to_proto(self.tensor) + ) + self.assertEqual(proto_helpers.proto_to_labels(proto), self.tensor) + + def test_proto_to_device_info(self): + proto = edml.generated.datastructures_pb2.DeviceInfo( + device_id="42", address="localhost" + ) + self.assertEqual(proto_helpers.proto_to_device_info(proto), ("42", "localhost")) + + def test_device_info_to_proto(self): + proto = proto_helpers.device_info_to_proto("42", "localhost") + self.assertEqual(proto_helpers.proto_to_device_info(proto), ("42", "localhost")) + + def test_proto_to_gradients(self): + proto = edml.generated.datastructures_pb2.Gradients( + gradients=proto_helpers.tensor_to_proto(self.tensor) + ) + self.assertEqual(proto_helpers.proto_to_gradients(proto), self.tensor) + + def test_gradients_to_proto(self): + proto = proto_helpers.gradients_to_proto(self.tensor) + self.assertEqual(proto_helpers.proto_to_gradients(proto), self.tensor) + + def test_proto_to_metrics(self): + proto = edml.generated.datastructures_pb2.Metrics( + metrics=proto_helpers._metrics_to_bytes(self.metrics) + ) + self.assertEqual(proto_helpers.proto_to_metrics(proto), self.metrics) + + def test_metrics_to_proto(self): + proto = proto_helpers.metrics_to_proto(self.metrics) + self.assertEqual(proto_helpers.proto_to_metrics(proto), self.metrics) + + def test_proto_size_per_attribute(self): + proto_object = FullModelTrainResponse( + client_weights=proto_helpers.weights_to_proto(self.state_dict), + server_weights=proto_helpers.weights_to_proto(self.state_dict), + num_samples=42, # to be ignored by the size computation + metrics=proto_helpers.metrics_to_proto(self.metrics), + ) + result = proto_helpers._proto_size_per_field(proto_object, "d0") + self.assertEqual( + result, + DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", name="size", value=417, method="client_weights" + ), + DiagnosticMetricResult( + device_id="d0", name="size", value=417, method="server_weights" + ), + DiagnosticMetricResult( + device_id="d0", name="size", value=585, method="metrics" + ), + ] + ), + ) + + def test_proto_size_per_attribute_ignore_diagnostic_metrics(self): + proto_object = FullModelTrainResponse( + client_weights=proto_helpers.weights_to_proto(self.state_dict), + server_weights=proto_helpers.weights_to_proto(self.state_dict), + num_samples=42, # to be ignored by the size computation + metrics=proto_helpers.metrics_to_proto(self.metrics), + diagnostic_metrics=proto_helpers.metrics_to_proto( + DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", + name="comp_time", + value=42, + method="TrainEpoch", + ) + ] + ) + ), + ) + self.assertEqual( + proto_helpers._proto_size_per_field(proto_object, "d0"), + DiagnosticMetricResultContainer( + [ + DiagnosticMetricResult( + device_id="d0", name="size", value=417, method="client_weights" + ), + DiagnosticMetricResult( + device_id="d0", name="size", value=417, method="server_weights" + ), + DiagnosticMetricResult( + device_id="d0", name="size", value=585, method="metrics" + ), + ] + ), + ) diff --git a/edml/tests/integration/rpc_server_test.py b/edml/tests/integration/rpc_server_test.py new file mode 100644 index 0000000000000000000000000000000000000000..faafd7b9b94300c9509d777453694381dfd1529c --- /dev/null +++ b/edml/tests/integration/rpc_server_test.py @@ -0,0 +1,199 @@ +import threading +import unittest +from concurrent import futures +from unittest.mock import Mock, call + +import grpc +from omegaconf import DictConfig + +from edml.core.battery import Battery, BatteryEmptyException +from edml.core.device import NetworkDevice, RPCDeviceServicer, DeviceRequestDispatcher +from edml.generated import connection_pb2_grpc +from edml.helpers.interceptors import DeviceServerInterceptor +from edml.helpers.logging import SimpleLogger +from edml.helpers.metrics import DiagnosticMetricResultContainer + + +class RPCServerTest(unittest.TestCase): + + def setUp(self): + # set up logger, battery, stop event and device as mocks to assert calls and set behavior + self.server_logger = Mock(spec=SimpleLogger) + self.server_battery = Mock(spec=Battery) + self.server_stop_mock = Mock(threading.Event) + + self.client_battery = Mock(spec=Battery) + self.client_logger = Mock(spec=SimpleLogger) + self.client_stop_mock = Mock(threading.Event) + self.device = Mock( + NetworkDevice( + device_id="d0", logger=self.server_logger, battery=self.server_battery + ) + ) + self.device.device_id = "d0" + + devices = [DictConfig({"device_id": "d0", "address": "localhost:50061"})] + + # start server + self.grpc_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=2), + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + interceptors=[ + DeviceServerInterceptor( + logger=self.server_logger, + stop_event=self.server_stop_mock, + battery=self.server_battery, + ), + ], + ) + connection_pb2_grpc.add_DeviceServicer_to_server( + RPCDeviceServicer(device=self.device), self.grpc_server + ) + self.grpc_server.add_insecure_port("localhost:50061") + self.grpc_server.start() + + self.dispatcher = DeviceRequestDispatcher( + devices, + logger=self.client_logger, + battery=self.client_battery, + stop_event=self.client_stop_mock, + ) + + def tearDown(self): + self.grpc_server.stop(grace=None) + + def test_request_success(self): + self.device.train_epoch.return_value = { + "weights": 42 + }, DiagnosticMetricResultContainer() + + active_devices = self.dispatcher.active_devices() + self.assertEqual(active_devices, ["d0"]) + + self.dispatcher.train_epoch_on("d0", "d0") + + self.device.train_epoch.assert_called() + self.server_stop_mock.set.assert_not_called() + self.client_stop_mock.set.assert_not_called() + self.assertEqual(self.dispatcher.active_devices(), ["d0"]) + + def test_empty_battery_at_receiving_request(self): + self.server_battery.update_communication_received.side_effect = ( + BatteryEmptyException + ) + active_devices = self.dispatcher.active_devices() + self.assertEqual(active_devices, ["d0"]) + + self.dispatcher.train_epoch_on("d0", "d0") + + self.server_logger.log.assert_has_calls( + [ + call({"/Device/TrainEpoch_request_size": 39}), + call("Battery empty while receiving request"), + ] + ) + self.server_stop_mock.set.assert_called() + self.server_battery.update_communication_received.assert_called_with(39) + self.server_battery.update_communication_sent.assert_not_called() + self.client_logger.log.assert_called_with( + "RPC server battery empty during request" + ) + self.client_battery.update_communication_sent.assert_called_with(39) + self.client_battery.update_communication_received.assert_not_called() + self.client_stop_mock.set.assert_not_called() + self.assertEqual(self.dispatcher.active_devices(), []) + + def test_empty_battery_while_processing_request(self): + self.device.train_epoch.side_effect = BatteryEmptyException + + active_devices = self.dispatcher.active_devices() + self.assertEqual(active_devices, ["d0"]) + + self.dispatcher.train_epoch_on("d0", "d0") + + self.device.train_epoch.assert_called() + self.server_stop_mock.set.assert_called() + self.server_battery.update_communication_received.assert_called_with(39) + self.server_battery.update_communication_sent.assert_not_called() + self.client_logger.log.assert_called_with( + "RPC server battery empty during request" + ) + self.client_battery.update_communication_sent.assert_called_with(39) + self.client_battery.update_communication_received.assert_not_called() + self.client_stop_mock.set.assert_not_called() + self.assertEqual(self.dispatcher.active_devices(), []) + + def test_empty_battery_at_sending_response(self): + self.server_battery.update_communication_sent.side_effect = ( + BatteryEmptyException + ) + self.device.train_epoch.return_value = { + "weights": 42 + }, DiagnosticMetricResultContainer() + + active_devices = self.dispatcher.active_devices() + self.assertEqual(active_devices, ["d0"]) + + self.dispatcher.train_epoch_on("d0", "d0") + + self.server_logger.log.assert_has_calls( + [ + call({"/Device/TrainEpoch_request_size": 39}), + call({"/Device/TrainEpoch_response_size": 132}), + call("Battery empty while handling request or sending response"), + ] + ) + self.server_battery.update_communication_sent.assert_called_with(132) + self.server_battery.update_communication_received.assert_called_with(39) + self.server_stop_mock.set.assert_called() + self.client_logger.log.assert_called_with( + "RPC server battery empty during request" + ) + self.client_battery.update_communication_sent.assert_called_with(39) + self.client_battery.update_communication_received.assert_not_called() + self.client_stop_mock.set.assert_not_called() + self.assertEqual(self.dispatcher.active_devices(), []) + + def test_empty_client_battery_while_sending_request(self): + active_devices = self.dispatcher.active_devices() + self.assertEqual(active_devices, ["d0"]) + self.client_battery.update_communication_sent.side_effect = ( + BatteryEmptyException + ) + + with self.assertRaises(BatteryEmptyException): + # this should raise an exception because the client battery is empty + # hence dispatcher is not needed anymore + self.dispatcher.train_epoch_on("d0", "d0") + + self.device.train_epoch.assert_not_called() + self.server_stop_mock.set.assert_not_called() + self.client_stop_mock.set.assert_called() + self.client_logger.log.assert_called_with( + "RPC client battery empty during request" + ) + + def test_empty_client_battery_while_receiving_response(self): + active_devices = self.dispatcher.active_devices() + self.assertEqual(active_devices, ["d0"]) + self.client_battery.update_communication_received.side_effect = ( + BatteryEmptyException + ) + self.device.train_epoch.return_value = { + "weights": 42 + }, DiagnosticMetricResultContainer() + + with self.assertRaises(BatteryEmptyException): + # this should raise an exception because the client battery is empty + # hence dispatcher is not needed anymore + self.dispatcher.train_epoch_on("d0", "d0") + + self.device.train_epoch.assert_called() + self.server_stop_mock.set.assert_not_called() + self.client_stop_mock.set.assert_called() + self.client_logger.log.assert_called_with( + "RPC client battery empty during request" + ) diff --git a/edml/tests/models/resnet_models_test.py b/edml/tests/models/resnet_models_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a9872b4e4ab50c934a7265a213456142677d916b --- /dev/null +++ b/edml/tests/models/resnet_models_test.py @@ -0,0 +1,11 @@ +import unittest + +from edml.models import resnet_models + + +class ResnetModelsTest(unittest.TestCase): + + def test_get_resnet(self): + client, server = resnet_models.resnet20(4, 100) + self.assertIsNotNone(client.state_dict()) + self.assertIsNotNone(server.state_dict()) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..3d8fa013b52754dcd0fafd41bc4c0d303690cf44 --- /dev/null +++ b/environment.yml @@ -0,0 +1,44 @@ +name: slenv +# For a faster installation run the following before conda env create: +# conda install -n base conda-libmamba-solver +# conda config --set solver libmamba +channels: + - conda-forge + - nvidia + - pytorch + - defaults +dependencies: + - asttokens=2.2.1 + - cudatoolkit=10.2.89 + - jupyter=1.0.0 + - matplotlib=3.8.0 + - mkl=2024.0.0 # Not used explicitly, but added to avoid compatibility error. Can be removed if https://github.com/pytorch/pytorch/issues/123097 is resolved + - numpy=1.24.3 + - pandas=1.4.2 + - pip=21.2.4 + - pre-commit=3.7.0 + - protobuf=3.19.1 + - pydantic + - pytest=7.4.2 + - python=3.9.7 + - pytorch=2.0.1 + - pytorch-cuda=11.7 + - scikit-learn=1.0.2 + - scipy=1.10.1 + - torchmetrics=0.9.2 + - torchvision=0.15.2 + - tqdm=4.64.0 + - typing_extensions=4.6.3 + - wandb=0.15.11 + - wfdb=3.4.1 + - pip: + - fvcore==0.1.5.post20221221 + - hydra-core==1.3.2 + - pytorch-nemo==0.0.8 + - grpc-interceptor==0.15.3 + - grpcio==1.58.0 + - grpcio-tools==1.58.0 + - grpcio-testing==1.58.0 + - ortools==9.8.3296 + - tabulate==0.9.0 +prefix: /home/tim/anaconda3/envs/slenv diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..d8fabc2bc41e4226e56ff5cffbd5ebe2642aaeb8 --- /dev/null +++ b/main.py @@ -0,0 +1,95 @@ +import random + +import hydra +import numpy.random +import torch +from omegaconf import DictConfig + +from edml.controllers.test_controller import TestController +from edml.core.start_device import launch_device +from edml.helpers.config_helpers import preprocess_config, instantiate_controller +from multiprocessing import Process + + +@hydra.main(version_base=None, config_name="default", config_path="config") +def main(cfg): + """Starts either a device with a gRPC server or a controller depending on the config.""" + preprocess_config(cfg) + + # Fix the seed for reproducibility. Every client also has to do the same, since multiprocessing does + # not share the same random state for all libraries we use. + _make_deterministic(cfg.seed) + + client_processes = [] + + # First, we start all the devices. + for device in cfg.topology.devices: + p = Process( + target=_start_device, args=(cfg, device.device_id), name=device.device_id + ) + p.start() + client_processes.append(p) + + # Then we start the controller. + _start_controller(cfg) + + # Once we reach this point, the controller has shutdown and the experiment is over. We terminate all clients. + for p in client_processes: + p.terminate() + p.join() + + _run_test_evaluation(cfg) + + +def _start_controller(cfg: DictConfig): + controller = instantiate_controller(cfg) + controller.train() + + +def _start_device(cfg: DictConfig, device_id: str): + # Update the device ID to be unique. + cfg.own_device_id = device_id + + # Fix the seed for reproducibility. + _make_deterministic(cfg.seed) + + # Start the device. + launch_device(cfg) + + +def _run_test_evaluation(cfg): + cfg.experiment.job = "test" + cfg.experiment.partition = "False" + cfg.experiment.latency = None + cfg.num_devices = 1 + device_id = cfg.topology.devices[0].device_id + p = Process(target=_start_device, args=(cfg, device_id), name=device_id) + p.start() + + controller = TestController(cfg) + controller.train() + + p.terminate() + p.join() + + +def _make_deterministic(seed_cfg: DictConfig): + seed = seed_cfg.value + + # Set the seed for all libraries that we use. + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # This is actually not enough to make PyTorch fully deterministic. This only configures CUDA to use a deterministic + # convolution algorithm. It does not make the selection of the algorithm itself deterministic. However, the projects + # that I encountered online did mostly only set the following flag. + # + # See https://pytorch.org/docs/stable/notes/randomness.html for more information. + if seed_cfg.torch_deterministic: + torch.backends.cudnn.deterministic = True + + +if __name__ == "__main__": + main() diff --git a/quick_pass.sh b/quick_pass.sh new file mode 100644 index 0000000000000000000000000000000000000000..94c5f9780f30fa3212fdfb2c59cf6fc585180298 --- /dev/null +++ b/quick_pass.sh @@ -0,0 +1,27 @@ +# shell script to start 2 devices with only a single batch of data each and run all controllers sequentially + +conda init bash +conda activate slenv + +# use trap to kill all processes in the subshell on ctrl-c +(trap 'kill 0' SIGINT; + # testrun + echo "Starting a quick test run with 2 devices" + device_pids=() + python3 main.py own_device_id=0 num_devices=2 experiment.max_epochs=1 experiment.max_rounds=1 experiment.load_single_batch_for_debugging=True wandb=False & + device_pids+=($!) + python3 main.py own_device_id=1 num_devices=2 experiment.max_epochs=1 experiment.max_rounds=1 experiment.load_single_batch_for_debugging=True wandb=False & + device_pids+=($!) + # run all controllers sequentially + python3 main.py +method='fed' num_devices=2 experiment.max_epochs=1 experiment.max_rounds=1 wandb=False + python3 main.py +method='split' num_devices=2 experiment.max_epochs=1 experiment.max_rounds=1 wandb=False + python3 main.py +method='swarm_seq' num_devices=2 experiment.max_epochs=1 experiment.max_rounds=1 wandb=False + python3 main.py +method='swarm_rand' num_devices=2 experiment.max_epochs=1 experiment.max_rounds=1 wandb=False + python3 main.py +method='swarm_max' num_devices=2 experiment.max_epochs=1 experiment.max_rounds=1 wandb=False + + # kill device processes after controller has finished + for pid in "${device_pids[@]}"; do + kill "$pid" + done +) +exit 1; diff --git a/results/.gitignore b/results/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4c785dbd31c6c1ad85636a39fd124b691c0e23f5 --- /dev/null +++ b/results/.gitignore @@ -0,0 +1,3 @@ +dataframes +metrics +plots diff --git a/results/baseline_evaluation.ipynb b/results/baseline_evaluation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..45cf7ae99ade471d207bf881be20d27ecea2acfe --- /dev/null +++ b/results/baseline_evaluation.ipynb @@ -0,0 +1,1989 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2023-11-23T14:26:55.110525300Z", + "start_time": "2023-11-23T14:26:53.495389800Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import wandb\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001B[34m\u001B[1mwandb\u001B[0m: Currently logged in as: \u001B[33mtim-ba\u001B[0m (\u001B[33mswarmsl\u001B[0m). Use \u001B[1m`wandb login --relogin`\u001B[0m to force relogin\n" + ] + } + ], + "source": [ + "wandb.login()\n", + "api = wandb.Api(timeout=29)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:26:57.045785900Z", + "start_time": "2023-11-23T14:26:55.130645600Z" + } + }, + "id": "2f1b50906b1ff28e" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "strategies=[\"fed\", \"split\", \"swarm_seq\", \"swarm_rand\", \"swarm_max\"]\n", + "batteries=[\"equal_batteries_only_flops\", \"unequal_batteries_only_flops\", \"equal_batteries_unlimited\"]" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:26:57.045785900Z", + "start_time": "2023-11-23T14:26:57.045785900Z" + } + }, + "id": "21e617b9f706a0b5" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [], + "source": [ + "run_groups = {}\n", + "for strategy in strategies:\n", + " for battery in batteries:\n", + " group = api.runs(\"tim-ba/baseline\", filters={\"group\": f\"{strategy}_{battery}\"})\n", + " run_groups[f\"{strategy}_{battery}\"] = group" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:26:57.265350600Z", + "start_time": "2023-11-23T14:26:57.045785900Z" + } + }, + "id": "363de94bb69a4d05" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/Device/FullModelTraining_response_size\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "_step\n", + "_timestamp\n", + "/Device/FullModelTraining_request_size\n", + "battery\n", + "/Device/SetWeights_response_size\n", + "fed_train_time\n", + "train_accuracy\n", + "_runtime\n", + "val_accuracy\n", + "client_train_epoch_time\n", + "/Device/SetWeights_request_size\n", + "val_accuracy\n", + "_step\n", + "client_train_epoch_time\n", + "_runtime\n", + "fed_train_time\n", + "client_evaluate_time\n", + "/Device/SetWeights_request_size\n", + "/Device/FullModelTraining_response_size\n", + "_timestamp\n", + "train_accuracy\n", + "/Device/StartExperiment_response_size\n", + "battery\n", + "/Device/FullModelTraining_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/SetWeights_response_size\n", + "client_train_epoch_time\n", + "/Device/FullModelTraining_request_size\n", + "_step\n", + "train_accuracy\n", + "/Device/SetWeights_request_size\n", + "/Device/StartExperiment_response_size\n", + "val_accuracy\n", + "battery\n", + "_runtime\n", + "client_evaluate_time\n", + "fed_train_time\n", + "/Device/FullModelTraining_response_size\n", + "_timestamp\n", + "/Device/StartExperiment_response_size\n", + "_step\n", + "/Device/SetWeights_response_size\n", + "/Device/FullModelTraining_request_size\n", + "_runtime\n", + "train_accuracy\n", + "/Device/FullModelTraining_response_size\n", + "client_evaluate_time\n", + "battery\n", + "val_accuracy\n", + "_timestamp\n", + "client_train_epoch_time\n", + "fed_train_time\n", + "/Device/SetWeights_request_size\n", + "battery\n", + "_step\n", + "client_train_epoch_time\n", + "train_accuracy\n", + "/Device/SetWeights_request_size\n", + "/Device/FullModelTraining_request_size\n", + "fed_train_time\n", + "_timestamp\n", + "val_accuracy\n", + "client_evaluate_time\n", + "/Device/StartExperiment_response_size\n", + "_runtime\n", + "/Device/SetWeights_response_size\n", + "/Device/FullModelTraining_response_size\n", + "_runtime\n", + "/Device/SetWeights_response_size\n", + "_step\n", + "val_accuracy\n", + "_timestamp\n", + "fed_train_time\n", + "train_accuracy\n", + "/Device/FullModelTraining_response_size\n", + "/Device/StartExperiment_response_size\n", + "client_train_epoch_time\n", + "/Device/SetWeights_request_size\n", + "/Device/FullModelTraining_request_size\n", + "client_evaluate_time\n", + "battery\n", + "val_accuracy\n", + "/Device/SetWeights_response_size\n", + "/Device/StartExperiment_response_size\n", + "train_accuracy\n", + "/Device/SetWeights_request_size\n", + "client_train_epoch_time\n", + "/Device/FullModelTraining_request_size\n", + "_runtime\n", + "fed_train_time\n", + "_step\n", + "_timestamp\n", + "battery\n", + "client_evaluate_time\n", + "/Device/FullModelTraining_response_size\n", + "/Device/SetWeights_request_size\n", + "fed_train_time\n", + "_runtime\n", + "val_accuracy\n", + "battery\n", + "_step\n", + "/Device/FullModelTraining_response_size\n", + "train_accuracy\n", + "/Device/StartExperiment_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/FullModelTraining_request_size\n", + "_timestamp\n", + "client_evaluate_time\n", + "client_train_epoch_time\n", + "/Device/SetWeights_response_size\n", + "/Device/FullModelTraining_request_size\n", + "/Device/SetWeights_request_size\n", + "client_evaluate_time\n", + "train_accuracy\n", + "val_accuracy\n", + "_step\n", + "client_train_epoch_time\n", + "battery\n", + "/Device/FullModelTraining_response_size\n", + "fed_train_time\n", + "_timestamp\n", + "_runtime\n", + "/Device/StartExperiment_response_size\n", + "train_accuracy\n", + "/Device/FullModelTraining_request_size\n", + "/Device/FullModelTraining_response_size\n", + "client_evaluate_time\n", + "client_train_epoch_time\n", + "fed_train_time\n", + "_timestamp\n", + "/Device/StartExperiment_response_size\n", + "/Device/SetWeights_request_size\n", + "_runtime\n", + "val_accuracy\n", + "_step\n", + "/Device/SetWeights_response_size\n", + "battery\n", + "client_train_epoch_time\n", + "/Device/EndExperiment_request_size\n", + "/Device/FullModelTraining_request_size\n", + "client_evaluate_time\n", + "train_accuracy\n", + "val_accuracy\n", + "battery\n", + "/Device/SetWeights_request_size\n", + "_timestamp\n", + "/Device/SetWeights_response_size\n", + "_step\n", + "_runtime\n", + "/Device/StartExperiment_response_size\n", + "/Device/FullModelTraining_response_size\n", + "fed_train_time\n", + "val_accuracy\n", + "fed_train_time\n", + "/Device/FullModelTraining_response_size\n", + "/Device/FullModelTraining_request_size\n", + "/Device/SetWeights_request_size\n", + "/Device/EndExperiment_request_size\n", + "battery\n", + "_runtime\n", + "/Device/StartExperiment_response_size\n", + "_step\n", + "train_accuracy\n", + "client_train_epoch_time\n", + "/Device/SetWeights_response_size\n", + "_timestamp\n", + "client_evaluate_time\n", + "_timestamp\n", + "/Device/EndExperiment_request_size\n", + "_step\n", + "_runtime\n", + "fed_train_time\n", + "/Device/FullModelTraining_response_size\n", + "/Device/StartExperiment_response_size\n", + "battery\n", + "/Device/FullModelTraining_request_size\n", + "val_accuracy\n", + "/Device/SetWeights_request_size\n", + "/Device/SetWeights_response_size\n", + "train_accuracy\n", + "client_evaluate_time\n", + "client_train_epoch_time\n", + "/Device/SetWeights_request_size\n", + "val_accuracy\n", + "_timestamp\n", + "client_train_epoch_time\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "/Device/SetWeights_response_size\n", + "fed_train_time\n", + "/Device/FullModelTraining_response_size\n", + "_runtime\n", + "/Device/EndExperiment_request_size\n", + "train_accuracy\n", + "_step\n", + "/Device/FullModelTraining_request_size\n", + "battery\n", + "_timestamp\n", + "/Device/StartExperiment_response_size\n", + "client_train_epoch_time\n", + "_step\n", + "_runtime\n", + "/Device/EndExperiment_request_size\n", + "val_accuracy\n", + "client_evaluate_time\n", + "battery\n", + "train_accuracy\n", + "fed_train_time\n", + "/Device/SetWeights_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/FullModelTraining_request_size\n", + "/Device/FullModelTraining_response_size\n", + "/Device/TrainEpoch_request_size\n", + "client_evaluate_time\n", + "battery\n", + "/Device/SetWeights_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/SetWeights_response_size\n", + "_runtime\n", + "_timestamp\n", + "/Device/GetBatteryStatus_request_size\n", + "_step\n", + "/Device/GetBatteryStatus_response_size\n", + "client_train_epoch_time\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/EndExperiment_request_size\n", + "_timestamp\n", + "/Device/TrainEpoch_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/StartExperiment_response_size\n", + "client_train_epoch_time\n", + "/Device/TrainEpoch_response_size\n", + "/Device/SetWeights_request_size\n", + "_runtime\n", + "client_evaluate_time\n", + "/Device/SetWeights_response_size\n", + "_step\n", + "/Device/Evaluate_request_size\n", + "battery\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/Evaluate_request_size\n", + "battery\n", + "client_evaluate_time\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainEpoch_response_size\n", + "_timestamp\n", + "client_train_epoch_time\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/StartExperiment_response_size\n", + "_step\n", + "/Device/SetWeights_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/TrainEpoch_request_size\n", + "_runtime\n", + "/Device/TrainEpoch_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/Evaluate_request_size\n", + "battery\n", + "/Device/EndExperiment_request_size\n", + "/Device/Evaluate_response_size\n", + "_timestamp\n", + "client_train_epoch_time\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/TrainEpoch_request_size\n", + "_step\n", + "client_evaluate_time\n", + "_runtime\n", + "/Device/SetWeights_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/SetWeights_response_size\n", + "val_accuracy\n", + "/Device/SetWeights_request_size\n", + "client_evaluate_time\n", + "/Device/TrainGlobal_response_size\n", + "_step\n", + "train_accuracy\n", + "/Device/EvaluateBatch_response_size\n", + "battery\n", + "/Device/TrainBatch_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "_timestamp\n", + "/Device/TrainBatch_request_size\n", + "client_train_epoch_time\n", + "_runtime\n", + "train_global_time\n", + "/Device/GetBatteryStatus_request_size\n", + "_timestamp\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/Evaluate_request_size\n", + "client_train_epoch_time\n", + "_step\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "/Device/SetWeights_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "_runtime\n", + "battery\n", + "/Device/SetWeights_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/Evaluate_response_size\n", + "client_train_epoch_time\n", + "client_evaluate_time\n", + "/Device/Evaluate_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "_timestamp\n", + "_step\n", + "battery\n", + "/Device/TrainEpoch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/StartExperiment_response_size\n", + "_runtime\n", + "/Device/Evaluate_request_size\n", + "battery\n", + "_step\n", + "client_evaluate_time\n", + "/Device/TrainEpoch_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "_timestamp\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainEpoch_request_size\n", + "_runtime\n", + "/Device/EndExperiment_request_size\n", + "client_train_epoch_time\n", + "_runtime\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_request_size\n", + "_timestamp\n", + "client_evaluate_time\n", + "/Device/GetBatteryStatus_response_size\n", + "client_train_epoch_time\n", + "_step\n", + "/Device/Evaluate_request_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/StartExperiment_response_size\n", + "battery\n", + "/Device/TrainEpoch_response_size\n", + "/Device/SetWeights_response_size\n", + "train_global_time\n", + "/Device/SetWeights_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/SetWeights_response_size\n", + "client_evaluate_time\n", + "/Device/EvaluateBatch_response_size\n", + "_step\n", + "_runtime\n", + "_timestamp\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_train_epoch_time\n", + "train_accuracy\n", + "/Device/StartExperiment_response_size\n", + "val_accuracy\n", + "/Device/TrainBatch_response_size\n", + "/Device/TrainGlobal_request_size\n", + "battery\n", + "/Device/EvaluateBatch_request_size\n", + "_runtime\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_request_size\n", + "_timestamp\n", + "/Device/EndExperiment_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "client_train_epoch_time\n", + "client_evaluate_time\n", + "/Device/Evaluate_response_size\n", + "_step\n", + "/Device/TrainEpoch_request_size\n", + "/Device/StartExperiment_response_size\n", + "battery\n", + "_step\n", + "/Device/SetWeights_request_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "battery\n", + "_timestamp\n", + "client_evaluate_time\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/SetWeights_response_size\n", + "client_train_epoch_time\n", + "/Device/Evaluate_request_size\n", + "_runtime\n", + "/Device/EndExperiment_request_size\n", + "client_evaluate_time\n", + "_step\n", + "_timestamp\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/SetWeights_request_size\n", + "_runtime\n", + "/Device/TrainEpoch_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "battery\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainEpoch_response_size\n", + "client_train_epoch_time\n", + "/Device/Evaluate_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/SetWeights_request_size\n", + "client_evaluate_time\n", + "client_train_epoch_time\n", + "/Device/TrainEpoch_request_size\n", + "/Device/StartExperiment_response_size\n", + "_runtime\n", + "_step\n", + "battery\n", + "/Device/Evaluate_response_size\n", + "_timestamp\n", + "/Device/EndExperiment_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "client_train_epoch_time\n", + "_timestamp\n", + "/Device/GetBatteryStatus_response_size\n", + "train_global_time\n", + "_step\n", + "/Device/GetBatteryStatus_request_size\n", + "client_evaluate_time\n", + "/Device/EndExperiment_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainGlobal_request_size\n", + "val_accuracy\n", + "/Device/TrainBatch_response_size\n", + "/Device/TrainGlobal_response_size\n", + "battery\n", + "_runtime\n", + "train_accuracy\n", + "/Device/SetWeights_request_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainBatch_request_size\n", + "train_accuracy\n", + "/Device/StartExperiment_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "client_evaluate_time\n", + "/Device/TrainEpoch_response_size\n", + "_timestamp\n", + "client_train_epoch_time\n", + "/Device/Evaluate_request_size\n", + "/Device/SetWeights_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "val_accuracy\n", + "battery\n", + "_runtime\n", + "_step\n", + "train_global_time\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_train_epoch_time\n", + "battery\n", + "_runtime\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/SetWeights_response_size\n", + "_timestamp\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "train_accuracy\n", + "/Device/EvaluateBatch_response_size\n", + "train_global_time\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/Evaluate_request_size\n", + "_step\n", + "/Device/SetWeights_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "val_accuracy\n", + "train_global_time\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainEpoch_request_size\n", + "val_accuracy\n", + "_step\n", + "/Device/Evaluate_request_size\n", + "train_accuracy\n", + "/Device/EvaluateBatch_request_size\n", + "client_train_epoch_time\n", + "/Device/TrainGlobal_request_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainBatch_request_size\n", + "battery\n", + "_runtime\n", + "client_evaluate_time\n", + "/Device/TrainBatch_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/SetWeights_response_size\n", + "_timestamp\n", + "/Device/TrainEpoch_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "_runtime\n", + "train_global_time\n", + "client_train_epoch_time\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainGlobal_response_size\n", + "_step\n", + "train_accuracy\n", + "/Device/Evaluate_response_size\n", + "battery\n", + "/Device/SetWeights_request_size\n", + "_timestamp\n", + "val_accuracy\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainBatch_request_size\n", + "client_evaluate_time\n", + "/Device/TrainGlobal_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "client_train_epoch_time\n", + "_runtime\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "_timestamp\n", + "val_accuracy\n", + "/Device/TrainGlobal_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "train_accuracy\n", + "battery\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/StartExperiment_response_size\n", + "train_global_time\n", + "client_evaluate_time\n", + "_step\n", + "train_accuracy\n", + "/Device/TrainEpoch_response_size\n", + "_step\n", + "val_accuracy\n", + "client_train_epoch_time\n", + "/Device/TrainGlobal_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_evaluate_time\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainBatch_request_size\n", + "battery\n", + "_timestamp\n", + "_runtime\n", + "/Device/TrainBatch_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_evaluate_time\n", + "battery\n", + "_runtime\n", + "train_global_time\n", + "/Device/GetBatteryStatus_request_size\n", + "train_accuracy\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "val_accuracy\n", + "/Device/StartExperiment_response_size\n", + "_step\n", + "/Device/Evaluate_response_size\n", + "client_train_epoch_time\n", + "_timestamp\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "train_accuracy\n", + "/Device/TrainBatch_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainGlobal_request_size\n", + "battery\n", + "_runtime\n", + "client_train_epoch_time\n", + "/Device/StartExperiment_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainBatch_response_size\n", + "_timestamp\n", + "client_evaluate_time\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_request_size\n", + "_step\n", + "val_accuracy\n", + "train_global_time\n", + "/Device/EvaluateBatch_response_size\n", + "_runtime\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/SetWeights_response_size\n", + "val_accuracy\n", + "train_accuracy\n", + "client_train_epoch_time\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/TrainGlobal_request_size\n", + "_step\n", + "battery\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_response_size\n", + "_timestamp\n", + "/Device/TrainBatch_response_size\n", + "train_global_time\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/SetWeights_response_size\n", + "client_train_epoch_time\n", + "/Device/TrainBatch_response_size\n", + "/Device/TrainGlobal_request_size\n", + "_timestamp\n", + "/Device/TrainBatch_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "battery\n", + "train_global_time\n", + "_step\n", + "/Device/EndExperiment_request_size\n", + "/Device/StartExperiment_response_size\n", + "_runtime\n", + "client_evaluate_time\n", + "/Device/Evaluate_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "val_accuracy\n", + "train_accuracy\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_train_epoch_time\n", + "/Device/EvaluateBatch_response_size\n", + "battery\n", + "/Device/SetWeights_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/TrainBatch_request_size\n", + "client_evaluate_time\n", + "/Device/EvaluateBatch_request_size\n", + "_timestamp\n", + "val_accuracy\n", + "/Device/GetBatteryStatus_request_size\n", + "_runtime\n", + "/Device/TrainEpoch_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/TrainEpoch_request_size\n", + "train_global_time\n", + "/Device/TrainGlobal_request_size\n", + "train_accuracy\n", + "_step\n", + "/Device/EndExperiment_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/Evaluate_request_size\n", + "val_accuracy\n", + "/Device/EndExperiment_request_size\n", + "client_evaluate_time\n", + "/Device/TrainBatch_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "client_train_epoch_time\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "_step\n", + "train_accuracy\n", + "/Device/GetBatteryStatus_response_size\n", + "_timestamp\n", + "/Device/SetWeights_request_size\n", + "train_global_time\n", + "/Device/SetWeights_response_size\n", + "_runtime\n", + "/Device/TrainEpoch_request_size\n", + "battery\n", + "_timestamp\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "_step\n", + "/Device/EndExperiment_request_size\n", + "val_accuracy\n", + "/Device/StartExperiment_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/Evaluate_request_size\n", + "client_evaluate_time\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/TrainBatch_request_size\n", + "_runtime\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "client_train_epoch_time\n", + "train_global_time\n", + "train_accuracy\n", + "/Device/Evaluate_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainBatch_response_size\n", + "battery\n", + "battery\n", + "/Device/TrainEpoch_response_size\n", + "train_global_time\n", + "/Device/TrainGlobal_request_size\n", + "_runtime\n", + "/Device/GetBatteryStatus_request_size\n", + "client_evaluate_time\n", + "/Device/TrainEpoch_request_size\n", + "_step\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/Evaluate_request_size\n", + "_timestamp\n", + "/Device/TrainBatch_response_size\n", + "client_train_epoch_time\n", + "val_accuracy\n", + "/Device/TrainBatch_request_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "train_accuracy\n", + "/Device/SetWeights_request_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "val_accuracy\n", + "battery\n", + "/Device/EndExperiment_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "_timestamp\n", + "client_train_epoch_time\n", + "/Device/SetWeights_request_size\n", + "_runtime\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainGlobal_request_size\n", + "_step\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_response_size\n", + "train_global_time\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "train_accuracy\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_request_size\n", + "_step\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/StartExperiment_response_size\n", + "battery\n", + "client_evaluate_time\n", + "client_train_epoch_time\n", + "train_accuracy\n", + "/Device/TrainBatch_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/SetWeights_response_size\n", + "train_global_time\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_request_size\n", + "_runtime\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/Evaluate_response_size\n", + "_timestamp\n", + "/Device/TrainGlobal_request_size\n", + "val_accuracy\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainBatch_request_size\n", + "_runtime\n", + "_timestamp\n", + "client_evaluate_time\n", + "/Device/TrainEpoch_request_size\n", + "val_accuracy\n", + "/Device/TrainEpoch_response_size\n", + "/Device/StartExperiment_response_size\n", + "_step\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/Evaluate_request_size\n", + "train_global_time\n", + "battery\n", + "/Device/TrainBatch_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/SetWeights_request_size\n", + "train_accuracy\n", + "client_train_epoch_time\n", + "/Device/TrainGlobal_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "_timestamp\n", + "/Device/Evaluate_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/StartExperiment_response_size\n", + "val_accuracy\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainEpoch_response_size\n", + "client_evaluate_time\n", + "/Device/TrainGlobal_request_size\n", + "_runtime\n", + "_step\n", + "/Device/TrainBatch_request_size\n", + "train_global_time\n", + "/Device/GetBatteryStatus_response_size\n", + "client_train_epoch_time\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainBatch_response_size\n", + "battery\n", + "train_accuracy\n", + "/Device/SetWeights_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "train_accuracy\n", + "_step\n", + "_runtime\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "/Device/GetBatteryStatus_response_size\n", + "train_global_time\n", + "/Device/SetWeights_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_response_size\n", + "_timestamp\n", + "/Device/TrainGlobal_response_size\n", + "client_train_epoch_time\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainGlobal_request_size\n", + "battery\n", + "val_accuracy\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainGlobal_response_size\n", + "train_global_time\n", + "client_evaluate_time\n", + "_timestamp\n", + "client_train_epoch_time\n", + "/Device/Evaluate_request_size\n", + "battery\n", + "train_accuracy\n", + "/Device/TrainBatch_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "_runtime\n", + "_step\n", + "/Device/TrainBatch_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/SetWeights_request_size\n", + "val_accuracy\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/TrainGlobal_request_size\n", + "_step\n", + "/Device/StartExperiment_response_size\n", + "train_accuracy\n", + "/Device/SetWeights_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "val_accuracy\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/Evaluate_request_size\n", + "battery\n", + "/Device/TrainBatch_response_size\n", + "client_evaluate_time\n", + "_runtime\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainEpoch_response_size\n", + "client_train_epoch_time\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/TrainGlobal_response_size\n", + "_timestamp\n", + "train_global_time\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_response_size\n", + "client_evaluate_time\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainEpoch_request_size\n", + "_timestamp\n", + "train_accuracy\n", + "/Device/EvaluateBatch_request_size\n", + "battery\n", + "_runtime\n", + "/Device/EvaluateBatch_response_size\n", + "_step\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_train_epoch_time\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/StartExperiment_response_size\n", + "val_accuracy\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "train_accuracy\n", + "_runtime\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/SetWeights_response_size\n", + "val_accuracy\n", + "client_train_epoch_time\n", + "/Device/TrainEpoch_response_size\n", + "client_evaluate_time\n", + "_timestamp\n", + "train_global_time\n", + "/Device/TrainGlobal_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "battery\n", + "/Device/TrainBatch_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainBatch_request_size\n", + "_step\n", + "client_evaluate_time\n", + "_runtime\n", + "client_train_epoch_time\n", + "/Device/TrainEpoch_request_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/Evaluate_response_size\n", + "train_accuracy\n", + "battery\n", + "/Device/TrainEpoch_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/SetWeights_response_size\n", + "val_accuracy\n", + "/Device/GetBatteryStatus_request_size\n", + "train_global_time\n", + "_step\n", + "/Device/TrainGlobal_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "_timestamp\n", + "/Device/TrainBatch_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "train_accuracy\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/SetWeights_request_size\n", + "train_global_time\n", + "_step\n", + "/Device/EndExperiment_request_size\n", + "battery\n", + "/Device/EvaluateBatch_request_size\n", + "_timestamp\n", + "/Device/TrainGlobal_response_size\n", + "_runtime\n", + "/Device/GetBatteryStatus_response_size\n", + "client_evaluate_time\n", + "/Device/Evaluate_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "val_accuracy\n", + "client_train_epoch_time\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/Evaluate_response_size\n", + "client_train_epoch_time\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainBatch_response_size\n", + "client_evaluate_time\n", + "/Device/TrainGlobal_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainEpoch_response_size\n", + "train_accuracy\n", + "_runtime\n", + "/Device/TrainGlobal_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "battery\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/EndExperiment_request_size\n", + "train_global_time\n", + "_timestamp\n", + "val_accuracy\n", + "/Device/GetBatteryStatus_response_size\n", + "_step\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainEpoch_response_size\n", + "train_global_time\n", + "/Device/SetWeights_response_size\n", + "_step\n", + "client_evaluate_time\n", + "/Device/Evaluate_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainGlobal_request_size\n", + "val_accuracy\n", + "/Device/TrainBatch_request_size\n", + "battery\n", + "/Device/Evaluate_request_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "_timestamp\n", + "/Device/SetWeights_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "client_train_epoch_time\n", + "train_accuracy\n", + "_runtime\n", + "_runtime\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/StartExperiment_response_size\n", + "train_global_time\n", + "client_train_epoch_time\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/SetWeights_response_size\n", + "_step\n", + "/Device/SetWeights_request_size\n", + "client_evaluate_time\n", + "train_accuracy\n", + "battery\n", + "/Device/EndExperiment_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainEpoch_response_size\n", + "val_accuracy\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainGlobal_request_size\n", + "_timestamp\n", + "/Device/TrainBatch_response_size\n", + "client_train_epoch_time\n", + "/Device/EndExperiment_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_evaluate_time\n", + "/Device/TrainGlobal_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "_step\n", + "_runtime\n", + "battery\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "_timestamp\n", + "/Device/Evaluate_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/SetWeights_response_size\n", + "train_accuracy\n", + "val_accuracy\n", + "/Device/TrainEpoch_response_size\n", + "/Device/StartExperiment_response_size\n", + "train_global_time\n", + "/Device/TrainBatch_response_size\n", + "/Device/TrainGlobal_request_size\n", + "_runtime\n", + "/Device/SetWeights_response_size\n", + "_step\n", + "train_accuracy\n", + "train_global_time\n", + "_timestamp\n", + "/Device/TrainEpoch_request_size\n", + "val_accuracy\n", + "battery\n", + "client_evaluate_time\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/EndExperiment_request_size\n", + "client_train_epoch_time\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_train_epoch_time\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainBatch_response_size\n", + "val_accuracy\n", + "client_evaluate_time\n", + "/Device/EvaluateBatch_response_size\n", + "battery\n", + "_timestamp\n", + "/Device/SetWeights_request_size\n", + "/Device/EndExperiment_request_size\n", + "train_accuracy\n", + "/Device/StartExperiment_response_size\n", + "_runtime\n", + "train_global_time\n", + "/Device/SetWeights_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/TrainGlobal_response_size\n", + "_step\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_response_size\n", + "client_evaluate_time\n", + "/Device/TrainGlobal_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "train_accuracy\n", + "/Device/TrainEpoch_request_size\n", + "train_global_time\n", + "battery\n", + "_timestamp\n", + "/Device/TrainBatch_request_size\n", + "_step\n", + "_runtime\n", + "client_train_epoch_time\n", + "/Device/Evaluate_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/StartExperiment_response_size\n", + "val_accuracy\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_response_size\n", + "_step\n", + "_runtime\n", + "/Device/Evaluate_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_request_size\n", + "train_global_time\n", + "client_evaluate_time\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/SetWeights_response_size\n", + "train_accuracy\n", + "client_train_epoch_time\n", + "/Device/TrainEpoch_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "battery\n", + "val_accuracy\n", + "/Device/TrainGlobal_request_size\n", + "_timestamp\n", + "battery\n", + "train_accuracy\n", + "/Device/SetWeights_request_size\n", + "_timestamp\n", + "client_evaluate_time\n", + "client_train_epoch_time\n", + "/Device/TrainEpoch_request_size\n", + "val_accuracy\n", + "train_global_time\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/TrainGlobal_response_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "_runtime\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainGlobal_request_size\n", + "_step\n", + "/Device/TrainEpoch_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/GetBatteryStatus_request_size\n", + "_step\n", + "/Device/GetBatteryStatus_response_size\n", + "train_global_time\n", + "/Device/TrainBatch_response_size\n", + "train_accuracy\n", + "_timestamp\n", + "/Device/TrainGlobal_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/Evaluate_request_size\n", + "_runtime\n", + "/Device/EvaluateBatch_response_size\n", + "val_accuracy\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/EvaluateBatch_request_size\n", + "battery\n", + "client_train_epoch_time\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "/Device/GetBatteryStatus_response_size\n", + "_timestamp\n", + "client_train_epoch_time\n", + "_step\n", + "/Device/StartExperiment_response_size\n", + "battery\n", + "/Device/SetWeights_request_size\n", + "client_evaluate_time\n", + "/Device/TrainEpoch_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "_runtime\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "_timestamp\n", + "battery\n", + "/Device/StartExperiment_response_size\n", + "_runtime\n", + "_step\n", + "client_train_epoch_time\n", + "/Device/Evaluate_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "client_evaluate_time\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_request_size\n", + "_step\n", + "client_evaluate_time\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/StartExperiment_response_size\n", + "_runtime\n", + "/Device/Evaluate_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/TrainEpoch_request_size\n", + "_timestamp\n", + "/Device/SetWeights_response_size\n", + "battery\n", + "client_train_epoch_time\n", + "/Device/GetBatteryStatus_request_size\n", + "client_evaluate_time\n", + "val_accuracy\n", + "/Device/SetWeights_response_size\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainGlobal_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/StartExperiment_response_size\n", + "_step\n", + "train_global_time\n", + "client_train_epoch_time\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainBatch_request_size\n", + "_runtime\n", + "/Device/SetWeights_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainGlobal_response_size\n", + "battery\n", + "_timestamp\n", + "train_accuracy\n", + "/Device/Evaluate_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainGlobal_request_size\n", + "val_accuracy\n", + "/Device/GetBatteryStatus_request_size\n", + "_step\n", + "_timestamp\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "_runtime\n", + "/Device/SetWeights_request_size\n", + "battery\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "train_accuracy\n", + "/Device/EvaluateBatch_request_size\n", + "train_global_time\n", + "/Device/TrainBatch_response_size\n", + "/Device/EvaluateBatch_response_size\n", + "/Device/TrainEpoch_request_size\n", + "client_train_epoch_time\n", + "client_train_epoch_time\n", + "/Device/GetBatteryStatus_response_size\n", + "battery\n", + "/Device/SetWeights_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/StartExperiment_response_size\n", + "client_evaluate_time\n", + "/Device/Evaluate_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/EndExperiment_request_size\n", + "_runtime\n", + "/Device/GetBatteryStatus_request_size\n", + "_step\n", + "_timestamp\n", + "/Device/SetWeights_request_size\n", + "/Device/TrainEpoch_request_size\n", + "/Device/SetWeights_request_size\n", + "/Device/SetWeights_response_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "client_evaluate_time\n", + "_runtime\n", + "/Device/TrainEpoch_request_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/Evaluate_response_size\n", + "_step\n", + "client_train_epoch_time\n", + "/Device/Evaluate_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "_timestamp\n", + "battery\n", + "/Device/StartExperiment_response_size\n", + "/Device/TrainEpoch_request_size\n", + "battery\n", + "client_evaluate_time\n", + "/Device/GetBatteryStatus_request_size\n", + "/Device/TrainEpoch_response_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/Evaluate_request_size\n", + "/Device/EndExperiment_request_size\n", + "/Device/StartExperiment_response_size\n", + "_timestamp\n", + "client_train_epoch_time\n", + "/Device/Evaluate_response_size\n", + "_step\n", + "_runtime\n", + "/Device/SetWeights_response_size\n", + "/Device/SetWeights_request_size\n", + "client_train_epoch_time\n", + "_runtime\n", + "/Device/TrainEpoch_response_size\n", + "battery\n", + "/Device/SetWeights_request_size\n", + "/Device/Evaluate_request_size\n", + "/Device/StartExperiment_response_size\n", + "/Device/Evaluate_response_size\n", + "/Device/SetWeights_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "_step\n", + "/Device/TrainEpoch_request_size\n", + "/Device/EndExperiment_request_size\n", + "_timestamp\n", + "client_evaluate_time\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/StartExperiment_response_size\n", + "client_train_epoch_time\n", + "train_global_time\n", + "/Device/EndExperiment_request_size\n", + "/Device/SetWeights_response_size\n", + "_timestamp\n", + "/Device/SetWeights_request_size\n", + "val_accuracy\n", + "/Device/EvaluateBatch_request_size\n", + "/Device/GetBatteryStatus_response_size\n", + "/Device/TrainBatch_response_size\n", + "/Device/TrainBatch_request_size\n", + "/Device/EvaluateBatch_response_size\n", + "train_accuracy\n", + "/Device/TrainGlobal_response_size\n", + "/Device/GetBatteryStatus_request_size\n", + "client_evaluate_time\n", + "/Device/TrainGlobal_request_size\n", + "_step\n", + "_runtime\n", + "battery\n" + ] + } + ], + "source": [ + "communication_sizes = {}\n", + "for name, group in run_groups.items(): \n", + " payload_size = 0\n", + " for run in group:\n", + " history_df = pd.DataFrame(run.scan_history())\n", + " for col in history_df.columns:\n", + " print(col)\n", + " if col.startswith(\"/Device/\"):\n", + " payload_size += history_df[col].sum()\n", + " communication_sizes[name] = payload_size" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:35:16.534589300Z", + "start_time": "2023-11-23T14:26:57.275349900Z" + } + }, + "id": "892c37ec9ffc62b2" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "{'fed_equal_batteries_only_flops': 480236160.0,\n 'fed_unequal_batteries_only_flops': 461026730.0,\n 'fed_equal_batteries_unlimited': 2401180740.0,\n 'split_equal_batteries_only_flops': 22342901919.0,\n 'split_unequal_batteries_only_flops': 28204101940.0,\n 'split_equal_batteries_unlimited': 159592077305.0,\n 'swarm_seq_equal_batteries_only_flops': 31965074916.0,\n 'swarm_seq_unequal_batteries_only_flops': 28928048314.0,\n 'swarm_seq_equal_batteries_unlimited': 159820568850.0,\n 'swarm_rand_equal_batteries_only_flops': 31528249111.0,\n 'swarm_rand_unequal_batteries_only_flops': 27544788384.0,\n 'swarm_rand_equal_batteries_unlimited': 159820492125.0,\n 'swarm_max_equal_batteries_only_flops': 31965076866.0,\n 'swarm_max_unequal_batteries_only_flops': 29812824528.0,\n 'swarm_max_equal_batteries_unlimited': 159823645350.0}" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "communication_sizes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:35:16.550081700Z", + "start_time": "2023-11-23T14:35:16.536073400Z" + } + }, + "id": "c8c73ee8c74bd181" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "{'fed_equal_batteries_only_flops': 0.48023616,\n 'fed_unequal_batteries_only_flops': 0.46102673,\n 'fed_equal_batteries_unlimited': 2.40118074,\n 'split_equal_batteries_only_flops': 22.342901919,\n 'split_unequal_batteries_only_flops': 28.20410194,\n 'split_equal_batteries_unlimited': 159.592077305,\n 'swarm_seq_equal_batteries_only_flops': 31.965074916,\n 'swarm_seq_unequal_batteries_only_flops': 28.928048314,\n 'swarm_seq_equal_batteries_unlimited': 159.82056885,\n 'swarm_rand_equal_batteries_only_flops': 31.528249111,\n 'swarm_rand_unequal_batteries_only_flops': 27.544788384,\n 'swarm_rand_equal_batteries_unlimited': 159.820492125,\n 'swarm_max_equal_batteries_only_flops': 31.965076866,\n 'swarm_max_unequal_batteries_only_flops': 29.812824528,\n 'swarm_max_equal_batteries_unlimited': 159.82364535}" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "communication_sizes_gbyte = {name: value/1000000000 for name, value in communication_sizes.items()}\n", + "communication_sizes_gbyte" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:35:16.616603600Z", + "start_time": "2023-11-23T14:35:16.543080400Z" + } + }, + "id": "72f766f68e45126c" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "data": { + "text/plain": "{'equal': [0.48023616, 22.342901919, 31.965074916, 31.528249111, 31.965076866],\n 'heterogeneous': [0.46102673,\n 28.20410194,\n 28.928048314,\n 27.544788384,\n 29.812824528],\n 'unlimited': [2.40118074,\n 159.592077305,\n 159.82056885,\n 159.820492125,\n 159.82364535]}" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batteries = [\"equal\", \"heterogeneous\", \"unlimited\"]\n", + "groups = {\"equal\": [], \"heterogeneous\": [], \"unlimited\": []}\n", + "for idx, (method, size) in enumerate(communication_sizes_gbyte.items()):\n", + " groups[batteries[idx % 3]].append(size)\n", + "groups" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:35:16.617602800Z", + "start_time": "2023-11-23T14:35:16.569089200Z" + } + }, + "id": "dac45be91922334e" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "data": { + "text/plain": "<Figure size 640x480 with 1 Axes>", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAosAAAHrCAYAAACn9tfQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAACbaUlEQVR4nOzdd1gUV9sG8HulLB2lF2n2hogldrGCqNiSYNSoqNHYo+hrN2CPJhqNRo0mgi1qEnsXRbBhFyXRWFEsIFYQUNqe7w8/Jq6wuKuLFO/fde2VzJkz5zw7s4vPnpkzIxNCCBARERER5aFUYQdAREREREUXk0UiIiIiUonJIhERERGpxGSRiIiIiFRiskhEREREKjFZJCIiIiKVmCwSERERkUpMFomIiIhIJSaLRERERKQSk0X66F28eBF9+/aFm5sbDAwMYGJigtq1a2Pu3Ll48uRJYYdXJNy6dQsymQyhoaEF0n5aWhqCg4MRERGRa11oaChkMhlu3bpVIH0XNzKZDMOGDSvsMJTIZDIEBwfnWyciIgIymQx//fXXhwnqLQr6M60OV1dXBAQESMv3799HcHAwoqOjc9UNCAiAiYnJhwuO6DW6hR0AUWFasWIFhgwZgsqVK+N///sfqlWrhszMTJw5cwbLli1DVFQUtmzZUthhFjp7e3tERUWhfPnyBdJ+Wloapk6dCgBo3ry50rr27dsjKioK9vb2BdI3UWHZsmULzMzMpOX79+9j6tSpcHV1Ra1atQovMKI3MFmkj1ZUVBQGDx6MNm3aYOvWrZDL5dK6Nm3aYPTo0di7d28hRlh0yOVyNGjQoFD6tra2hrW1daH0XZSkpaXByMiosMMgLXjx4gUMDQ3h6elZ2KEQqYWnoemjNWvWLMhkMixfvlwpUcyhr6+Pjh07SssKhQJz585FlSpVIJfLYWNjg969e+Pu3btK2zVv3hw1atRAVFQUGjVqBENDQ7i6uiIkJAQAsGvXLtSuXRtGRkZwd3fPlZAGBwdDJpPh4sWL+Pzzz2Fubg4LCwsEBgYiKysLV65cQdu2bWFqagpXV1fMnTtXaXtVp21zTgO+fqo3J9bTp0+jadOmMDIyQrly5fDdd99BoVBI9VSdsvv333/RvXt32NraQi6Xw9nZGb1790Z6ejoA4OHDhxgyZAiqVasGExMT2NjYoGXLljhy5IhS2znJ4NSpUyGTySCTyaTTc6rez8qVK+Hh4QEDAwNYWFigS5cuuHz5slKdnFN3169fR7t27WBiYgInJyeMHj1aijE/6hzzkSNHwtjYGMnJybm279atG2xtbZGZmSmVbdy4EQ0bNoSxsTFMTEzg4+OD8+fP5xl3TEwMvL29YWpqilatWinVWbNmDapWrQojIyN4eHhg586dufq/du0aevToARsbG8jlclStWhU///yzUp2XL19i9OjRqFWrlvRZa9iwIbZt25arveTkZAwYMACWlpYwMTFB27ZtcfXq1bfuxzf7CwwMhJ2dHQwNDeHl5aX0/tesWQOZTIaoqKhc206bNg16enq4f/++yvavX7+Ovn37omLFijAyMoKjoyP8/PwQExOjVnzbtm1DzZo1IZfLUa5cOSxcuFD6Tr75PiZMmAA3Nzfo6+vD0dERQ4cOxbNnz5Tqubq6okOHDti8eTM8PT1hYGAgjaK/fho6IiIC9erVAwD07dtX+h68eXr/bZ/lnO/q999/jzlz5sDV1RWGhoZo3rw5rl69iszMTIwfPx4ODg4wNzdHly5dkJiYqNa+oY+YIPoIZWVlCSMjI1G/fn21txk4cKAAIIYNGyb27t0rli1bJqytrYWTk5N4+PChVM/Ly0tYWlqKypUri99++03s27dPdOjQQQAQU6dOFe7u7mL9+vVi9+7dokGDBkIul4t79+5J2wcFBQkAonLlymL69OkiLCxMjB07Vuq7SpUq4qeffhJhYWGib9++AoDYtGmTtH1ISIgAIGJjY5XiP3TokAAgDh06lCvWihUrimXLlomwsDAxZMgQAUCsWrVKqhcbGysAiJCQEKksOjpamJiYCFdXV7Fs2TJx8OBBsXbtWuHv7y+Sk5OFEEL8+++/YvDgwWLDhg0iIiJC7Ny5U/Tv31+UKlVKiuPly5di7969AoDo37+/iIqKElFRUeL69esq38+sWbMEANG9e3exa9cusXr1alGuXDlhbm4url69KtXr06eP0NfXF1WrVhU//PCDOHDggPj222+FTCYTU6dO1coxv3DhggAgVqxYobTt06dPhVwuF4GBgVLZzJkzhUwmE/369RM7d+4UmzdvFg0bNhTGxsbin3/+UYpbT09PuLq6itmzZ4uDBw+Kffv2CSGEACBcXV3FJ598Iv744w+xe/du0bx5c6Grqytu3LghtfHPP/8Ic3Nz4e7uLlavXi32798vRo8eLUqVKiWCg4Oles+ePRMBAQFizZo1Ijw8XOzdu1eMGTNGlCpVSukzoFAoRIsWLYRcLhczZ84U+/fvF0FBQaJcuXICgAgKCsp3X+Z8/pycnESnTp3Ejh07xNq1a0WFChWEmZmZFHt6erqws7MTPXv2VNo+MzNTODg4iM8//zzffiIjI8Xo0aPFX3/9JSIjI8WWLVtE586dhaGhofj333+lenl9pvfs2SNKlSolmjdvLrZs2SL+/PNPUb9+feHq6ipe/+dSoVAIHx8foaurK6ZMmSL2798vfvjhB2FsbCw8PT3Fy5cvpbouLi7C3t5elCtXTqxcuVIcOnRInDp1SlrXp08fIYQQSUlJ0md98uTJ0vfgzp07Qgj1P8s578vFxUX4+fmJnTt3irVr1wpbW1tRqVIl0atXL9GvXz+xZ88esWzZMmFiYiL8/Pzy3adETBbpo5SQkCAAiC+++EKt+pcvXxYAxJAhQ5TKT548KQCIiRMnSmVeXl4CgDhz5oxU9vjxY6GjoyMMDQ2VEsPo6GgBQPz0009SWU6yOG/ePKW+atWqJQCIzZs3S2WZmZnC2tpadO3aVSrTNFkEIE6ePKlUt1q1asLHx0dazusf1pYtW4rSpUuLxMTEvHZZnrKyskRmZqZo1aqV6NKli1T+8OFDlQnHm+/n6dOnwtDQULRr106pXlxcnJDL5aJHjx5SWZ8+fQQA8ccffyjVbdeunahcuXK+sWpyzGvXri0aNWqkVG/JkiUCgIiJiZHi09XVFcOHD1eq9/z5c2FnZyf8/f1zxb1y5cpccQEQtra2UkIuxKvPc6lSpcTs2bOlMh8fH1G2bFmRlJSktP2wYcOEgYGBePLkSZ7vO+cY9e/fX3h6ekrle/bsEQDEwoULlerPnDlTo2Sxdu3aQqFQSOW3bt0Senp64quvvpLKgoKChL6+vnjw4IFUtnHjRgFAREZG5ttPXu8nIyNDVKxYUYwaNUoqz+szXa9ePeHk5CTS09OlsufPnwtLS0ulZDHnx83cuXOV+sqJcfny5VKZi4uL0NHREVeuXMkV2+vJohBCnD59OldMOdT9LOe8Lw8PD5GdnS2VL1iwQAAQHTt2VNp+5MiRAkCuzwnR63gamkgNhw4dAgClmYsA8Mknn6Bq1ao4ePCgUrm9vT3q1KkjLVtYWMDGxga1atWCg4ODVF61alUAwO3bt3P12aFDB6XlqlWrQiaTwdfXVyrT1dVFhQoV8txeXXZ2dvjkk0+UymrWrJlvm2lpaYiMjIS/v/9brydctmwZateuDQMDA+jq6kJPTw8HDx7MdcpYXVFRUXjx4kWuY+Hk5ISWLVvmOhYymQx+fn5KZW97f4Bmx7xv3744fvw4rly5IpWFhISgXr16qFGjBgBg3759yMrKQu/evZGVlSW9DAwM4OXlledM8E8//TTP2Fq0aAFTU1Np2dbWFjY2NtJ7evnyJQ4ePIguXbrAyMhIqb927drh5cuXOHHihLT9n3/+icaNG8PExEQ6Rr/99pvSMcrZHz179lSKpUePHir3YV569OihdErXxcUFjRo1ktoHgMGDBwN4NQEtx+LFi+Hu7o5mzZrl235WVhZmzZqFatWqQV9fH7q6utDX18e1a9fy/cylpqbizJkz6Ny5M/T19aVyExOTXJ+f8PBwALk/G59//jmMjY1zfQZr1qyJSpUq5Ru3OjT5LLdr1w6lSv33T3zO35r27dsr1cspj4uLe+/4qORiskgfJSsrKxgZGSE2Nlat+o8fPwaAPGfkOjg4SOtzWFhY5Kqnr6+fqzznH6WXL1/mqp9XXSMjIxgYGOQqz2t7dVlaWuYqk8vlePHihcptnj59iuzsbJQtWzbftufPn4/Bgwejfv362LRpE06cOIHTp0+jbdu2+bafH02PRV77TC6Xv3WfadJPz549IZfLpWs6L126hNOnT6Nv375SnQcPHgAA6tWrBz09PaXXxo0b8ejRo1xxvz5T9nVvO2aPHz9GVlYWFi1alKuvdu3aAYDU3+bNm+Hv7w9HR0esXbsWUVFROH36NPr166e0jx4/fgxdXd1cfdvZ2eUZoyp51bezs1Pan7a2tujWrRt++eUXZGdn4+LFizhy5IhatwwKDAzElClT0LlzZ+zYsQMnT57E6dOn4eHh8dbPtBACtra2uda9WZazL978oSSTyXK9FyDvz9C70OSzrOpvjSZ/g4hycDY0fZR0dHTQqlUr7NmzB3fv3n1r0pPzD2R8fHyuuvfv34eVlVWBxaqpnH9M3pzA8WYy8j4sLCygo6OTa3LPm9auXYvmzZtj6dKlSuXPnz9/575fPxZv0uax0OSYlylTBp06dcLq1asxY8YMhISEwMDAAN27d5fq5NT/66+/4OLi8tb+35xQoYkyZcpAR0cHvXr1wtChQ/Os4+bmBuDVMXJzc8PGjRuV+nzz82NpaYmsrCw8fvxYKWFMSEjQKLa86ickJORKQr/55husWbMG27Ztw969e1G6dOlco5p5Wbt2LXr37o1Zs2YplT969AilS5dWuV2ZMmUgk8mkpD6/mHP2xcOHD5USRiEEEhISpIkqOd7nWBIVBRxZpI/WhAkTIITAgAEDkJGRkWt9ZmYmduzYAQBo2bIlgFf/EL3u9OnTuHz5cq6ZqoXJ1dUVwKubjb9u+/btWusjZxbrn3/+mW8SKpPJcs00v3jxYq6Zrjl11BltbNiwIQwNDXMdi7t37yI8PFxrx0LTY963b1/cv38fu3fvxtq1a9GlSxel5MTHxwe6urq4ceMG6tatm+dLW4yMjNCiRQucP38eNWvWzLOvnORMJpNBX19fKaFJSEjINRu6RYsWAIB169Yplf/+++8axbZ+/XoIIaTl27dv4/jx47nur1mnTh00atQIc+bMwbp16xAQEABjY+O3tp/XZ27Xrl24d+9evtsZGxujbt262Lp1q9Lfg5SUlFwzzXOO/ZufjU2bNiE1NfWdP4OafA+IPiSOLNJHq2HDhli6dCmGDBmCOnXqYPDgwahevToyMzNx/vx5LF++HDVq1ICfnx8qV66MgQMHYtGiRShVqhR8fX1x69YtTJkyBU5OThg1alRhvx1JvXr1ULlyZYwZMwZZWVkoU6YMtmzZgqNHj2q1n/nz56NJkyaoX78+xo8fjwoVKuDBgwfYvn07fvnlF5iamqJDhw6YPn06goKC4OXlhStXrmDatGlwc3NDVlaW1JapqSlcXFywbds2tGrVChYWFrCyspIS39eVLl0aU6ZMwcSJE9G7d290794djx8/xtSpU2FgYICgoCCtvD9Nj7m3tzfKli2LIUOGICEhQekUNPAqiZ82bRomTZqEmzdvom3btihTpgwePHiAU6dOwdjYWLqlijYsXLgQTZo0QdOmTTF48GC4urri+fPnuH79Onbs2CFdd5dzW5chQ4bgs88+w507dzB9+nTY29vj2rVrSu+vWbNmGDt2LFJTU1G3bl0cO3YMa9as0SiuxMREdOnSBQMGDEBSUhKCgoJgYGCACRMm5Kr7zTffoFu3bpDJZBgyZIha7Xfo0AGhoaGoUqUKatasibNnz+L7779/69kD4NWtedq3bw8fHx988803yM7Oxvfffw8TExOlpzm1adMGPj4+GDduHJKTk9G4cWNcvHgRQUFB8PT0RK9evdTfIa8pX748DA0NsW7dOlStWhUmJiZwcHBQus6ZqFAU7vwaosIXHR0t+vTpI5ydnYW+vr50+4tvv/1WaaZvdna2mDNnjqhUqZLQ09MTVlZW4ssvv5RubZHDy8tLVK9ePVc/Li4uon379rnKAYihQ4dKyzmzoV+/HY8Qr2ZDGhsb59o+r/6uXr0qvL29hZmZmbC2thbDhw8Xu3btynM2dF6x9unTR7i4uEjLec0cFUKIS5cuic8//1xYWloKfX194ezsLAICAqRbh6Snp4sxY8YIR0dHYWBgIGrXri22bt2aq30hhDhw4IDw9PQUcrlcAJBmiaqa3f3rr7+KmjVrCn19fWFubi46deqkdPuZ/PZZzj5+G3WPeY6JEydKt4d5fSbq67Zu3SpatGghzMzMhFwuFy4uLuKzzz4TBw4ceGvcQuT+vOR4c2atEK+OW79+/YSjo6PQ09MT1tbWolGjRmLGjBlK9b777jvh6uoq5HK5qFq1qlixYkWe++jZs2eiX79+onTp0sLIyEi0adNG/PvvvxrNhl6zZo0YMWKEsLa2FnK5XDRt2lTpzgGvS09PF3K5XLRt2zbftl/39OlT0b9/f2FjYyOMjIxEkyZNxJEjR4SXl5fw8vJS2jd5faa3bNki3N3dpc/zd999J0aMGCHKlCmjVO/Fixdi3LhxwsXFRejp6Ql7e3sxePBg8fTpU6V6qr73OevePGbr168XVapUEXp6ekr7Vd3Pcs77+v7775Xq5ez/P//8U6k85/t1+vTpPGMkEkIImRCvnQ8gIiIqInbs2IGOHTti165d0sScDy0zMxO1atWCo6Mj9u/fXygxEBU2JotERFSkXLp0Cbdv38Y333wDY2NjnDt37oNNEunfvz/atGkDe3t7JCQkYNmyZYiMjMT+/fvRunXrDxIDUVHDaxaJiKhIGTJkCI4dO4batWtj1apVH3Q28fPnzzFmzBg8fPgQenp6qF27Nnbv3s1EkT5qHFkkIiIiIpUK9dY5s2fPRr169WBqagobGxt07txZ6QkIwKv7VgUHB8PBwUF6GPo///yjVCc9PR3Dhw+HlZUVjI2N0bFjx7fe/42IiIiI3q5Qk8XIyEgMHToUJ06cQFhYGLKysuDt7Y3U1FSpzty5czF//nwsXrwYp0+fhp2dHdq0aaN0U9+RI0diy5Yt2LBhA44ePYqUlBR06NAB2dnZhfG2iIiIiEqMInUa+uHDh7CxsUFkZCSaNWsGIQQcHBwwcuRIjBs3DsCrUURbW1vMmTMHX3/9NZKSkmBtbY01a9agW7duAF49XcHJyQm7d++Gj49PYb4lIiIiomKtSE1wSUpKAvDfsytjY2ORkJAAb29vqY5cLoeXlxeOHz+Or7/+GmfPnkVmZqZSHQcHB9SoUQPHjx/PM1lMT09XepSVQqHAkydPYGlpyccyERER0UdBCIHnz5/DwcEBpUqpPtlcZJJFIQQCAwPRpEkT1KhRA8B/z+N88yHutra2uH37tlRHX18fZcqUyVVH1TNLZ8+erdUnJRAREREVV3fu3Mn3KUdFJlkcNmwYLl68mOcjyd4c7RNCvHUEML86EyZMQGBgoLSclJQEZ2dn3LlzB2ZmZu8QPREREVHxkpycDCcnJ5iamuZbr0gki8OHD8f27dtx+PBhpczWzs4OwKvRQ3t7e6k8MTFRGm20s7NDRkYGnj59qjS6mJiYiEaNGuXZn1wuz/WgeQAwMzNjskhEREQflbcNwBXqbGghBIYNG4bNmzcjPDwcbm5uSuvd3NxgZ2eHsLAwqSwjIwORkZFSIlinTh3o6ekp1YmPj8fff/+tMlkkIiIiIvUU6sji0KFD8fvvv2Pbtm0wNTWVrjE0NzeHoaEhZDIZRo4ciVmzZqFixYqoWLEiZs2aBSMjI/To0UOq279/f4wePRqWlpawsLDAmDFj4O7uzjvuExEREb2nQk0Wly5dCgBo3ry5UnlISAgCAgIAAGPHjsWLFy8wZMgQPH36FPXr18f+/fuVzq//+OOP0NXVhb+/P168eIFWrVohNDQUOjo6H+qtEBEREZVIReo+i4UlOTkZ5ubmSEpKUnnNohACWVlZvNE3ERFBR0cHurq6vN0aFWvq5D9AEZngUtRlZGQgPj4eaWlphR0KEREVEUZGRrC3t4e+vn5hh0JUoJgsvoVCoUBsbCx0dHTg4OAAfX19/pIkIvqICSGQkZGBhw8fIjY2FhUrVsz3hsZExR2TxbfIyMiAQqGAk5MTjIyMCjscIiIqAgwNDaGnp4fbt28jIyMDBgYGhR0SUYHhTyE18VcjERG9jv8u0MeCn3QiIiIiUonJIhERERGpxGsW35Hr+F0ftL9b37X/oP1pk0wmw5YtW9C5c+cP33mw+QfsK0njTZo3b45atWphwYIF2o+HihX3Ve4frK+YPjEfrK8cERERaNGiBZ4+fYrSpUsjNDQUI0eOxLNnz96r3YL6+3Lr1i24ubnh/PnzqFWrllbbJipuOLJIVIwFBwfzHzIqlrp164arV6++dzvx8fHw9fUF8CrBk8lkiI6Ofu92ieg/HFkkImRmZkJPT6+ww6CPiKGhIQwNDd+7HTs7Oy1EQ0T54chiCSaEwNy5c1GuXDkYGhrCw8MDf/31l7R+9+7dqFSpEgwNDdGiRQuEhoZCJpNJp4XyGrVasGABXF1dpeXTp0+jTZs2sLKygrm5Oby8vHDu3LkP8O5KDoVCgbFjx8LCwgJ2dnYIDg6W1iUlJWHgwIGwsbGBmZkZWrZsiQsXLgAAQkNDMXXqVFy4cAEymQwymQyhoaFv3Q7479iuXLkS5cqVg1wuhxACcXFx6NSpE0xMTGBmZgZ/f388ePBAKd4ZM2bAxsYGpqam+OqrrzB+/Phcn5OQkBBUrVoVBgYGqFKlCpYsWSKtyxn92bx5M1q0aAEjIyN4eHggKipKqY3jx4+jWbNmMDQ0hJOTE0aMGIHU1FRp/dOnT9G7d2+UKVMGRkZG8PX1xbVr13K9x9e9+fmNiIjAJ598AmNjY5QuXRqNGzfG7du333rMPkaurq65LpeoVauW9HmVyWT49ddf0aVLFxgZGaFixYrYvn27yvZCQ0NRunRpafn1z6SzszNMTEwwePBgZGdnY+7cubCzs4ONjQ1mzpyp1I5MJsPWrVsBAG5ubgAAT09PyGQypUfJ5veZBIBTp07B09MTBgYGqFu3Ls6fP6/ZDiIqwZgslmCTJ09GSEgIli5din/++QejRo3Cl19+icjISNy5cwddu3ZFu3btEB0dLf2jr6nnz5+jT58+OHLkCE6cOIGKFSuiXbt2eP78eQG8o5Jp1apVMDY2xsmTJzF37lxMmzYNYWFhEEKgffv2SEhIwO7du3H27FnUrl0brVq1wpMnT9CtWzeMHj0a1atXR3x8POLj49GtW7e3bpfj+vXr+OOPP7Bp0ybptF3nzp3x5MkTREZGIiwsDDdu3EC3bt2kbdatW4eZM2dizpw5OHv2LJydnaVnvOdYsWIFJk2ahJkzZ+Ly5cuYNWsWpkyZglWrVinVmzRpEsaMGYPo6GhUqlQJ3bt3R1ZWFgAgJiYGPj4+6Nq1Ky5evIiNGzfi6NGjGDZsmLR9QEAAzpw5g+3btyMqKgpCCLRr1w6ZmZlq7fesrCx07twZXl5euHjxIqKiojBw4EDedP89TJ06Ff7+/rh48SLatWuHnj17Kn3m3ubGjRvYs2cP9u7di/Xr12PlypVo37497t69i8jISMyZMweTJ0/GiRMn8tz+1KlTAIADBw4gPj4emzdvBvD2z2Rqaio6dOiAypUr4+zZswgODsaYMWPec28QlRw8DV1CpaamYv78+QgPD0fDhg0BAOXKlcPRo0fxyy+/wNXVFeXKlcOPP/4ImUyGypUrIyYmBnPmzNGon5YtWyot//LLLyhTpgwiIyPRoUMHrb2fkqxmzZoICgoCAFSsWBGLFy/GwYMHoaOjg5iYGCQmJkIulwMAfvjhB2zduhV//fUXBg4cCBMTE+jq6iqdigsPD3/rdsCrG86vWbMG1tbWAICwsDBcvHgRsbGxcHJyAgCsWbMG1atXx+nTp1GvXj0sWrQI/fv3R9++fQEA3377Lfbv34+UlBSp/+nTp2PevHno2rUrgFejPZcuXcIvv/yCPn36SPXGjBmD9u1fTdyaOnUqqlevjuvXr6NKlSr4/vvv0aNHD4wcOVLaLz/99BO8vLywdOlS3LlzB9u3b8exY8fQqFEjAK8SWScnJ2zduhWff/75W/d7cnIykpKS0KFDB5QvXx4AULVqVbWPG+UWEBCA7t27AwBmzZqFRYsW4dSpU2jbtq1a2ysUCqxcuRKmpqaoVq0aWrRogStXrmD37t0oVaoUKleujDlz5iAiIgINGjTItX3OZ9nS0lLpO/G2z+S6deuQnZ2NlStXwsjICNWrV8fdu3cxePDg990lRCUCk8US6tKlS3j58iXatGmjVJ6RkQFPT0+8ePECDRo0UBpFyUkqNZGYmIhvv/0W4eHhePDgAbKzs5GWloa4uLj3fg8fi5o1ayot29vbIzExEWfPnkVKSgosLS2V1r948QI3btxQ2Z6627m4uEj/uALA5cuX4eTkJCWKAFCtWjWULl0aly9fRr169XDlyhUMGTJEqd1PPvkE4eHhAICHDx/izp076N+/PwYMGCDVycrKgrm58sz019+3vb09gFefpypVquDs2bO4fv061q1bJ9URQkiP37x27Rp0dXVRv359ab2lpSUqV66My5cvq9w3r7OwsEBAQAB8fHzQpk0btG7dGv7+/lIspLnXj6mxsTFMTU2RmJio9vaurq4wNTWVlm1tbaGjo6N082tbW1uN2lTnM3n58mV4eHgoPaXrXf4eEpVUTBZLKIVCAQDYtWsXHB0dldbJ5XIMHz78rW2UKlUKQgilsjdP8QUEBODhw4dYsGABXFxcIJfL0bBhQ2RkZLznO/h4vDmxRCaTQaFQQKFQwN7eHhEREbm2ef1arzepu52xsbHSOiFEnqdg3yx/s87rn5Gcz92KFSuUEjkA0NHRUVp+/X3ntJmzvUKhwNdff40RI0bkisfZ2VnlLNrXY1Xn8xsSEoIRI0Zg79692LhxIyZPnoywsLA8R60+dursT1WfZXXltf37tqnOZ/LN90VEypgsllDVqlWDXC5HXFwcvLy88lyfc1F4jjevA7K2tkZCQoLSP8Bv3pLiyJEjWLJkCdq1awcAuHPnDh49eqS9N/IRq127NhISEqCrq6s0KeN1+vr6yM7O1ni7vFSrVg1xcXG4c+eONLp46dIlJCUlSadnK1eujFOnTqFXr17SdmfOnJH+39bWFo6Ojrh58yZ69uypdt9vql27Nv755x9UqFBBZaxZWVk4efKkdBr68ePHuHr1qhSrOp9f4NVkCE9PT0yYMAENGzbE77//zmQxD9bW1oiPj5eWk5OTERsbW4gR5aavrw8ASt8JdT6T1apVw5o1a/DixQtphraq6yKJPkac4FJCmZqaYsyYMRg1ahRWrVqFGzdu4Pz58/j555+xatUqDBo0CDdu3EBgYCCuXLmC33//XZpJm6N58+Z4+PAh5s6dixs3buDnn3/Gnj17lOpUqFABa9asweXLl3Hy5En07NlTK7fDIKB169Zo2LAhOnfujH379uHWrVs4fvw4Jk+eLCVorq6uiI2NRXR0NB49eoT09HS1tlPVX82aNdGzZ0+cO3cOp06dQu/eveHl5YW6desCAIYPH47ffvsNq1atwrVr1zBjxgxcvHhRabQxODgYs2fPxsKFC3H16lXExMQgJCQE8+fPV/u9jxs3DlFRURg6dCiio6Nx7do1bN++XRoRr1ixIjp16oQBAwbg6NGjuHDhAr788ks4OjqiU6dOAN7++Y2NjcWECRMQFRWF27dvY//+/UrJJilr2bIl1qxZgyNHjuDvv/9Gnz59co0WFzYbGxsYGhpi7969ePDgAZKSXt0o/22fyR49eqBUqVLo378/Ll26hN27d+OHH34ozLdCVLQIEklJSQKASEpKyrXuxYsX4tKlS+LFixeFENn7USgUYuHChaJy5cpCT09PWFtbCx8fHxEZGSmEEGLHjh2iQoUKQi6Xi6ZNm4qVK1cKAOLp06dSG0uXLhVOTk7C2NhY9O7dW8ycOVO4uLhI68+dOyfq1q0r5HK5qFixovjzzz+Fi4uL+PHHH6U6AMSWLVs+zJsuZry8vMQ333yjVNapUyfRp08fIYQQycnJYvjw4cLBwUHo6ekJJycn0bNnTxEXFyeEEOLly5fi008/FaVLlxYAREhIiFrbBQUFCQ8Pj1zx3L59W3Ts2FEYGxsLU1NT8fnnn4uEhASlOtOmTRNWVlbCxMRE9OvXT4wYMUI0aNBAqc66detErVq1hL6+vihTpoxo1qyZ2Lx5sxBCiNjYWAFAnD9/Xqr/9OlTAUAcOnRIKjt16pRo06aNMDExEcbGxqJmzZpi5syZ0vonT56IXr16CXNzc2FoaCh8fHzE1atXleLI7/ObkJAgOnfuLOzt7YW+vr5wcXER3377rcjOzlZ5vD5mSUlJwt/fX5iZmQknJycRGhoqPDw8RFBQkBAi7++5ubm59Jk8dOiQ0t+XkJAQYW5uLtXN6zPZp08f0alTJ6WyN78zb/a7YsUK4eTkJEqVKiW8vLyk8vw+k0IIERUVJTw8PIS+vr6oVauW2LRpU67P6ZuK878PRELkn/+8TiYEL9ZITk6Gubk5kpKSYGZmprTu5cuXiI2NhZubGwwMDAopwg/jzcdxEamjTZs2sLOzw5o1awo7FKIP6mP694FKpvzyn9fxmkUiUltaWhqWLVsGHx8f6OjoYP369Thw4ADCwsIKOzQiIiogTBaJSG0ymQy7d+/GjBkzkJ6ejsqVK2PTpk1o3bp1YYdGREQFhMkiSZo3b85bSFC+DA0NceDAgcIOg4iIPiDOhiYiIiIilZgsqokjbkRE9Dr+u0AfCyaLb5Hz9IC0tLRCjoSIiIqSnH8X3nzKDFFJw2sW30JHRwelS5eWnkVqZGSU5yPRiIjo4yCEQFpaGhITE1G6dOkid3NyIm1jsqgGOzs7ANDo4fVERFSylS5dWvr3gagkY7KoBplMBnt7e9jY2CAzM7OwwyEiokKmp6fHEUX6aDBZ1ICOjg7/OBAREdFHhRNciIiIiEglJotEREREpBKTRSIiIiJSickiEREREanEZJGIiIiIVGKySEREREQqMVkkIiIiIpWYLBIRERGRSkwWiYiIiEglJotEREREpBKTRSIiIiJSickiEREREalUqMni4cOH4efnBwcHB8hkMmzdulVpvUwmy/P1/fffS3WaN2+ea/0XX3zxgd8JERERUclUqMliamoqPDw8sHjx4jzXx8fHK71WrlwJmUyGTz/9VKnegAEDlOr98ssvHyJ8IiIiohJPtzA79/X1ha+vr8r1dnZ2Ssvbtm1DixYtUK5cOaVyIyOjXHWJiIiI6P0Vm2sWHzx4gF27dqF///651q1btw5WVlaoXr06xowZg+fPn+fbVnp6OpKTk5VeRERERJRboY4samLVqlUwNTVF165dlcp79uwJNzc32NnZ4e+//8aECRNw4cIFhIWFqWxr9uzZmDp1akGHTERERFTsyYQQorCDAF5NZtmyZQs6d+6c5/oqVaqgTZs2WLRoUb7tnD17FnXr1sXZs2dRu3btPOukp6cjPT1dWk5OToaTkxOSkpJgZmb2zu+BiIiIqLhITk6Gubn5W/OfYjGyeOTIEVy5cgUbN258a93atWtDT08P165dU5ksyuVyyOVybYdJREREVOIUi2sWf/vtN9SpUwceHh5vrfvPP/8gMzMT9vb2HyAyIiIiopKtUEcWU1JScP36dWk5NjYW0dHRsLCwgLOzM4BXQ6R//vkn5s2bl2v7GzduYN26dWjXrh2srKxw6dIljB49Gp6enmjcuPEHex9EREREJVWhJotnzpxBixYtpOXAwEAAQJ8+fRAaGgoA2LBhA4QQ6N69e67t9fX1cfDgQSxcuBApKSlwcnJC+/btERQUBB0dnQ/yHoiIiIhKsiIzwaUwqXuBJxEREVFJoW7+UyyuWSQiIiKiwsFkkYiIiIhUYrJIRERERCoxWSQiIiIilZgsEhEREZFKTBaJiIiISCUmi0RERESkEpNFIqJCdvjwYfj5+cHBwQEymQxbt25VWh8QEACZTKb0atCggVKdGzduoEuXLrC2toaZmRn8/f3x4MGDfPvNysrC5MmT4ebmBkNDQ5QrVw7Tpk2DQqEAAGRmZmLcuHFwd3eHsbExHBwc0Lt3b9y/f1+pneXLl6N58+YwMzODTCbDs2fP3nufFFfF/Vh+DErKMRJCwNfXN8/3oG1MFolKkML6IxgcHJyrXTs7O6U6Dx48QEBAABwcHGBkZIS2bdvi2rVr0vonT55g+PDhqFy5MoyMjODs7IwRI0YgKSnp/XZKMZCamgoPDw8sXrxYZZ22bdsiPj5eeu3evVtpe29vb8hkMoSHh+PYsWPIyMiAn5+f9A9RXubMmYNly5Zh8eLFuHz5MubOnYvvv/8eixYtAgCkpaXh3LlzmDJlCs6dO4fNmzfj6tWr6Nixo1I7aWlpaNu2LSZOnPiee6L4K+7H8mNQUo7RggULIJPJ3nEvaEiQSEpKEgBEUlJSYYdC9F52794tJk2aJDZt2iQAiC1btiit79Onj2jbtq2Ij4+XXo8fP5bWp6SkiHLlyokuXbqIixcviosXL4pOnTqJevXqiezsbJX9BgUFierVqyu1m5iYKK1XKBSiQYMGomnTpuLUqVPi33//FQMHDhTOzs4iJSVFCCFETEyM6Nq1q9i+fbu4fv26OHjwoKhYsaL49NNPtbuTijhVx61Tp04qt9m3b58oVaqU0t+wJ0+eCAAiLCxM5Xbt27cX/fr1Uyrr2rWr+PLLL1Vuc+rUKQFA3L59O9e6Q4cOCQDi6dOnKrf/mBTnY/mxKK7HKDo6WpQtW1bEx8fn+R7UpW7+w5FFohLE19cXM2bMQNeuXVXWkcvlsLOzk14WFhbSumPHjuHWrVsIDQ2Fu7s73N3dERISgtOnTyM8PDzfvnV1dZXatba2ltZdu3YNJ06cwNKlS1GvXj1UrlwZS5YsQUpKCtavXw8AqFGjBjZt2gQ/Pz+UL18eLVu2xMyZM7Fjxw5kZWW9554p/iIiImBjY4NKlSphwIABSExMlNalp6dDJpNBLpdLZQYGBihVqhSOHj2qss0mTZrg4MGDuHr1KgDgwoULOHr0KNq1a6dym6SkJMhkMpQuXfr939RHisey6CvKxygtLQ3du3fH4sWLc53BKShMFok+MgXxRxB4lRA6ODjAzc0NX3zxBW7evKnUbk5bOXR0dKCvr59vuznPK9XV1dX4fZYkvr6+WLduHcLDwzFv3jycPn0aLVu2lPZrgwYNYGxsjHHjxiEtLQ2pqan43//+B4VCgfj4eJXtjhs3Dt27d0eVKlWgp6cHT09PjBw5Et27d8+z/suXLzF+/Hj06NEj3+fIkmo8lkVfUT9Go0aNQqNGjdCpUyftvvH8vNO4ZQnD09BUEiGPUxMbNmwQO3fuFDExMWL79u3Cw8NDVK9eXbx8+VIIIURiYqIwMzMT33zzjUhNTRUpKSli6NChAoAYOHCgyr52794t/vrrL3Hx4kURFhYmvLy8hK2trXj06JEQQoiMjAzh4uIiPv/8c/HkyRORnp4uZs+eLQAIb2/vPNt89OiRcHZ2FpMmTdLODikm8jpub7p//77Q09MTmzZtksr27dsnypUrJ2QymdDR0RFffvmlqF27thg8eLDKdtavXy/Kli0r1q9fLy5evChWr14tLCwsRGhoaK66GRkZolOnTsLT01Pl30qehlZWnI/lx6K4HaNt27aJChUqiOfPn2v0HlRRN/9hsiiYLFLJ9CH/CL4pJSVF2Nrainnz5kllZ86cER4eHgKA0NHRET4+PsLX11f4+vrm2j4pKUnUr19ftG3bVmRkZKjdb0mg7h/+ChUqiO+++y5X+cOHD6VkzdbWVsydO1dlG2XLlhWLFy9WKps+fbqoXLmyUllGRobo3LmzqFmzpvQDIC9MFpUV52P5sShux+ibb76R/jbnvACIUqVKCS8vr7e+jzepm/983Od2iD5y9vb2cHFxUZqV7O3tjRs3buDRo0fQ1dVF6dKlYWdnBzc3N7XbNTY2hru7u1K7derUQXR0NJKSkpCRkQFra2vUr18fdevWVdr2+fPnaNu2LUxMTLBlyxbo6em9/xstYR4/fow7d+7A3t4+1zorKysAQHh4OBITE/Od7ZqWloZSpZSvRtLR0VGa0ZmZmQl/f39cu3YNhw4dgqWlpZbeBQE8lsVBUTpG48ePx1dffaVU5u7ujh9//BF+fn4avzd1MVkk+ohp64/gm9LT03H58mU0bdo01zpzc3MAr65xPHPmDKZPny6tS05Oho+PD+RyObZv3650jWNJlpKSguvXr0vLsbGxiI6OhoWFBSwsLBAcHIxPP/0U9vb2uHXrFiZOnAgrKyt06dJF2iYkJARVq1aFtbU1oqKi8M0332DUqFGoXLmyVKdVq1bo0qULhg0bBgDw8/PDzJkz4ezsjOrVq+P8+fOYP38++vXrB+DVfeE+++wznDt3Djt37kR2djYSEhIAABYWFtDX1wcAJCQkICEhQXoPMTExMDU1hbOzs9IEqo9BcT+WH4PifIxyJhC+ydnZWaMf9BrTeMyyBOJpaCopnj9/Ls6fPy/Onz8vAIj58+eL8+fPi9u3b4vnz5+L0aNHi+PHj4vY2Fhx6NAh0bBhQ+Ho6CiSk5OlNlauXCmioqLE9evXxZo1a4SFhYUIDAxU6qdly5Zi0aJF0vLo0aNFRESEuHnzpjhx4oTo0KGDMDU1Fbdu3ZLq/PHHH+LQoUPixo0bYuvWrcLFxUV07dpVWp+cnCzq168v3N3dxfXr15Vuw5OVlVWAe63w5Zy+ffPVp08fkZaWJry9vYW1tbXQ09MTzs7Ook+fPiIuLk6pjXHjxglbW1uhp6cnKlasKObNmycUCoVSHRcXFxEUFCQtJycni2+++UY4OzsLAwMDUa5cOTFp0iSRnp4uhBAiNjY2z7gAiEOHDkntBAUF5VknJCSkoHZZkVXcj+XHoKQdI/CaxQ+DySKVFIX1R7Bbt27C3t5e6OnpCQcHB9G1a1fxzz//KG2zcOFCUbZsWanvyZMnS38k84sdgIiNjdX6viIi+tipm//IhBBCiwOVxVJycjLMzc2l23QQERERlXTq5j+8zyIRERERqcRkkYiIiIhUYrJIRERERCoxWSQiIiIilZgsEhEREZFKvCk30UfOfZX7B+0vpk/MB+2vJPuQx47HTXt43Iqnj/m4cWSRiIiIiFRiskhEREREKjFZJCIiIiKVmCwSERERkUpMFomIiIhIJSaLRERERKQSk0UiIiIiUonJIhERERGpxGSRiIiIiFRiskhEREREKjFZJCIiIiKVmCwSERERkUpMFomIiIhIJV11KiUnJ2vcsJmZmcbbEBEREVHRolayWLp0achkMrUblclkuHr1KsqVK/fOgRERERFR4VMrWQSAv/76CxYWFm+tJ4RAu3bt1Grz8OHD+P7773H27FnEx8djy5Yt6Ny5s7Q+ICAAq1atUtqmfv36OHHihLScnp6OMWPGYP369Xjx4gVatWqFJUuWoGzZsuq9MSIiIiJSSa1k0cXFBc2aNYOlpaVajZYrVw56enpvrZeamgoPDw/07dsXn376aZ512rZti5CQEGlZX19faf3IkSOxY8cObNiwAZaWlhg9ejQ6dOiAs2fPQkdHR614iYiIiChvaiWLsbGxGjX6999/q1XP19cXvr6++daRy+Wws7PLc11SUhJ+++03rFmzBq1btwYArF27Fk5OTjhw4AB8fHw0ipuIiIiIlGllNvSzZ8+00UyeIiIiYGNjg0qVKmHAgAFITEyU1p09exaZmZnw9vaWyhwcHFCjRg0cP35cZZvp6elITk5WehERERFRbhoni3PmzMHGjRulZX9/f1haWsLR0REXLlzQanC+vr5Yt24dwsPDMW/ePJw+fRotW7ZEeno6ACAhIQH6+vooU6aM0na2trZISEhQ2e7s2bNhbm4uvZycnLQaNxEREVFJoXGy+Msvv0jJVVhYGMLCwrBnzx74+vrif//7n1aD69atG9q3b48aNWrAz88Pe/bswdWrV7Fr1658txNC5Dt7e8KECUhKSpJed+7c0WrcRERERCWF2rOhc8THx0vJ4s6dO+Hv7w9vb2+4urqifv36Wg/wdfb29nBxccG1a9cAAHZ2dsjIyMDTp0+VRhcTExPRqFEjle3I5XLI5fICjZWIiIioJNB4ZLFMmTLSSNzevXuliSVCCGRnZ2s3ujc8fvwYd+7cgb29PQCgTp060NPTQ1hYmFQnPj4ef//9d77JIhERERGpR+ORxa5du6JHjx6oWLEiHj9+LM1mjo6ORoUKFTRqKyUlBdevX5eWY2NjER0dDQsLC1hYWCA4OBiffvop7O3tcevWLUycOBFWVlbo0qULAMDc3Bz9+/fH6NGjYWlpCQsLC4wZMwbu7u5SEktERERE707jZPHHH3+Eq6sr7ty5g7lz58LExATAqxG9IUOGaNTWmTNn0KJFC2k5MDAQANCnTx8sXboUMTExWL16NZ49ewZ7e3u0aNECGzduhKmpqVI8urq68Pf3l27KHRoaynssEhEREWmBxsminp4exowZk6t85MiRGnfevHlzCCFUrt+3b99b2zAwMMCiRYuwaNEijfsnIiIiovy9030W16xZgyZNmsDBwQG3b98GACxYsADbtm3TanBEREREVLg0ThaXLl2KwMBA+Pr64tmzZ9KkltKlS2PBggXajo+IiIiICpHGyeKiRYuwYsUKTJo0Sem6wLp16yImJkarwRERERFR4dI4WYyNjYWnp2eucrlcjtTUVK0ERURERERFg8bJopubG6Kjo3OV79mzB9WqVdNGTERERERURGg8G/p///sfhg4dipcvX0IIgVOnTmH9+vWYPXs2fv3114KIkYiIiIgKicbJYt++fZGVlYWxY8ciLS0NPXr0gKOjIxYuXIgvvviiIGIkIiIiokKicbIIAAMGDMCAAQPw6NEjKBQK2NjYaDsuIiIiIioCNL5msWXLlnj27BkAwMrKSkoUk5OT0bJlS60GR0RERESFS+NkMSIiAhkZGbnKX758iSNHjmglKCIiIiIqGtQ+DX3x4kXp/y9duoSEhARpOTs7G3v37oWjo6N2oyMiIiKiQqV2slirVi3IZDLIZLI8TzcbGhry+cxEREREJYzayWJsbCyEEChXrhxOnToFa2traZ2+vj5sbGyUnuhCRERERMWf2smii4sLAOD58+cwNjYusICIiIiIqOjQeIKLra0t+vXrh6NHjxZEPERERERUhGicLK5fvx5JSUlo1aoVKlWqhO+++w73798viNiIiIiIqJBpnCz6+flh06ZNuH//PgYPHoz169fDxcUFHTp0wObNm5GVlVUQcRIRERFRIdA4WcxhaWmJUaNG4cKFC5g/fz4OHDiAzz77DA4ODvj222+RlpamzTiJiIiIqBC80+P+ACAhIQGrV69GSEgI4uLi8Nlnn6F///64f/8+vvvuO5w4cQL79+/XZqxERERE9IFpnCxu3rwZISEh2LdvH6pVq4ahQ4fiyy+/ROnSpaU6tWrVgqenpzbjJCIiIqJCoHGy2LdvX3zxxRc4duwY6tWrl2edcuXKYdKkSe8dHBEREREVLo2Txfj4eBgZGeVbx9DQEEFBQe8cFBEREREVDRoni68nikIIHDp0CC9evECjRo1QpkwZrQZHRERERIVL7dnQz549Q58+feDu7o4BAwYgOTkZTZs2RevWreHn54cqVarg4sWLBRkrEREREX1gaieLY8aMQVRUFLp164aYmBi0bdsW2dnZiIqKwsmTJ1GtWjVep0hERERUwqh9GnrPnj34/fff4eXlhb59+8LJyQnh4eGoX78+AGDOnDno2LFjgQVKRERERB+e2iOLDx48QKVKlQAAjo6OMDAwgJOTk7Te2dkZDx8+1H6ERERERFRo1E4WFQoFdHR0pGUdHR3IZDJp+fX/JyIiIqKSQaPZ0L/++itMTEwAAFlZWQgNDYWVlRUA4Pnz59qPjoiIiIgKldrJorOzM1asWCEt29nZYc2aNbnqEBEREVHJoXayeOvWrQIMg4iIiIiKIrWvWSQiIiKij4/aI4svXrzAwYMH0aFDBwDAhAkTkJ6eLq3X0dHB9OnTYWBgoP0oiYiIiKhQqJ0srl69Gjt37pSSxcWLF6N69eowNDQEAPz7779wcHDAqFGjCiZSIiIiIvrg1D4NvW7dOvTr10+p7Pfff8ehQ4dw6NAhfP/99/jjjz+0HiARERERFR61k8WrV69KN+UGAAMDA5Qq9d/mn3zyCS5duqTd6IiIiIioUKl9GjopKQm6uv9Vf/NpLQqFQukaRiIiIiIq/tQeWSxbtiz+/vtvlesvXryIsmXLaiUoIiIiIioa1E4W27Vrh2+//RYvX77Mte7FixeYOnUq2rdvr9XgiIiIiKhwqX0aeuLEifjjjz9QuXJlDBs2DJUqVYJMJsO///6LxYsXIysrCxMnTizIWImIiIjoA1N7ZNHW1hbHjx9H1apVMX78eHTp0gWdO3fGhAkTUK1aNRw9ehS2trYadX748GH4+fnBwcEBMpkMW7duldZlZmZi3LhxcHd3h7GxMRwcHNC7d2/cv39fqY3mzZtDJpMpvb744guN4iAiIiKivKk9sggAbm5u2Lt3L548eYLr168DACpUqAALC4t36jw1NRUeHh7o27cvPv30U6V1aWlpOHfuHKZMmQIPDw88ffoUI0eORMeOHXHmzBmlugMGDMC0adOk5Zx7PxIRERHR+9EoWcxhYWGBTz755L079/X1ha+vb57rzM3NERYWplS2aNEifPLJJ4iLi4Ozs7NUbmRkBDs7u/eOh4iIiIiUqXUaumvXrkhOTla70Z49eyIxMfGdg1IlKSkJMpkMpUuXVipft24drKysUL16dYwZMwbPnz/Pt5309HQkJycrvYiIiIgoN7VGFrdt25brvoqqCCGwY8cOTJ8+HTY2Nu8V3OtevnyJ8ePHo0ePHjAzM5PKe/bsCTc3N9jZ2eHvv//GhAkTcOHChVyjkq+bPXs2pk6dqrXYiIiIiEoqtZJFIYTS01s+tMzMTHzxxRdQKBRYsmSJ0roBAwZI/1+jRg1UrFgRdevWxblz51C7du0825swYQICAwOl5eTkZDg5ORVM8ERERETFmFrJ4qFDhzRu2NHRUeNt8pKZmQl/f3/ExsYiPDxcaVQxL7Vr14aenh6uXbumMlmUy+WQy+VaiY+IiIioJFMrWfTy8iroOPKUkyheu3YNhw4dgqWl5Vu3+eeff5CZmQl7e/sPECERERFRyfZOs6G1JSUlRboFDwDExsYiOjoaFhYWcHBwwGeffYZz585h586dyM7ORkJCAoBXs7H19fVx48YNrFu3Du3atYOVlRUuXbqE0aNHw9PTE40bNy6st0VERERUYhRqsnjmzBm0aNFCWs65jrBPnz4IDg7G9u3bAQC1atVS2u7QoUNo3rw59PX1cfDgQSxcuBApKSlwcnJC+/btERQUBB0dnQ/2PoiIiIhKqkJNFps3bw4hhMr1+a0DACcnJ0RGRmo7LCIiIiL6f2o/7o+IiIiIPj5MFomIiIhIpXc6Df3XX3/hjz/+QFxcHDIyMpTWnTt3TiuBEREREVHh03hk8aeffkLfvn1hY2OD8+fP45NPPoGlpSVu3ryp8jnPRERERFQ8aZwsLlmyBMuXL8fixYuhr6+PsWPHIiwsDCNGjEBSUlJBxEhEREREhUTjZDEuLg6NGjUCABgaGuL58+cAgF69emH9+vXajY6IiIiICpXGyaKdnR0eP34MAHBxccGJEycAvLqh9ttudUNERERExYvGyWLLli2xY8cOAED//v0xatQotGnTBt26dUOXLl20HiARERERFR6NZ0MvX74cCoUCADBo0CBYWFjg6NGj8PPzw6BBg7QeIBEREREVHo2Txbt378LJyUla9vf3h7+/P4QQuHPnDpydnbUaIBEREREVHo1PQ7u5ueHhw4e5yp88eQI3NzetBEVERERERYPGyaIQAjKZLFd5SkoKDAwMtBIUERERERUNap+GDgwMBADIZDJMmTIFRkZG0rrs7GycPHkStWrV0nqARERERFR41E4Wz58/D+DVyGJMTAz09fWldfr6+vDw8MCYMWO0HyERERERFRq1k8VDhw4BAPr27YuFCxfCzMyswIIiIiIioqJB49nQISEhBREHERERERVBGieLAHD69Gn8+eefiIuLQ0ZGhtK6zZs3ayUwIiIiIip8Gs+G3rBhAxo3boxLly5hy5YtyMzMxKVLlxAeHg5zc/OCiJGIiIiIConGyeKsWbPw448/YufOndDX18fChQtx+fJl+Pv784bcRERERCWMxsnijRs30L59ewCAXC5HamoqZDIZRo0aheXLl2s9QCIiIiIqPBonixYWFnj+/DkAwNHREX///TcA4NmzZ0hLS9NudERERERUqDSe4NK0aVOEhYXB3d0d/v7++OabbxAeHo6wsDC0atWqIGIkIiIiokKicbK4ePFivHz5EgAwYcIE6Onp4ejRo+jatSumTJmi9QCJiIiIqPBonCxaWFhI/1+qVCmMHTsWY8eO1WpQRERERFQ0aJwsJiUlISwsDLdu3YJMJkO5cuXQqlUrPtGFiIiIqATSKFlcu3Ythg0bhuTkZKVyc3NzLFu2DN26ddNqcERERERUuNSeDX3u3Dn07dsXnTt3xvnz5/HixQukpaXhzJkz8PPzQ69evXDhwoWCjJWIiIiIPjC1RxYXLVqEzp07IzQ0VKm8du3aWL16NdLS0rBw4UKsXLlS2zESERERUSFRe2Tx2LFj+Prrr1WuHzRoEI4ePaqVoIiIiIioaFA7Wbx//z4qVaqkcn2lSpVw7949rQRFREREREWD2sliWloaDAwMVK6Xy+XS/ReJiIiIqGTQaDb0vn37YG5unue6Z8+eaSMeIiIiIipCNEoW+/Tpk+96mUz2XsEQERERUdGidrKoUCgKMg4iIiIiKoLUvmaRiIiIiD4+TBaJiIiISCUmi0RERESkEpNFIiIiIlKJySIRERERqcRkkYiIiIhUUitZLFOmDCwsLNR6aeLw4cPw8/ODg4MDZDIZtm7dqrReCIHg4GA4ODjA0NAQzZs3xz///KNUJz09HcOHD4eVlRWMjY3RsWNH3L17V6M4iIiIiChvat1nccGCBdL/P378GDNmzICPjw8aNmwIAIiKisK+ffswZcoUjTpPTU2Fh4cH+vbti08//TTX+rlz52L+/PkIDQ1FpUqVMGPGDLRp0wZXrlyBqakpAGDkyJHYsWMHNmzYAEtLS4wePRodOnTA2bNnoaOjo1E8RERERKRMrWTx9Se3fPrpp5g2bRqGDRsmlY0YMQKLFy/GgQMHMGrUKLU79/X1ha+vb57rhBBYsGABJk2ahK5duwIAVq1aBVtbW/z+++/4+uuvkZSUhN9++w1r1qxB69atAQBr166Fk5MTDhw4AB8fH7VjISIiIqLcNL5mcd++fWjbtm2uch8fHxw4cEArQQFAbGwsEhIS4O3tLZXJ5XJ4eXnh+PHjAICzZ88iMzNTqY6DgwNq1Kgh1clLeno6kpOTlV5ERERElJvGyaKlpSW2bNmSq3zr1q2wtLTUSlAAkJCQAACwtbVVKre1tZXWJSQkQF9fH2XKlFFZJy+zZ8+Gubm59HJyctJa3EREREQlidrPhs4xdepU9O/fHxEREdI1iydOnMDevXvx66+/aj1AmUymtCyEyFX2prfVmTBhAgIDA6Xl5ORkJoxEREREedB4ZDEgIADHjx9H6dKlsXnzZmzatAnm5uY4duwYAgICtBaYnZ0dAOQaIUxMTJRGG+3s7JCRkYGnT5+qrJMXuVwOMzMzpRcRERER5fZO91msX78+1q1bh3PnzuH8+fNYt24d6tevr9XA3NzcYGdnh7CwMKksIyMDkZGRaNSoEQCgTp060NPTU6oTHx+Pv//+W6pDRERERO9O49PQAHDjxg2EhITg5s2bWLBgAWxsbLB37144OTmhevXqareTkpKC69evS8uxsbGIjo6GhYUFnJ2dMXLkSMyaNQsVK1ZExYoVMWvWLBgZGaFHjx4AAHNzc/Tv3x+jR4+GpaUlLCwsMGbMGLi7u0uzo4mIiIjo3Wk8shgZGQl3d3ecPHkSmzZtQkpKCgDg4sWLCAoK0qitM2fOwNPTE56engCAwMBAeHp64ttvvwUAjB07FiNHjsSQIUNQt25d3Lt3D/v375fusQgAP/74Izp37gx/f380btwYRkZG2LFjB++xSERERKQFGo8sjh8/HjNmzEBgYKBS0taiRQssXLhQo7aaN28OIYTK9TKZDMHBwQgODlZZx8DAAIsWLcKiRYs06puIiIiI3k7jkcWYmBh06dIlV7m1tTUeP36slaCIiIiIqGjQOFksXbo04uPjc5WfP38ejo6OWgmKiIiIiIoGjZPFHj16YNy4cUhISIBMJoNCocCxY8cwZswY9O7duyBiJCIiIqJConGyOHPmTDg7O8PR0REpKSmoVq0amjVrhkaNGmHy5MkFESMRERERFRKNJ7jo6elh3bp1mDZtGs6fPw+FQgFPT09UrFixIOIjIiIiokKkcbIYGRkJLy8vlC9fHuXLly+ImIiIiIioiND4NHSbNm3g7OyM8ePH4++//y6ImIiIiIioiNA4Wbx//z7Gjh2LI0eOoGbNmqhZsybmzp2Lu3fvFkR8RERERFSINE4WraysMGzYMBw7dgw3btxAt27dsHr1ari6uqJly5YFESMRERERFRKNk8XXubm5Yfz48fjuu+/g7u6OyMhIbcVFREREREXAOyeLx44dw5AhQ2Bvb48ePXqgevXq2LlzpzZjIyIiIqJCpvFs6IkTJ2L9+vW4f/8+WrdujQULFqBz584wMjIqiPiIiIiIqBBpnCxGRERgzJgx6NatG6ysrAoiJiIiIiIqIjROFo8fP14QcRARERFREaRWsrh9+3b4+vpCT08P27dvz7dux44dtRIYERERERU+tZLFzp07IyEhATY2NujcubPKejKZDNnZ2dqKjYiIiIgKmVrJokKhyPP/iYiIiKhke6/7LL58+VJbcRARERFREaRxspidnY3p06fD0dERJiYmuHnzJgBgypQp+O2337QeIBEREREVHo2TxZkzZyI0NBRz586Fvr6+VO7u7o5ff/1Vq8ERERERUeHSOFlcvXo1li9fjp49e0JHR0cqr1mzJv7991+tBkdEREREhUvjZPHevXuoUKFCrnKFQoHMzEytBEVERERERYPGyWL16tVx5MiRXOV//vknPD09tRIUERERERUNGj/BJSgoCL169cK9e/egUCiwefNmXLlyBatXr8bOnTsLIkYiIiIiKiQajyz6+flh48aN2L17N2QyGb799ltcvnwZO3bsQJs2bQoiRiIiIiIqJBqPLAKAj48PfHx8tB0LERERERUxGo8snj59GidPnsxVfvLkSZw5c0YrQRERERFR0aBxsjh06FDcuXMnV/m9e/cwdOhQrQRFREREREWDxsnipUuXULt27Vzlnp6euHTpklaCIiIiIqKiQeNkUS6X48GDB7nK4+Pjoav7TpdAEhEREVERpXGy2KZNG0yYMAFJSUlS2bNnzzBx4kTOhiYiIiIqYTQeCpw3bx6aNWsGFxcX6Sbc0dHRsLW1xZo1a7QeIBEREREVHo2TRUdHR1y8eBHr1q3DhQsXYGhoiL59+6J79+7Q09MriBiJiIiIqJC800WGxsbGGDhwoLZjISIiIqIiRuNrFgFgzZo1aNKkCRwcHHD79m0AwI8//oht27ZpNTgiIiIiKlwaJ4tLly5FYGAgfH198fTpU2RnZwMAypQpgwULFmg7PiIiIiIqRBoni4sWLcKKFSswadIkpVvl1K1bFzExMVoNjoiIiIgKl8bJYmxsrDQL+nVyuRypqalaCYqIiIiIigaNk0U3NzdER0fnKt+zZw+qVaumjZiIiIiIqIjQeDb0//73PwwdOhQvX76EEAKnTp3C+vXrMXv2bPz6668FESMRERERFRKNRxb79u2LoKAgjB07FmlpaejRoweWLVuGhQsX4osvvtB6gK6urpDJZLleQ4cOBQAEBATkWtegQQOtx0FERET0MXqn+ywOGDAAAwYMwKNHj6BQKGBjYwMAuHfvHhwdHbUa4OnTp6UZ1wDw999/o02bNvj888+lsrZt2yIkJERa1tfX12oMRERERB+rd0oWc1hZWQEAEhISMHPmTPz666948eKFVgLLYW1trbT83XffoXz58vDy8pLK5HI57OzstNovEREREWlwGvrZs2fo2bMnrK2t4eDggJ9++gkKhQLffvstypUrhxMnTmDlypUFGSsyMjKwdu1a9OvXDzKZTCqPiIiAjY0NKlWqhAEDBiAxMTHfdtLT05GcnKz0IiIiIqLc1B5ZnDhxIg4fPow+ffpg7969GDVqFPbu3YuXL19iz549SiN9BWXr1q149uwZAgICpDJfX198/vnncHFxQWxsLKZMmYKWLVvi7NmzkMvlebYze/ZsTJ06tcDjJSIiIiru1E4Wd+3ahZCQELRu3RpDhgxBhQoVUKlSpQ/61JbffvsNvr6+cHBwkMq6desm/X+NGjVQt25duLi4YNeuXejatWue7UyYMAGBgYHScnJyMpycnAoucCIiIqJiSu1k8f79+9J9FMuVKwcDAwN89dVXBRbYm27fvo0DBw5g8+bN+dazt7eHi4sLrl27prKOXC5XOepIRERERP9R+5pFhUIBPT09aVlHRwfGxsYFElReQkJCYGNjg/bt2+db7/Hjx7hz5w7s7e0/UGREREREJZfaI4tCCAQEBEgjci9fvsSgQYNyJYxvG/l7FwqFAiEhIejTp4/S86hTUlIQHByMTz/9FPb29rh16xYmTpwIKysrdOnSRetxEBEREX1s1E4W+/Tpo7T85Zdfaj0YVQ4cOIC4uDj069dPqVxHRwcxMTFYvXo1nj17Bnt7e7Ro0QIbN26EqanpB4uPiIiIqKRSO1l8/abXH5q3tzeEELnKDQ0NsW/fvkKIiIiIiOjjoPHj/oiIiIjo48FkkYiIiIhUYrJIRERERCoxWSQiIiIilZgsEhEREZFK75Us3r17FwqFQluxEBEREVER817JYrVq1XDr1i0thUJERERERc17JYt53fuQiIiIiEoOXrNIRERERCqp/QQXADh8+LDScnZ2Nk6dOoW7d+9KZc2aNdNOZERERERU6DRKFt98PnR6ejr+97//QVf3VTMymQw3b97UXnREREREVKg0ShZjY2OVlk1NTREZGYly5cppNSgiIiIiKhp4zSIRERERqcRkkYiIiIhUeq9k8csvv4SZmZm2YiEiIiKiIkajaxbftHTpUm3FQURERERFEE9DExEREZFKTBaJiIiISCUmi0RERESkEpNFIiIiIlJJ42QxLi4OQohc5UIIxMXFaSUoIiIiIioaNE4W3dzc8PDhw1zlT548gZubm1aCIiIiIqKiQeNkUQgBmUyWqzwlJQUGBgZaCYqIiIiIiga177MYGBgIAJDJZJgyZQqMjIykddnZ2Th58iRq1aql9QCJiIiIqPConSyeP38ewKuRxZiYGOjr60vr9PX14eHhgTFjxmg/QiIiIiIqNGoni4cOHQIA9O3bFwsXLuRj/oiIiIg+Aho/7i8kJKQg4iAiIiKiIkjjZDE1NRXfffcdDh48iMTERCgUCqX1N2/e1FpwRERERFS4NE4Wv/rqK0RGRqJXr16wt7fPc2Y0EREREZUMGieLe/bswa5du9C4ceOCiIeIiIiIihCN77NYpkwZWFhYFEQsRERERFTEaJwsTp8+Hd9++y3S0tIKIh4iIiIiKkI0Pg09b9483LhxA7a2tnB1dYWenp7S+nPnzmktOCIiIiIqXBoni507dy6AMIiIiIioKNI4WQwKCiqIOIiIiIioCNL4mkUAePbsGX799VdMmDABT548AfDq9PO9e/e0GhwRERERFS6NRxYvXryI1q1bw9zcHLdu3cKAAQNgYWGBLVu24Pbt21i9enVBxElEREREhUDjkcXAwEAEBATg2rVrMDAwkMp9fX1x+PBhrQZHRERERIVL42Tx9OnT+Prrr3OVOzo6IiEhQStBEREREVHRoHGyaGBggOTk5FzlV65cgbW1tVaCIiIiIqKiQeNksVOnTpg2bRoyMzMBADKZDHFxcRg/fjw+/fRTrQYXHBwMmUym9LKzs5PWCyEQHBwMBwcHGBoaonnz5vjnn3+0GgMRERHRx0zjZPGHH37Aw4cPYWNjgxcvXsDLywsVKlSAqakpZs6cqfUAq1evjvj4eOkVExMjrZs7dy7mz5+PxYsX4/Tp07Czs0ObNm3w/PlzrcdBRERE9DHSeDa0mZkZjh49ivDwcJw7dw4KhQK1a9dG69atCyI+6OrqKo0m5hBCYMGCBZg0aRK6du0KAFi1ahVsbW3x+++/53ldJRERERFpRuNkMUfLli3RsmVLbcaSp2vXrsHBwQFyuRz169fHrFmzUK5cOcTGxiIhIQHe3t5SXblcDi8vLxw/fjzfZDE9PR3p6enScl7XYBIRERGRBqehT548iT179iiVrV69Gm5ubrCxscHAgQOVEjBtqF+/PlavXo19+/ZhxYoVSEhIQKNGjfD48WNp5rWtra3SNra2tm+dlT179myYm5tLLycnJ63GTURERFRSqJ0sBgcH4+LFi9JyTEwM+vfvj9atW2P8+PHYsWMHZs+erdXgfH198emnn8Ld3R2tW7fGrl27ALw63ZxDJpMpbSOEyFX2pgkTJiApKUl63blzR6txExEREZUUaieL0dHRaNWqlbS8YcMG1K9fHytWrEBgYCB++ukn/PHHHwUSZA5jY2O4u7vj2rVr0nWMb44iJiYm5hptfJNcLoeZmZnSi4iIiIhyUztZfPr0qVISFhkZibZt20rL9erVK/ARuvT0dFy+fBn29vZwc3ODnZ0dwsLCpPUZGRmIjIxEo0aNCjQOIiIioo+F2smira0tYmNjAbxKys6dO4eGDRtK658/fw49PT2tBjdmzBhERkYiNjYWJ0+exGeffYbk5GT06dMHMpkMI0eOxKxZs7Blyxb8/fffCAgIgJGREXr06KHVOIiIiIg+VmrPhm7bti3Gjx+POXPmYOvWrTAyMkLTpk2l9RcvXkT58uW1Gtzdu3fRvXt3PHr0CNbW1mjQoAFOnDgBFxcXAMDYsWPx4sULDBkyBE+fPkX9+vWxf/9+mJqaajUOIiIioo+V2snijBkz0LVrV3h5ecHExASrVq2Cvr6+tH7lypVKt7HRhg0bNuS7XiaTITg4GMHBwVrtl4iIiIheUTtZtLa2xpEjR5CUlAQTExPo6Ogorf/zzz9hYmKi9QCJiIiIqPBofFNuc3PzPMstLCzeOxgiIiIiKlo0fjY0EREREX08mCwSERERkUpMFomIiIhIJSaLRERERKQSk0UiIiIiUonJIhERERGpxGSRiIiIiFRiskhEREREKjFZJCIiIiKVmCwSERERkUpMFomIiIhIJSaLRFTkLV26FDVr1oSZmRnMzMzQsGFD7NmzR1q/efNm+Pj4wMrKCjKZDNHR0W9tMzMzE9OmTUP58uVhYGAADw8P7N27N1e9e/fu4csvv4SlpSWMjIxQq1YtnD17VlofEBAAmUym9GrQoIFSG8uXL0fz5s1hZmYGmUyGZ8+evfO+IPoQ+J2j1zFZJKIir2zZsvjuu+9w5swZnDlzBi1btkSnTp3wzz//AABSU1PRuHFjfPfdd2q3OXnyZPzyyy9YtGgRLl26hEGDBqFLly44f/68VOfp06do3Lgx9PT0sGfPHly6dAnz5s1D6dKlldpq27Yt4uPjpdfu3buV1qelpaFt27aYOHHiu++EYqogko7Q0NBcyYJMJsPLly+lOrNnz0a9evVgamoKGxsbdO7cGVeuXFFqJyUlBcOGDUPZsmVhaGiIqlWrYunSpdL6J0+eYPjw4ahcuTKMjIzg7OyMESNGICkp6f13TBHH7xy9TrewAyAiehs/Pz+l5ZkzZ2Lp0qU4ceIEqlevjl69egEAbt26pXaba9aswaRJk9CuXTsAwODBg7Fv3z7MmzcPa9euBQDMmTMHTk5OCAkJkbZzdXXN1ZZcLoednZ3KvkaOHAkAiIiIUDu+kiIn6ahQoQIAYNWqVejUqRPOnz+P6tWrS0nH559/jgEDBqjdrpmZWa7kz8DAQPr/yMhIDB06FPXq1UNWVhYmTZoEb29vXLp0CcbGxgCAUaNG4dChQ1i7di1cXV2xf/9+DBkyBA4ODujUqRPu37+P+/fv44cffkC1atVw+/ZtDBo0CPfv38dff/2lhb1TdPE7R6/jyCIRFSvZ2dnYsGEDUlNT0bBhw3duJz09XSm5AABDQ0McPXpUWt6+fTvq1q2Lzz//HDY2NvD09MSKFStytRUREQEbGxtUqlQJAwYMQGJi4jvHVdL4+fmhXbt2qFSpEipVqoSZM2fCxMQEJ06cAAD06tUL3377LVq3bq1RuzKZDHZ2dkqv1+3duxcBAQGoXr06PDw8EBISgri4OKXTmVFRUejTpw+aN28OV1dXDBw4EB4eHjhz5gwAoEaNGti0aRP8/PxQvnx5tGzZEjNnzsSOHTuQlZX1nnum+OB3jpgsElGxEBMTAxMTE8jlcgwaNAhbtmxBtWrV3rk9Hx8fzJ8/H9euXYNCoUBYWBi2bduG+Ph4qc7NmzexdOlSVKxYEfv27cOgQYMwYsQIrF69Wqrj6+uLdevWITw8HPPmzcPp06fRsmVLpKenv9f7LYm0lXQAr04hu7i4oGzZsujQoYPSqcy85Jw6trCwkMqaNGmC7du34969exBC4NChQ7h69Sp8fHzybcfMzAy6uiX/xNyH+s7dvXtXumTg8uXL+Omnn2BhYaH0nVu0aBECAgLg4OCA9evXo3z58ggJCXnrdy5nBNje3h5OTk4YNWqU0uUKhw8fhp+fHxwcHCCTybB161al7TMzMzFu3Di4u7vjn4H/4N+R/+Lu8rvIfJqpXO9ZJu78cgf/jvgX/wz8B9eDriPpdMm5XIHJIhEVC5UrV0Z0dDROnDiBwYMHo0+fPrh06dI7t7dw4UJUrFgRVapUgb6+PoYNG4a+fftCR0dHqqNQKFC7dm3MmjULnp6e+PrrrzFgwACl69q6deuG9u3bo0aNGvDz88OePXtw9epV7Nq1673eb0mi7aSjSpUqCA0Nxfbt27F+/XoYGBigcePGuHbtWp71hRAIDAxEkyZNUKNGDan8p59+QrVq1VC2bFno6+ujbdu2WLJkCZo0aZJnO48fP8b06dPx9ddfv3PsxcmH+s4JITB06FCcOHECurq6MDMzw969e1GpUiV8/fXX+OqrrzBx4kTcvHkT27Ztw8WLF1G3bl0MGTIELVu2VPmdW7duHZYvXw4AOHnyJH777Tds3LgREyZMkOqkpqbCw8MDixcvzjPmtLQ0nDt3DlOmTEGFqRXgPMwZ6QnpuL3wtlK9uyvuIiMhA84jnVFxRkWY1THDnSV38OL2i3feX0UJk0UiKhb09fVRoUIF1K1bF7Nnz4aHhwcWLlz4zu1ZW1tj69atSE1Nxe3bt/Hvv//CxMQEbm5uUh17e3tUq1ZNabJESEgIoqOjVU6WyLlGbtCgQUpJZV7UmR2qzkSN7JfZuL/mPv4d9S/+GfAPrk24hsfhj99532ibtpOOBg0a4Msvv4SHhweaNm2KP/74A5UqVcKiRYvyrD9s2DBcvHgR69evVyr/6aefcOLECWzfvh1nz57FvHnzMGTIEBw4cCBXG8nJyWjfvj2qVauGoKCgd469OPlQ37kqVapIlww4ODigXbt2SpcMWFpaIiUlBUuXLkW9evVQuXJlLFmyBCkpKVi/fj3s7e3h4uKS68dCVFSU9OPA2dkZ3t7e6N69u3SZAfDqzMCMGTPQtWvXPGM2NzdHWFgY/P39IbeXw6iCEey/tMfLWy+R8ThDqvfi+gtYtLaAUTkj6Nvow6ajDXSMdPDiFpNFIqJCI4TQyqleAwMDODo6IisrC5s2bUKnTp2kdY0bN8aVK1ekyRInTpxAp06dYGBgAG9vb6Smpkp1R40ahb179+Lnn3+Gnp4e2rdvj+HDh2Pbtm0q+1ZndujrfYeFhSErKytX3wm/JyAlJgVlB5ZFxVkVYeljifi18Ug+l/ze+0cbtJ10vKlUqVKoV69eniOLw4cPx/bt23Ho0CGULVtWKn/x4gUmTpyI+fPnw8/PDzVr1sSwYcPQrVs3/PDDD0ptPH/+HG3btoWJiQm2bNkCPT09rcVenHzI7xzw3yUDN27ckLbLoaOjA319fRw9ehSPHz/GnTt3YG9vr9RPkyZNcPXqVWn55s2b2L17N9q3b/9e8SteKAAZoGP031kIo4pGSD6VjKyULAiFwLMTzyCyBIyrGr9XX0UFk0WiAqSN23fkZcWKFWjatCnKlCmDMmXKoHXr1jh16lSuekuWLIGbmxsMDAxQp04dHDlyRGWb90Lv4e+Av/Fo3yOl8icRT3Bz9k1cGnQJfwf8jezUbA32gHZMnDgRR44cwa1btxATE4NJkyYhIiICPXv2fBXjkyeIjo6WRquuXLmC6OhoJCQkSG307t1b6fTTyZMnsXnzZty8eRNHjhxB27ZtoVAoMHbsWKnOqFGjcOLECTRr1gxNmjTBhQsXsGXLFsyaNQtxcXE4cuQIxowZg6ioKBw+fBjNmjXDnDlzYG1tjZ9++kmaLJGQkIDo6Ghcv34dwKvTstHR0Vi1ahUmTpyIdu3aoVy5chg8eDB8fHwwb948KQZ1Jmqk3UhD6calYVLVBPrW+rBobgEDJwO8iC2aoxraSjpeby86OlopWRBCYNiwYdi8eTPCw8OVRoyBV6O6mZmZKFVK+Z9BHR0dKBQKaTk5ORne3t7Q19fH9u3bc03QKKkK6zs3cuRInDt3Di4uLjAwMMDvv/+OTZs2wcrKChMmTMCdO3cwatQoDB48GAkJCfjnn3/g5+cHKysrdOnSRWonISEBVapUkWZ1W1lZoXz58mjYsCHGjx//zvtFkaFAwp8JMG9gDh3D/5JFpyFOENkC/w57Nbp/f9V9OA93htxG/s59FSVMFokKkDqjQjkjUmvXrsXly5cxatSot45IRUREoHv37jh06BCioqKkUyz37t2T6mzcuBEjR47EpEmTcP78eTRt2hS+vr6Ii4vL1V7y2WS8uPECuqVzX7SvSFfA1N0U1h2s33NvvLsHDx6gV69eqFy5Mlq1aoWTJ09i7969aNOmDYBXMyg9PT2lEYMvvvgCnp6eWLZsmdRGXFyc0uSVly9fYvLkyahWrRq6dOkCR0dHHD16VOl+bvXq1cOWLVuwfv161KhRA9OnT8eCBQukfq2srBATE4NOnTrh+vXrWLduHRwdHXH8+HGcOXNGmiyxbNkyeHp6SreGadasGTw9PZGSkvLW2aFvymuihlFFIzyPfo7Mp5kQQiDlcgoyHmTAxN3kXXa3VhVE0jF16lTs27cPN2/eRHR0NPr374/o6GgYGBhIP86MjIzwyy+/4LvvvoOpqSkSEhKQkJAg3ZPR3NwcANCuXTvIZDJMmDABoaGhWL16tZR0PH/+HN7e3rh79y4iIyNhbGys8r6Or5s9ezZkMpl0+5a8qPpxVlQmShTWd27VqlWwtraGgYGB9J1buHAh9u7di6tXr8LZ2RkLFy7Eb7/9BplMhr///huVKlVCVFQUTE1NpXZyvnMbNmwA8GpyFQBs27YN06dPf6d9IrIE7iy9AwjAobeD8v7a9ADZadlwHeuK8kHlYeVjhbif4/DyTt6fkeJGJoQQhR1EYUtOToa5ubk0y42ooDx8+BA2NjaIjIxEs2bNALy6PUe3bt0wZcoUqV6dOnXQrl07tf+oZWdno0yZMli8eDF69+4NAKhfvz5q166tNEpZtWpVdO7cGbNnz5bKqiyoghvTbsB1jCtuz78NS29LWPlY5eoj5XIKbs25hao/V4WOsU6u9eqK6RPzztsWBUIIdOrUCU+fPlUaqc3IyMCAAQOwevVq6OrqolSpUvj111+l+9HlpUePHrhw4QK2bt2K8uXL4+DBg+jUqROys7PzHHl7s2/3Ve4AAEWWAvdD7uPZsWeAzqvbyjj0dUCZxmW09r7f9bj1798fBw8eRHx8PMzNzVGzZk2MGzdOSjpCQ0PRt2/fXNsFBQUhODgYAKRb24SGhgJ49QNr8+bNSEhIgLm5OTw9PREcHIypU6fiiy++QL169ZQmsrxuwYIF6NatGwAgMTERQ4YMwbFjx6Cvrw83NzcMHDgQo0aNgkwmQ0REBFq0aJFnO6dOnUK9evVylZ8+fRr+/v4wMzNDixYtsGDBglx1XEa4IHFrIrKeZ8HK10rp+xb7fSwUaQrY97KHrokunp14hsQtiSgfXB6GLoYq97Mqxen7Nnz4cGzduhWHDx/ONRKcIykpCRkZGbC2tkb9+vVRt25d/Pzzz3nWbdq0KRo0aIDvv/9eKlu7di0GDhyIlJSUXKPKMpkMW7ZsQefOnXO1lZmZCcsGlsh8mAnXca7QNfnvh3V6Yjqujb2GCjMrwMDxvx9/sXNjoW+jD8cAR012A4APd9zUzX9K/tx/oiIkv9t39OvXDw4ODoiIiMDVq1c1uqYrLS0NmZmZUrsZGRk4e/ZsrtMt3t7eOH78uLSsUChwd/ldWPlaKf2RI9VyJku8Ofr3+mQJFxcXHD58GEOGDIG9vb3KewguXLgQAwYMQJUqVSCTyVC+fHn07dtX6YbE6vT9JOwJ0m6kwfkbZ+hb6SP1Siri18RDr7QeTKoX7ujib7/9lu/6gIAABAQE5FvnzRsr//jjj/jxxx9z1Xt9cpAQIs8fZ6+zs7ODlZUVWrZsiYMHD+Za37x5cwghEBoaipEjR771kXEpKSno2bMnVqxYgRkzZuRZ5969e7i/9r704+xNL66/gH1vexiVMwIA2HS0weN9j/Hi1ot3ShaLAyEEhg8fji1btiAiIkJloghAGhG+du0azpw5k+8P6rS0tDwvMxBCQJNxsszMTPj7+yPjQQbcxrkpJYoAINL/vy2Z8nayUjKghAzH8TQ00Qeirdt35GX8+PFwdHSUkpJHjx4hOzsbtra2SvVsbW2VTu/NmTMHKAVYtrF8z3f3cdDGZInXqTMj+219KzIUePDXA9h/YQ8zTzMYOBnAsrUlzD8xx6M9j3K18zHJ68fZ6x48eIBdu3ahf//+b21Lnfs6Dh06FO3bt1f540ChUKBXr175/jgr6RMl8jJ06FCsXbsWv//+u9IlAy9e/HfN7Z9//omIiAjp9jlt2rRB586d4e3tLdV583IFPz8/LF26FBs2bEBsbCzCwsIwZcoUdOzYUbpFVkpKCqKjo6XHTMbGxiI6Olq6XCcrKwufffYZzpw5g7Jfl4VQCGQ+y0Tms0wosl5d2yq3l0PfVh/3Q+8j7WYa0hPT8WjPI6T8kwKz2iXjbCVHFok+EG2OSL1u7ty5WL9+PSIiInJd/yaTKf/UFUJIZWfPnsXChQtRdnzZXPVI2dtGPtSdLKFKzuzQzMxMbNq0Cf7+/mr3LbIFRLbI/dO/FDQaPSlpVP04e92qVatgamqq8rYpOXLu6+ju7o7k5GQsXLgQjRs3xoULF1CxYkUAwIYNG3Du3DmcPn1aZTtz5syBrq5uvj/OnIY44c6SO/h32L+ADlBKv1SJmiiRl5xLZZo3b65UHhISIo06x8fHIzAwEA8ePIC9vT169+6tdOkO8Ooayde/g5MnT4ZMJsPkyZNx7949WFtbw8/PDzNnzpTqnDlzRulSg8DAQABAnz59EBoairt372L79u2vVn6rHLfrOFeYVDWBTFcGl1EuePDnA9xecBuKlwrIbeVw/MoRph6mKAmYLBJ9ADmjQocPH85zRGrLli3SheI1a9ZEdHQ0fvjhh7cmiz/88ANmzZqFAwcOoGbNmlK5lZUVdHR0lEYRgVfXaOWMNh45cgSJiYl4MPrBfxUUQMKGBDze/xiV51V+37ddYgwdOhS///47tm3bJo18AK9OiRkaGsLMzAxeXl743//+B0NDQ7i4uCAyMhKrV6/G/PnzpXZ69+4NR0dH6ZrRkydP4t69e6hVqxbu3buH4ODgXLND8+sbAHQMdWBU2QgJGxMg05O9Og39byqeHXsGu+6qn51b0qn6cfa6lStXomfPnm+d4dygQQM0aNBAWm7cuDFq166NRYsW4aeffsKdO3fwzTffYP/+/Srbyvlxdu7cOfiEqX5CzOsTJXRMdPD83HPE/RyHchPLwcCpZF4qos6PmhEjRmDEiBH51nnzcgVdXV0EBQXle1/MnEsNVHF1dZXW51wjnBe5nRzOw53zja84Y7JIVIAKckTq+++/x4wZM7Bv3z7UrVtXaZ2+vj7q1KmDsLAwpdtJhIWFSfc069WrF1q3bo0u2/5bf+uHWyjdqDTKNNXexIiSQJ2Rjw0bNmDChAno2bMnnjx5AhcXF8ycORODBg2S6r858pEzO/TmzZswMTFBu3btsGbNGqXZofn1nXONlNNgJzz46wHu/nIX2anZ0LPUg+2ntrBokffp15JO1Y+z1x05cgRXrlzBxo0bNW7/zfs6nj17FomJiahTp45UJzs7G4cPH8bixYuRnp4u/ThzdnZGtvj/20+98eMsPTEdTw4+UZooYehsiNSrqXh88PE7TZQg0gYmi0QFqKBGpObOnYspU6bg999/h6urq9SuiYkJTExeTWgIDAxEr169ULduXTRs2BDLly9HXFyclLxYWlrC0tISBmf/G62Q6ciga64Luf1/p7wyn2UiKykLGYmvnlbw8u5LlDIoBT1LvVwXehc1ruO188g9l3E78ywP/hcIfq2PWyompuR4c+TDy8vrrU8yyW/UY96qV/dj1Cuth7Jf5Z0UfUw0mSjx22+/oU6dOvDw8HinfqKjo+Hu/mqkqVWrVoiJUZ692rdvX1SpUgXjxo2Djo6O9OMMgPQD7c0fZyVhooS2vnPquPXd+91cm9RXtP/SExVzBTUitWTJEmRkZOCzzz5Tavf1241069YNjx8/xrRp0xAfH48aNWpg9+7dcHFx0eg9PDn0BA+3PZSWY2fHAgAc+ztyBJIKzLskHY/3L0HqpUjYdJ2MlgtPAjgJAJDJjVBK778fQIr0NNz9fQPKtOgP1/G7ciUdb/44mzp1Kho0aICKFSsiOTkZP/30E6Kjo6Vbtpiamua6LtLY2BiWlpZSec6PMwDSD7Q3f5y9PlHC7gu7V6ehzz5Hyj8pcBmp2feWSJuYLBIVIHWuxbGzs1N5q5Qcb45I3bp1S63+hwwZgiFDhqhVF0Ce1ynadrGFbRfbPGoTFS0p53cDAB6sn6BUbtluJEzc/7v+N/XyYUAAxtW88mznzR9nz549w8CBA5Xu63j48GF88sknWo3/Y5goQcUTk0UiIioRVF0u8CbTWm1hWqutyvXq3tcxP2+2kZe8fpyV9IkSVDzxPov0UXnbs5ozMzMxbtw4uLu7w9jYGA4ODujduzfu37+fb7ubN29G3bp1Ubp0aRgbG6NWrVpYs2ZNvnG8+Tgwdfu+ceMGunTpAmtra5iZmcHf3x8PHjwAERFRQeDIIn1Ucp7VXK9ePWRlZWHSpEnw9vbGpUuXYGxsjLS0NJw7dw5TpkyBh4cHnj59ipEjR6Jjx444c+aMynYtLCwwadIkVKlSBfr6+ti5cyf69u0LGxsb+Pgo3ybj9OnTWL58udKtbgCo1Xdqaiq8vb3h4eGB8PBwAMCUKVPg5+eHEydO5JpVTURUYgWbf9j+3D7eEV8mi/RRef1xYMCriSY2NjY4e/YsmjVrBnNzc4SFhSnVWbRoET755BPExcXB2TnvPxZvTmD55ptvsGrVKhw9elQpWczvcWDq9H3s2DHcunUL58+fl57jGRISAgsLC4SHh6t1E28iIiJNMFmkj9rbHgeWU0cmkynd+y4/QgiEh4fjypUrrx6n95rXHwem6tmx+fWdnp4OmUwGufy/mZ0GBgYoVaoUjh49ymSxsHGkg4hKICaL9NFS53FgL1++xPjx49GjRw9pJE+VpKQkODo6Ij09HTo6OliyZAnatGkjrVfncWCSYHO8zBIYvzIVPWrowmy+EwCgQaoCxrrZGOdlhlmt5BACGHfgJRQKBeJ3zAKg2UX4AJhwEBFRvpgs0kfrbY8Dy8zMxBdffAGFQoElS5a8tT1TU1NER0cjJSUFBw8eRGBgIMqVK4fmzZur9Tgwpb6zBb746wUUAljS/r/61sal8OfnRhi86wV+OpmBUjKgu7seatuXgg4vVyR6NxwRJsoXk0X6KL3tcWCZmZnw9/dHbGwswsPD3zqqCLx6BFiFChUAALVq1cLly5cxe/ZsNG/eXK3Hgeno6PzX918vEPtMgfDeRjCTKz/Owbu8Lm6MMMWjNAV0S8lQ2kAGux+ew606s0UiItK+Iv2vy9tucwIAAQEBkMlkSq/XH/hO9DohBIYNG4bNmzcjPDw8z8eB5SSK165dw4EDB6SnLrxLX+np6QD+exxYdHS09Kpbty569uyJ6Oho5UTR3x/XHitwoJcRLI1Uf0WtjEqhtIEM4bFZSEwV6FiZv/2IiEj7ivS/Lm+7zUmOtm3bKj0BQ19fvzDCpWLgbc9qzsrKwmeffYZz585h586dyM7OlupYWFhIn603Hwc2e/Zs1K1bF+XLl0dGRgZ2796N1atXS4/7U+dxYEp9dzVEtgASUhSv+jaUQV/n1QhjyPkMVLUuBWujUoi6m4Vv9qZjVAN9VLbSKeC9R0REH6MinSy+7TYnOeRyOezs7D50eFQMve1ZzXfv3sX27dsBvDqV/LpDhw5J2735OLDU1FQMGTIEd+/ehaGhIapUqYK1a9eiW7duasem1PcvyusO9TFCc9dXX9crjxWYcDAdT14IuJYuhUlN9TGqAX8gERFRwSjSyeKbVN3mJCIiAjY2NihdujS8vLwwc+ZM2NjYqGwnPT1dOj0IAMnJyQUTMBU5b3tWs6urq1rPc37zUV4zZsxQ61Y4+bWh1Hc+F9x/19oA37V++yQZIiIibSjS1yy+TtVtTnx9fbFu3TqEh4dj3rx5OH36NFq2bKmUDL5p9uzZMDc3l15OTk4f4i0QERERFTvFZmRR1W1OXj/NV6NGDdStWxcuLi7YtWsXunbtmmdbEyZMQGBgoLScnJzMhJGIiIgoD8UiWXzbbU5eZ29vDxcXF1y7dk1lHblcrvQEDCIiIiLKW5FOFoUQGD58OLZs2YKIiIg8b3PypsePH+POnTuwt7f/ABESERERlWxF+prFoUOHYu3atfj999+l25wkJCTgxYsXAICUlBSMGTMGUVFRuHXrFiIiIuDn5wcrKyt06dKlkKMnIiIiKv6K9Mji225zoqOjg5iYGKxevRrPnj2Dvb09WrRogY0bN8LU1LQQIqaSynX8rg/a3y1OdiYioiKiSCeLb7uFiaGhIfbt2/eBoiEiIiL6+BTp09BEREREVLiYLBIRERGRSkwWiYiIiEglJotEREREpBKTRSIiIiJSickiEREREanEZJGIiIiIVGKySEREREQqMVkkIiIiIpWYLBIRERGRSkwWiYiIiEglJotEREREpBKTRSIiIiJSickiEREREanEZJGIiIiIVGKySEREREQqMVkkIiIiIpWYLBIRERGRSkwWiYiIiEglJotEREREpBKTRSIiIiJSickiEREREanEZJFymT17NurVqwdTU1PY2Nigc+fOuHLlitrbHzt2DLq6uqhVq1audZs2bUK1atUgl8tRrVo1bNmyRWn94cOH4efnBwcHB8hkMmzdujVXG5s3b4aPjw+srKwgk8kQHR2dq05CQgJ69eoFOzs7GBsbo3bt2vjrr7/Ufg9ERET0CpNFyiUyMhJDhw7FiRMnEBYWhqysLHh7eyM1NfWt2yYlJaF3795o1apVrnVRUVHo1q0bevXqhQsXLqBXr17w9/fHyZMnpTqpqanw8PDA4sWLVfaRmpqKxo0b47vvvlNZp1evXrhy5Qq2b9+OmJgYdO3aFd26dcP58+ff+h6IiIjoP7qFHQAVPXv37lVaDgkJgY2NDc6ePYtmzZrlu+3XX3+NHj16QEdHJ9eo4IIFC9CmTRtMmDABADBhwgRERkZiwYIFWL9+PQDA19cXvr6++fbRq1cvAMCtW7dU1omKisLSpUvxySefAAAmT56MH3/8EefOnYOnp2e+7RMREdF/OLJIb5WUlAQAsLCwyLdeSEgIbty4gaCgoDzXR0VFwdvbW6nMx8cHx48f106gr2nSpAk2btyIJ0+eQKFQYMOGDUhPT0fz5s213hcREVFJxmRRA0uWLIGbmxsMDAxQp04dHDlyRK3t8ruGb8GCBahcuTIMDQ3h5OSEUaNG4eXLl0r96enpQS6XQ19fP1e/mZmZGDduHNzd3WFsbAwHBwe0bds2V39ff/01ypcvDwMDAxgaGsLAwAByuRxVq1bF7t27pXpz5sxBmTJlIJPJIJPJUL9+fQQGBqJJkyaoUaMGAEAIgeDgYDg4OMDQ0BDNmzfH7t27MX78eKxbtw7JycnYvXs3/v33XxgZGcHZ2RkjRoxAQkICbG1tAQARERGQyWQIDAxEXFyc1J9MJsPp06c1PTS5bNy4EVlZWbC0tIRcLsfXX3+NLVu2oHz58u/dNhER0ceEyaKaNm7ciJEjR2LSpEk4f/48mjZtCl9fX8TFxeW7XX7X8K1btw7jx49HUFAQLl++jN9++w0bN27EhAkTpP5yRuKaNm0KXV1deHh4KPWblpaGc+fOYcqUKTh37hxWr16NiIgIGBsbK/VVp04dLF++HJUrV0adOnVQt25dWFpaYtmyZXB0dATwavLJjBkz0KBBAyxatAgAkJGRgYsXL0qniQFg7ty5mD9/PhYvXozTp09Lk2AmTJiASpUq4f79+3j+/DkcHBwQExOD0NBQ7N27F1lZWZDJZACARo0aIT4+HosXL4ZcLkd8fDy++uoruLq6om7duu94lP4zefJkPH36FAcOHMCZM2cQGBiIzz//HDExMe/dNhER0ceEyaKa5s+fj/79++Orr75C1apVsWDBAjg5OWHp0qX5bpdzDV/Dhg1zrYuKikLjxo3Ro0cPuLq6wtvbG927d8eZM2ek/qKjo/HVV1/hwIEDcHFxga2trVK/5ubmCAsLg7+/PypXroxff/0VPXv2RHJyMjIzM6W+Bg4ciGvXriElJQWHDh3CkiVLEB8fj7Jly8LDwwNZWVn45ptvsGDBAuzZswfDhg0DANy5cweHDh1C2bJlAbwaVVywYAEmTZqErl27okaNGli0aBEyMzMxevRoaUTzypUriI2NReXKlQEAM2fOhBAC9+7dAwDo6+vDzs4OGRkZsLW1haWlJbZv345+/fpJCeW7unHjBhYvXoyVK1eiVatW8PDwQFBQEOrWrYuff/75vdomIiL62DBZVENGRgbOnj2b63o7b2/vfK+3e9s1fE2aNMHZs2dx6tQpAMDNmzexe/du+Pj44OzZs2jZsqVSvzn9qeo3pz9/f38AgI6OjtL67du3o2HDhhg4cCAaNmwIPT09/P7778jOzsa5c+dw7949lCpVCrVq1ZJGJgcNGgQ3NzepjdjYWCQkJCjtC2trazRv3hzt27dHdHQ0oqOjMWjQIFSuXBnR0dGoX78+kpKSoK+vj/DwcKWY9u/fj0aNGmH79u149OgRAgICVO5PdaWlpQEASpVS/njr6OhAoVC8d/tEREQfE86GVsOjR4+QnZ0tXW+Xw9bWFgkJCXluc+3aNYwfPx5HjhyBrm7eu/mLL77Aw4cP0aRJEwghkJWVhcGDB6Nfv36YMmUK9PX1lfrN6S+vfnP6O3DgAPr27Qt3d/dcydK5c+fw4MEDAICrqytGjhyJadOmQQiBSpUqAQCCg4NRtWpV3LhxA8Crayp79uyJMmXKwNzcXOp35cqVSEpKwurVq1GqVClUrFgRt2/flq5rtLGxgYGBAWrUqIHHjx9j+vTp6N69O9auXYs5c+agU6dO2LZtGw4cOICjR49i6tSp8PHxQZkyZZTumxgbG4vo6GhYWFjA2dkZAPDkyRPExcXh/v37ACDdA9LOzg52dnaoUqUKKlSogK+//ho//PADLC0tsXXrVoSFhWHnzp2qDjMRERHlgSOLGnjz9KgQIs9TptnZ2ejRowemTp0qJWF5iYiIwMyZM7FkyRKcO3cOmzdvxs6dO/Hjjz8q9Zfz35z+3uw3p79vv/0WU6ZMgUKhQPv27XP1Z2JiAnt7e4SHh8Pd3R2rV6/G2LFjsXTpUmnEbdKkSdi3bx9SUlIAvLqnYbVq1WBvb4+NGzdKbSUmJipdr6lqXyQnJ6N9+/aoVq0aVqxYgQ0bNiAkJAQ1a9ZEaGgoNm7cCEdHR+zbtw/9+/fHmTNn4OnpKd3eJjAwEJ6envj222+lNrdv3w5PT0/pPX7xxRfw9PTEsmXLAAB6enrYvXs3rK2t4efnh5o1a2L16tVYtWoV2rVrp/J4EBERUW4cWVSDlZUVdHR0co3mJSYm5hptBIDnz5/jzJkzOH/+vHTtn0KhgBACurq62L9/P1q2bIkpU6agV69e+OqrrwAA7u7uSE1NxYABA6Cjo4P09HSlfnP6e7PfnP7Onj0LIQR0dHQQHR2dqz9HR0fo6emhRYsWaNy4McqUKYMnT54gISEBVlZWAIBq1apBCAHgVZJasWJFfP7555g5cyaAV6fKAWD8+PFK9yt8M6bg4GCMHj0aPj4+MDExwZYtW6Cnp4fPPvsMn332mdL+mj59OiwtLdGxY0fo6elJ/asSEBDw1tPVFStWxKZNm/KtQ0RERG/HkUU15NyyJiwsTKk8LCwMjRo1ylXfzMwMMTEx0vV7eV3DB7y6ti6v6+oAoHbt2jh06JBSvzn9vdmvoaEhWrZsiXLlyiEyMlJlf40bN8b169elUUQhBO7cuQN7e3s0aNAAcrk812P9EhMT4eLiIi27ubnBzs5OaV9kZGQgMjJSKabk5GR4e3tDX18f27dvh4GBQZ77VgiBkJAQ9O7dG3p6ennWISIiosLDkUU1BQYGolevXqhbty4aNmyI5cuXIy4uDoMGDQLw6mkk9+7dk67hy7l2L8fr1/Dl8PPzw/z58+Hp6Yn69evj+vXrmDJlCjp27IiuXbuiV69e6NevH1asWIHr16/j1q1b0unfnFvyrFy5Ev7+/rh69Sp27twpje4ZGRlBLpejRo0auHnzJjZu3IjGjRvjp59+Qrdu3fD06VPo6OggLCwMI0eOhJmZGQYNGoQpU6YgIyMD9vb2AF6d4q5SpQri4uLg7OwMmUyGkSNHYtasWahYsSIqVqyIWbNmwcjICD169ADwaqTT29sbaWlpWLt2LZKTk5GcnAzg1WSY1yfehIeHIzY2Fv379y+4g0dERETvjMmimrp164bHjx9j2rRpiI+PR40aNbB7925p1C0+Pv6t91x80+TJkyGTyTB58mTcu3dPusZu5syZKF26NB4/foy5c+cCePW8ZoVCgejoaOzevRsrV65EXFwc7t69i+3btwNArpt+59yA2sDAAEeOHMGCBQvw8uVLbN++HVlZWXB0dMSgQYMwbtw4AMD333+P+Ph4DBkyRGojJSUFXl5e6NOnD0JDQwEAY8eOxYsXLzBkyBA8ffoU9evXx/79+2FqagoAOHv2rPS85woVKijFFBsbC1dXV2n5t99+Q6NGjVC1alWN9h0RERF9GDLxtgvEPgLJyckwNzdHUlISzMzMCjscKoJcx+/6oP3dMujxwfpyd3P+YH0BQEyfD3dj9JJ83IAPe+w+5HEDPuyx43HTHh437fhQx03d/IfXLBIRERGRSkwWiYiIiEglXrNIWuO+yv2D9vehT68QERF9jErMyOKSJUvg5uYGAwMD1KlTB0eOHCnskIiIiIiKvRKRLG7cuBEjR47EpEmTcP78eTRt2hS+vr4az04mIiIiImUlIlmcP38++vfvj6+++gpVq1bFggUL4OTkhKVLlxZ2aERERETFWrG/ZjEjIwNnz57F+PHjlcq9vb1x/PjxPLdJT09Henq6tJyUlAQA0o2jC1KNoH0F3sfr/jb4cDe7znYp+8H6Aj7M8cqhSE/7YH0BQLLsw93RKvtF9gfrC+Bx06YPeew+5HEDPuyx43HTHh437fhQxy2nn7feRVEUc/fu3RMAxLFjx5TKZ86cKSpVqpTnNkFBQQIAX3zxxRdffPHF10f/unPnTr65VrEfWcwhk8mUloUQucpyTJgwAYGBgdKyQqHAkydPYGlpqXKbj0VycjKcnJxw584d3qC8mOGxK5543IonHrfiicdNmRACz58/h4ODQ771in2yaGVlBR0dHSQkJCiVJyYmSs9JfpNcLodcLlcqK126dEGFWCyZmZnxi1RM8dgVTzxuxROPW/HE4/Yfc3Pzt9Yp9hNc9PX1UadOHYSFhSmVh4WFoVGjRoUUFREREVHJUOxHFgEgMDAQvXr1Qt26ddGwYUMsX74ccXFxGDRoUGGHRkRERFSslYhksVu3bnj8+DGmTZuG+Ph41KhRA7t374aLi0thh1bsyOVyBAUF5TpNT0Ufj13xxONWPPG4FU88bu9GJsTb5ksTERER0ceq2F+zSEREREQFh8kiEREREanEZJGIiIiIVGKyWMLJZDJs3bq1sMPQSEREBGQyGZ49e1bYoRRbt27dgkwmQ3R0NADuU6KPkaurKxYsWFDYYVAJwGSxiAkICIBMJsv1un79emGHli8mI9qVmJiIr7/+Gs7OzpDL5bCzs4OPjw+ioqLeqb1GjRohPj5euvlqaGjoR3kjem3vVyoYPE7FF49dyVQibp1T0rRt2xYhISFKZdbW1oUSS0ZGBvT19Qul74/Zp59+iszMTKxatQrlypXDgwcPcPDgQTx58uSd2tPX14ednZ2Woyx+tL1fNZWZmQk9Pb0P0ldx9rEcp5L4efhYjt1HJ98nR9MH16dPH9Gp0/+1d+dBTZ3rH8C/AQsJhEUoUrAIepFNRKAuiFfBBWOt4i6jjMq4VC7WXateEbQuFHGh1SpTRUDAteoMuOBWsNYNQQGVGARRbKVS1F4FKwp5fn/0xxmPEKD3tgbx+cxkhvOed39zkodzcpJh9e5LSUkhT09P0tfXp/bt29Py5cvp5cuXwv6CggLq3bs36evrk7OzM504cYIA0KFDh4Q8P/30E40dO5ZMTU3JzMyM/P39qbi4uE77a9asISsrK7K1tSUiosTERProo49ILpeTpaUljRs3jh48eEBERMXFxXV+lHzSpElERKRWqykyMpLat29PUqmU3NzcaP/+/aJxHTlyhDp27EhSqZR8fX0pLi6OANDjx4//1+l8Kz1+/JgAUEZGhsY8AGjLli00aNAgkkqlZGdnR/v27RP2167J1atXiYgoPT1dmNPav199hIeH/82j0r7G5nXevHk0ZMgQYXvjxo0EgA4fPiykOTg4UExMDBERZWZm0oABA8jc3JyMjY2pT58+lJ2dLaoTAG3dupX8/f3JwMCAwsLCKDw8nLp06UKxsbFkY2NDhoaGFBwcTNXV1RQZGUmWlpZkYWFBq1atavLYwsPDycbGhvT09MjKyopmzpwp7KuqqqKFCxeStbU1GRgYUPfu3Sk9PV1UPi4ujmxsbEgmk9Hw4cNp3bp1ZGJi0uT2/0oteZ3qa6e6upomT55MdnZ2JJVKycHBgaKjo0Xlal+Xo6Ki6IMPPiAzMzMKCQmhFy9eCHkePHhAQ4YMEV4PkpKSyNbWljZu3Njk/v2vWvraxcTE0CeffEIymYycnJzo/PnzdOvWLfLx8SEDAwPy8vKiwsJCoUxhYSH5+/tTmzZtyNDQkLp27UonT54U9iuVSpLJZJScnCykHThwgPT19SkvL6/JfXsTOFhsZjQFi2lpaWRsbEzx8fFUVFREJ06cIDs7O1q+fDkREdXU1JCrqyv5+vrS1atX6cyZM+Th4SEKFisrK6ljx440efJkysvLo/z8fBo/fjw5OjpSVVWV0L5cLqcJEybQ9evX6dq1a0REFBsbS0ePHqWioiK6cOECeXl50ccff0xERNXV1XTgwAECQCqVikpLS+m3334jIqJ///vf5OTkRGlpaVRUVERxcXGkr68vvJiUlJSQvr4+zZ49m27evElJSUlkaWn5TgeLL1++JLlcTnPmzKHnz5/XmwcAmZub07Zt20ilUlFoaCjp6upSfn4+ETUcLFZVVVF0dDQZGxtTaWkplZaW0tOnT9/U8LSmsXlNSUkhExMTqqmpISKi4cOH0/vvv08LFy4kIqLS0lICQEqlkoiITp8+TYmJiZSfn0/5+fk0ZcoUsrS0pCdPngh1AqA2bdpQbGwsFRUV0Z07dyg8PJzkcjmNHj2abty4QSkpKaSnp0cKhYJmzpxJN2/epB07dhAAunDhQqPj2r9/PxkbG9PRo0fp7t27dOnSJfr222+F/ePHjydvb2/64YcfqLCwkKKiokhfX58KCgqIiOjixYskkUgoIiKCVCoVffXVV2Rqaqq1YLGlrpOmdl68eEFhYWGUmZlJt2/fpqSkJDIwMKC9e/cK5SZNmkTGxsYUHBxMSqWSUlNTycDAQLTOH3/8Mbm6utL58+cpKyuLvL29SSaTvdFgsaWvXdu2bWnv3r2kUqlo+PDhZGdnR/369aO0tDTKz88nLy8vGjRokFAmJyeHYmJiKC8vjwoKCmjp0qUklUrp7t27Qp5vvvmGTExM6M6dO/Tzzz+TmZnZG12zpuJgsZmZNGkS6erqkqGhofAYPXo09e7dm9asWSPKm5iYSFZWVkREdPz4cdLV1aV79+4J+48dOyYKFmNjY8nR0ZHUarWQp6qqimQyGR0/flxo39LSUggeNcnMzCQAQpDxajBSq6KigqRSKZ0/f15UdsqUKTRu3DgiIlqyZAk5OzuL+rRo0aJ3OlgkIvruu++odevWJJVKydvbm5YsWUK5ubnCfgAUHBwsKtOjRw/617/+RUQNB4tEf5xJ0lYwoE0Nzetvv/1GOjo6lJWVRWq1mszNzSkiIoK6detGRES7du0iS0tLjXVXV1eTkZERpaamCmkAaM6cOaJ84eHhZGBgIHrDUygUZGdnJ7yJEhE5OjpSREREo2Nav349OTg4iM4y1SosLCSJREI///yzKL1///60ZMkSIiIaN26c6A2OiCggIECrz4+WuE6a2qlPSEgIjRo1StieNGkS2draUnV1tZA2ZswYCggIICIilUpFAOjixYvCfqVSSQDeeODRktcuNDRU2L5w4QIBoNjYWCFt9+7dJJVKG6zHxcWFNm3aJEr75JNPqHfv3tS/f3/y8/MTvR82F3yDSzPUt29f5OTkCI+vv/4a2dnZ+OKLLyCXy4XHtGnTUFpaimfPnkGpVKJdu3b48MMPhXp69uwpqjc7OxuFhYUwMjIS6jAzM8Pz589RVFQk5OvcuXOdzylevXoVw4YNg62tLYyMjODr6wsAKCkp0TiO/Px8PH/+HH5+fqJ+79y5U2hPqVTCy8sLEolEY7/fRaNGjcL9+/eRkpIChUKBjIwMeHp6Ij4+Xsjz+jz17NkTSqXyDff07dLQvJqYmMDd3R0ZGRm4du0adHR0MH36dOTm5uLp06fIyMiAj4+PUFdZWRmCg4Ph4OAAExMTmJiYoKKios4x0bVr1zr9sLOzg5GRkbBtaWkJFxcX6OjoiNLKysoaHdOYMWPw+++/o0OHDpg2bRoOHTqE6upqAMCVK1dARHBwcBAdg2fOnBEdg/U9l7SpJa5TQ+3ExMSga9eusLCwgFwux7Zt2+r0r1OnTtDV1RW2rayshHaVSiVatWolqtvJyUkrN7G15LVzc3MTlQX+eL98Ne358+d48uQJAKCyshKff/45XFxcYGpqCrlcjps3b9bp/44dO5CXl4crV64gPj5e9H7YXPANLs2QoaEh7O3tRWlqtRorVqzAyJEj6+SXSqWgen618fUnnFqtxkcffYTk5OQ6eV+9gcbQ0FC0r7KyEgMHDsTAgQORlJQECwsLlJSUQKFQ4MWLFxrHoVarAQBHjhxB27ZtRftqf5ezvn6zP0ilUvj5+cHPzw9hYWGYOnUqwsPDERQUpLFMc3yRaW4amldfX19kZGRAT08PPj4+aN26NTp16oRz584hIyMDc+bMEeoJCgrCr7/+iujoaNja2kJfXx89e/asc0y8fjwBqPMBfIlEUm9a7THUEBsbG6hUKpw8eRKnTp1CSEgIoqKicObMGajVaujq6iI7O1sUaACAXC4H0HyPwZa2Tpra2bdvH+bOnYv169ejZ8+eMDIyQlRUFC5dutRoX2rbrV3D5nL8t9S1e7V87VzXl1Zb58KFC3H8+HGsW7cO9vb2kMlkGD16dJ3+5+bmorKyEjo6Ovjll19gbW3d5D69KRwsviU8PT2hUqnqBJG1XFxcUFJSgvv37wtPtNe/qsDT0xN79+5FmzZtYGxs3OS2b968ifLycnz55ZewsbEBAGRlZYny1J6JrKmpEfVJX18fJSUlov8WX+/3698DefHixSb37V3y+lxdvHgREydOFG17eHg0qS49PT3RWr3LXp1XX19fxMbGolWrVhgwYAAAwMfHB3v27EFBQYHoeXz27Fls2bIFgwcPBgDcu3cP5eXlb7z/ACCTyeDv7w9/f3/MmDEDTk5OuHbtGjw8PFBTU4OysjL07t273rIuLi51jrnmeAy2hHWqz9mzZ+Ht7Y2QkBAh7dUrPU3h7OyM6upqZGVloXv37gAAlUrVbL7KrKWuXWPOnj2LoKAgjBgxAgBQUVGBO3fuiPI8evQIQUFBWLp0KX755RcEBgbiypUrkMlkWuixZnwZ+i0RFhaGnTt3Yvny5bhx4waUSiX27t2L0NBQAMCAAQPg6OiIiRMnIjc3F2fPnsXSpUtFdQQGBuL999/HsGHDcPbsWRQXF+PMmTOYPXs2fvrpJ41tt2vXDnp6eti0aRNu376NlJQUrFy5UpTH1tYWEokEhw8fxq+//oqKigoYGRlhwYIFmDt3LhISElBUVISrV6/im2++QUJCAgAgODgYRUVFmDdvHlQqFXbt2iW61PouevjwIfr164ekpCTk5eWhuLgY+/fvx9q1azFs2DAh3/79+7Fjxw4UFBQgPDwcmZmZ+Oyzz5rUhp2dHSoqKnD69GmUl5fj2bNnf9dwmo2mzGufPn3w9OlTpKamCh+18PX1Fc6ou7i4CPXZ29sjMTERSqUSly5dQmBgoFZe4OPj4xEbG4vr16/j9u3bSExMhEwmg62tLRwcHBAYGIiJEyfi4MGDKC4uxuXLlxEZGYmjR48CAGbNmoW0tDSsXbsWBQUF2Lx5M9LS0t74OGq11HXSxN7eHllZWTh+/DgKCgqwbNkyXL58+U/V4ejoiEGDBmHatGm4dOkSsrOzMXXq1Dc+zndt7Rpjb2+PgwcPIicnB7m5uRg/fnydM5nBwcGwsbFBaGgoNmzYACLCggULtNRjzThYfEsoFAocPnwYJ0+eRLdu3eDl5YUNGzbA1tYWAKCjo4NDhw6hqqoK3bt3x9SpU7F69WpRHQYGBvjhhx/Qrl07jBw5Es7Ozpg8eTJ+//33Bs80WlhYID4+Hvv374eLiwu+/PJLrFu3TpSnbdu2WLFiBRYvXgxLS0shaFm5ciXCwsIQEREBZ2dnKBQKpKamon379gD+CEQPHDiA1NRUdOnSBTExMVizZs1fOXVvHblcjh49emDjxo3o06cPXF1dsWzZMkybNg2bN28W8q1YsQJ79uyBm5sbEhISkJycLHqhbYi3tzeCg4MREBAACwsLrF279u8aTrPRlHk1MTGBh4cHzMzMhLns3bs31Gp1nbPjO3bswOPHj+Hh4YEJEyZg1qxZaNOmzRsfl6mpKbZt24ZevXrBzc0Np0+fRmpqKszNzQEAcXFxmDhxIubPnw9HR0f4+/vj0qVLwlUCLy8vbN++HZs2bYK7uztOnDgh/BOqDS11nTQJDg7GyJEjERAQgB49euDhw4eis4xNFRcXBxsbG/j4+GDkyJH49NNP3/g437W1a8zGjRvRunVreHt7Y+jQoVAoFPD09BT279y5E0ePHkViYiJatWoFAwMDJCcnY/v27cI/c82FhJrrB1YYYxpJJBIcOnQIw4cP13ZXWAsUHx+POXPmNJvLmIwx7eIzi4wxxhhjTCMOFhljrBlLTk4Wfe3Nq49OnTppu3vs//E6vb147RrHl6EZY6wZe/r0KR48eFDvvvfee0/43DLTLl6ntxevXeM4WGSMMcYYYxrxZWjGGGOMMaYRB4uMMcYYY0wjDhYZY4wxxphGHCwyxhhjjDGNOFhkjDEN7OzsEB0dre1uID4+Hqamps2qLYlEUud33RljLRMHi4wxrQoKCmq2v0Rz+fJlfPrpp9ruxhsVEBCAgoICYXv58uVwd3fXXocYY1rXStsdYIyxN+3ly5d47733Gs1nYWHxBnrTfLx8+RIymQwymUzbXWGMNSN8ZpEx1qzl5+dj8ODBkMvlsLS0xIQJE1BeXi7sT0tLwz//+U+YmprC3NwcQ4YMQVFRkbD/zp07kEgk2LdvH3x9fSGVSpGUlCSc0Vy3bh2srKxgbm6OGTNm4OXLl0LZ1y9DSyQSbN++HSNGjICBgQE6duyIlJQUUX9TUlLQsWNHyGQy9O3bFwkJCZBIJA3+zvKGDRvQuXNnGBoawsbGBiEhIaioqGhwXlatWoU2bdrAyMgIU6dOxeLFi0VnANVqNb744gt8+OGH0NfXh7u7O9LS0hqdl1cvQ8fHx2PFihXIzc2FRCKBRCJBfHy8UEd5ebnGucjIyIBEIsHx48fh4eEBmUyGfv36oaysDMeOHYOzszOMjY0xbtw4PHv2rMGxMsa0i4NFxlizVVpaCh8fH7i7uyMrKwtpaWl48OABxo4dK+SprKzEvHnzcPnyZZw+fRo6OjoYMWIE1Gq1qK5FixZh1qxZUCqVUCgUAID09HQUFRUhPT0dCQkJiI+PFwVD9VmxYgXGjh2LvLw8DB48GIGBgXj06BGAPwKw0aNHY/jw4cjJycH06dOxdOnSRsepo6ODr7/+GtevX0dCQgK+//57fP755xrzJycnY/Xq1YiMjER2djbatWuHrVu3ivJ89dVXWL9+PdatW4e8vDwoFAr4+/vj1q1bjc5LrYCAAMyfPx+dOnVCaWkpSktLERAQ0KS5qLV8+XJs3rwZ58+fx7179zB27FhER0dj165dOHLkCE6ePIlNmzY1OkeMMS0ixhjTokmTJtGwYcPq3bds2TIaOHCgKO3evXsEgFQqVb1lysrKCABdu3aNiIiKi4sJAEVHR9dp19bWlqqrq4W0MWPGUEBAgLBta2tLGzduFLYBUGhoqLBdUVFBEomEjh07RkREixYtIldXV1E7S5cuJQD0+PHj+iegHvv27SNzc3NhOy4ujkxMTITtHj160IwZM0RlevXqRV26dBG2ra2tafXq1aI83bp1o5CQECLSPC+vtxUeHi6qt1Zjc5Genk4A6NSpU0KeiIgIAkBFRUVC2vTp00mhUGiYCcZYc8BnFhljzVZ2djbS09Mhl8uFh5OTEwAIl5qLioowfvx4dOjQAcbGxmjfvj0AoKSkRFRX165d69TfqVMn6OrqCttWVlYoKytrsE9ubm7C34aGhjAyMhLKqFQqdOvWTZS/e/fujY4zPT0dfn5+aNu2LYyMjDBx4kQ8fPgQlZWV9eZXqVR16n11+8mTJ7h//z569eolytOrVy8olUpRWn3z0lQNzUV9eSwtLWFgYIAOHTqI0hqbc8aYdvENLoyxZkutVmPo0KGIjIyss8/KygoAMHToUNjY2GDbtm2wtraGWq2Gq6srXrx4IcpvaGhYp47Xb3KRSCR1Ll//mTJEBIlEItpPRA3Wd/fuXQwePBjBwcFYuXIlzMzM8OOPP2LKlCmiz0++rint1Jfn9bT65qWpmjJ/r+aRSCT/1ZwzxrSLzywyxpotT09P3LhxA3Z2drC3txc9DA0N8fDhQyiVSoSGhqJ///5wdnbG48ePtdZfJycnXL58WZSWlZXVYJmsrCxUV1dj/fr18PLygoODA+7fv99gGUdHR2RmZmpsx9jYGNbW1vjxxx9Fec6fPw9nZ+emDEWgp6eHmpqaP1WGMdaycLDIGNO6//znP8jJyRE9SkpKMGPGDDx69Ajjxo1DZmYmbt++jRMnTmDy5MmoqalB69atYW5ujm+//RaFhYX4/vvvMW/ePK2NY/r06bh58yYWLVqEgoIC7Nu3T7hh5vUzerX+8Y9/oLq6Gps2bcLt27eRmJiImJiYBtuZOXMmYmNjkZCQgFu3bmHVqlXIy8sTtbFw4UJERkZi7969UKlUWLx4MXJycjB79uw/NSY7OzsUFxcjJycH5eXlqKqq+lPlGWNvPw4WGWNal5GRAQ8PD9EjLCwM1tbWOHfuHGpqaqBQKODq6orZs2fDxMQEOjo60NHRwZ49e5CdnQ1XV1fMnTsXUVFRWhtH+/bt8d133+HgwYNwc3PD1q1bhbuh9fX16y3j7u6ODRs2IDIyEq6urkhOTkZERESD7QQGBmLJkiVYsGABPD09UVxcjKCgIEilUiHPrFmzMH/+fMyfPx+dO3dGWlqa8LU+f8aoUaMwaNAg9O3bFxYWFti9e/efKs8Ye/tJqLEP1DDGGPuvrV69GjExMbh3797f2o6fnx8++OADJCYm/q3tMMbePXyDC2OM/YW2bNmCbt26wdzcHOfOnUNUVBQ+++yzv7SNZ8+eISYmBgqFArq6uti9ezdOnTqFkydP/qXtMMYYwMEiY4z9pWo/Q/jo0SO0a9cO8+fPx5IlS/7SNiQSCY4ePYpVq1ahqqoKjo6OOHDgAAYMGPCXtsMYYwBfhmaMMcYYYw3gG1wYY4wxxphGHCwyxhhjjDGNOFhkjDHGGGMacbDIGGOMMcY04mCRMcYYY4xpxMEiY4wxxhjTiINFxhhjjDGmEQeLjDHGGGNMo/8DUvK9ajOL0fkAAAAASUVORK5CYII=" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# adapted from https://matplotlib.org/stable/gallery/lines_bars_and_markers/barchart.html#sphx-glr-gallery-lines-bars-and-markers-barchart-py\n", + "import matplotlib.pyplot as plt\n", + "\n", + "strategies=[\"Federated\", \"Split\", \"Swarm_seq\", \"Swarm_rand\", \"Swarm_max\"]\n", + "\n", + "x = np.arange(len(strategies)) # the label locations\n", + "width = 0.25 # the width of the bars\n", + "multiplier = 0\n", + "\n", + "fig, ax = plt.subplots(layout='constrained')\n", + "\n", + "for attribute, measurement in groups.items():\n", + " offset = width * multiplier\n", + " rects = ax.bar(x + offset, measurement, width, label=attribute)\n", + " ax.bar_label(rects, padding=3)\n", + " multiplier += 1\n", + "\n", + "ax.set_xlabel('Learning algorithm')\n", + "ax.set_ylabel('Sent + Received Data [GBytes]')\n", + "ax.set_title('Communication overhead by algorithm')\n", + "ax.set_xticks(x + width, strategies)\n", + "ax.legend(loc='upper left', ncols=3)\n", + "ax.set_ylim(0, 200)\n", + "\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:35:17.319567800Z", + "start_time": "2023-11-23T14:35:16.590542600Z" + } + }, + "id": "9567c2929b5cf534" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [ + { + "data": { + "text/plain": "dict_keys(['fed_equal_batteries_only_flops', 'fed_unequal_batteries_only_flops', 'fed_equal_batteries_unlimited', 'split_equal_batteries_only_flops', 'split_unequal_batteries_only_flops', 'split_equal_batteries_unlimited', 'swarm_seq_equal_batteries_only_flops', 'swarm_seq_unequal_batteries_only_flops', 'swarm_seq_equal_batteries_unlimited', 'swarm_rand_equal_batteries_only_flops', 'swarm_rand_unequal_batteries_only_flops', 'swarm_rand_equal_batteries_unlimited', 'swarm_max_equal_batteries_only_flops', 'swarm_max_unequal_batteries_only_flops', 'swarm_max_equal_batteries_unlimited'])" + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "run_groups.keys()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:35:17.343644800Z", + "start_time": "2023-11-23T14:35:17.326092400Z" + } + }, + "id": "2be034350770ae3b" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "def dpm(method):\n", + " data_per_method = {}\n", + " for run in run_groups[method]:\n", + " history_df = pd.DataFrame(run.scan_history())\n", + " for col in history_df.columns:\n", + " if col.startswith(\"/Device/\"):\n", + " if col in data_per_method.keys():\n", + " data_per_method[col] += history_df[col].sum()\n", + " else:\n", + " data_per_method[col] = history_df[col].sum()\n", + " return data_per_method\n", + "\n", + "def dpd(method):\n", + " data_per_device = {}\n", + " for run in run_groups[method]:\n", + " history_df = pd.DataFrame(run.scan_history())\n", + " for col in history_df.columns:\n", + " if col.startswith(\"/Device/\"):\n", + " if col in data_per_device.keys():\n", + " data_per_device[col].append(history_df[col].sum())\n", + " else:\n", + " data_per_device[col] = [history_df[col].sum()]\n", + " return data_per_device" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:35:19.128896700Z", + "start_time": "2023-11-23T14:35:17.338127200Z" + } + }, + "id": "5d4467da7723b00b" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "data": { + "text/plain": " 0 1 \\\n/Device/GetBatteryStatus_response_size 7.500000e+02 7.000000e+02 \n/Device/TrainEpoch_response_size 6.137760e+05 6.137760e+05 \n/Device/Evaluate_response_size 4.000000e+01 4.000000e+01 \n/Device/TrainBatch_request_size 2.832800e+09 2.832800e+09 \n/Device/EvaluateBatch_request_size 7.082035e+08 7.082035e+08 \n/Device/TrainBatch_response_size 2.831694e+09 2.831694e+09 \n/Device/EvaluateBatch_response_size 1.520000e+03 1.520000e+03 \n/Device/SetWeights_request_size 1.006521e+07 1.006524e+07 \n/Device/EndExperiment_request_size 5.000000e+00 NaN \n/Device/StartExperiment_response_size 5.000000e+00 5.000000e+00 \n/Device/SetWeights_response_size 5.000000e+01 6.000000e+01 \n/Device/GetBatteryStatus_request_size 1.500000e+02 1.400000e+02 \n/Device/TrainEpoch_request_size 2.240000e+02 2.240000e+02 \n/Device/TrainGlobal_request_size 1.400000e+01 2.100000e+01 \n/Device/TrainGlobal_response_size 9.604898e+06 9.604890e+06 \n/Device/Evaluate_request_size 2.400000e+02 2.400000e+02 \n\n 2 3 \\\n/Device/GetBatteryStatus_response_size 6.500000e+02 6.000000e+02 \n/Device/TrainEpoch_response_size 6.137760e+05 6.137760e+05 \n/Device/Evaluate_response_size 4.000000e+01 4.000000e+01 \n/Device/TrainBatch_request_size 2.832800e+09 2.832800e+09 \n/Device/EvaluateBatch_request_size 7.082035e+08 7.082035e+08 \n/Device/TrainBatch_response_size 2.831694e+09 2.831694e+09 \n/Device/EvaluateBatch_response_size 1.520000e+03 1.520000e+03 \n/Device/SetWeights_request_size 1.006524e+07 1.006524e+07 \n/Device/EndExperiment_request_size NaN NaN \n/Device/StartExperiment_response_size 5.000000e+00 5.000000e+00 \n/Device/SetWeights_response_size 6.000000e+01 6.000000e+01 \n/Device/GetBatteryStatus_request_size 1.300000e+02 1.200000e+02 \n/Device/TrainEpoch_request_size 2.240000e+02 2.240000e+02 \n/Device/TrainGlobal_request_size 2.100000e+01 2.100000e+01 \n/Device/TrainGlobal_response_size 9.604890e+06 9.604890e+06 \n/Device/Evaluate_request_size 2.400000e+02 2.400000e+02 \n\n 4 \n/Device/GetBatteryStatus_response_size 5.500000e+02 \n/Device/TrainEpoch_response_size 6.137760e+05 \n/Device/Evaluate_response_size 4.000000e+01 \n/Device/TrainBatch_request_size 2.832800e+09 \n/Device/EvaluateBatch_request_size 7.082035e+08 \n/Device/TrainBatch_response_size 2.831694e+09 \n/Device/EvaluateBatch_response_size 1.520000e+03 \n/Device/SetWeights_request_size 1.021859e+07 \n/Device/EndExperiment_request_size NaN \n/Device/StartExperiment_response_size 5.000000e+00 \n/Device/SetWeights_response_size 7.000000e+01 \n/Device/GetBatteryStatus_request_size 1.100000e+02 \n/Device/TrainEpoch_request_size 2.240000e+02 \n/Device/TrainGlobal_request_size 2.100000e+01 \n/Device/TrainGlobal_response_size 9.604890e+06 \n/Device/Evaluate_request_size 2.400000e+02 ", + "text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>0</th>\n <th>1</th>\n <th>2</th>\n <th>3</th>\n <th>4</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>/Device/GetBatteryStatus_response_size</th>\n <td>7.500000e+02</td>\n <td>7.000000e+02</td>\n <td>6.500000e+02</td>\n <td>6.000000e+02</td>\n <td>5.500000e+02</td>\n </tr>\n <tr>\n <th>/Device/TrainEpoch_response_size</th>\n <td>6.137760e+05</td>\n <td>6.137760e+05</td>\n <td>6.137760e+05</td>\n <td>6.137760e+05</td>\n <td>6.137760e+05</td>\n </tr>\n <tr>\n <th>/Device/Evaluate_response_size</th>\n <td>4.000000e+01</td>\n <td>4.000000e+01</td>\n <td>4.000000e+01</td>\n <td>4.000000e+01</td>\n <td>4.000000e+01</td>\n </tr>\n <tr>\n <th>/Device/TrainBatch_request_size</th>\n <td>2.832800e+09</td>\n <td>2.832800e+09</td>\n <td>2.832800e+09</td>\n <td>2.832800e+09</td>\n <td>2.832800e+09</td>\n </tr>\n <tr>\n <th>/Device/EvaluateBatch_request_size</th>\n <td>7.082035e+08</td>\n <td>7.082035e+08</td>\n <td>7.082035e+08</td>\n <td>7.082035e+08</td>\n <td>7.082035e+08</td>\n </tr>\n <tr>\n <th>/Device/TrainBatch_response_size</th>\n <td>2.831694e+09</td>\n <td>2.831694e+09</td>\n <td>2.831694e+09</td>\n <td>2.831694e+09</td>\n <td>2.831694e+09</td>\n </tr>\n <tr>\n <th>/Device/EvaluateBatch_response_size</th>\n <td>1.520000e+03</td>\n <td>1.520000e+03</td>\n <td>1.520000e+03</td>\n <td>1.520000e+03</td>\n <td>1.520000e+03</td>\n </tr>\n <tr>\n <th>/Device/SetWeights_request_size</th>\n <td>1.006521e+07</td>\n <td>1.006524e+07</td>\n <td>1.006524e+07</td>\n <td>1.006524e+07</td>\n <td>1.021859e+07</td>\n </tr>\n <tr>\n <th>/Device/EndExperiment_request_size</th>\n <td>5.000000e+00</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>NaN</td>\n </tr>\n <tr>\n <th>/Device/StartExperiment_response_size</th>\n <td>5.000000e+00</td>\n <td>5.000000e+00</td>\n <td>5.000000e+00</td>\n <td>5.000000e+00</td>\n <td>5.000000e+00</td>\n </tr>\n <tr>\n <th>/Device/SetWeights_response_size</th>\n <td>5.000000e+01</td>\n <td>6.000000e+01</td>\n <td>6.000000e+01</td>\n <td>6.000000e+01</td>\n <td>7.000000e+01</td>\n </tr>\n <tr>\n <th>/Device/GetBatteryStatus_request_size</th>\n <td>1.500000e+02</td>\n <td>1.400000e+02</td>\n <td>1.300000e+02</td>\n <td>1.200000e+02</td>\n <td>1.100000e+02</td>\n </tr>\n <tr>\n <th>/Device/TrainEpoch_request_size</th>\n <td>2.240000e+02</td>\n <td>2.240000e+02</td>\n <td>2.240000e+02</td>\n <td>2.240000e+02</td>\n <td>2.240000e+02</td>\n </tr>\n <tr>\n <th>/Device/TrainGlobal_request_size</th>\n <td>1.400000e+01</td>\n <td>2.100000e+01</td>\n <td>2.100000e+01</td>\n <td>2.100000e+01</td>\n <td>2.100000e+01</td>\n </tr>\n <tr>\n <th>/Device/TrainGlobal_response_size</th>\n <td>9.604898e+06</td>\n <td>9.604890e+06</td>\n <td>9.604890e+06</td>\n <td>9.604890e+06</td>\n <td>9.604890e+06</td>\n </tr>\n <tr>\n <th>/Device/Evaluate_request_size</th>\n <td>2.400000e+02</td>\n <td>2.400000e+02</td>\n <td>2.400000e+02</td>\n <td>2.400000e+02</td>\n <td>2.400000e+02</td>\n </tr>\n </tbody>\n</table>\n</div>" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame.from_dict(dpd('swarm_max_equal_batteries_only_flops'), orient='index')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:35:34.476949800Z", + "start_time": "2023-11-23T14:35:19.038455300Z" + } + }, + "id": "81fc596bb897ca5c" + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [ + { + "data": { + "text/plain": " 0 1 \\\n/Device/StartExperiment_response_size 5.000000e+00 5.000000e+00 \n/Device/EvaluateBatch_response_size 1.520000e+03 1.520000e+03 \n/Device/TrainEpoch_request_size 2.240000e+02 2.240000e+02 \n/Device/TrainGlobal_response_size 9.604898e+06 9.604890e+06 \n/Device/TrainGlobal_request_size 1.400000e+01 2.100000e+01 \n/Device/SetWeights_response_size 5.000000e+01 6.000000e+01 \n/Device/TrainBatch_request_size 2.832800e+09 2.832800e+09 \n/Device/EvaluateBatch_request_size 7.082035e+08 7.082035e+08 \n/Device/TrainEpoch_response_size 6.137760e+05 6.137760e+05 \n/Device/Evaluate_request_size 2.400000e+02 2.400000e+02 \n/Device/SetWeights_request_size 1.006521e+07 1.006524e+07 \n/Device/Evaluate_response_size 4.000000e+01 4.000000e+01 \n/Device/TrainBatch_response_size 2.831694e+09 2.831694e+09 \n/Device/GetBatteryStatus_request_size 7.500000e+01 7.000000e+01 \n/Device/EndExperiment_request_size 5.000000e+00 NaN \n/Device/GetBatteryStatus_response_size 3.750000e+02 3.500000e+02 \n\n 2 3 \\\n/Device/StartExperiment_response_size 5.000000e+00 5.000000e+00 \n/Device/EvaluateBatch_response_size 1.520000e+03 1.520000e+03 \n/Device/TrainEpoch_request_size 2.240000e+02 2.240000e+02 \n/Device/TrainGlobal_response_size 9.604890e+06 9.604890e+06 \n/Device/TrainGlobal_request_size 2.100000e+01 2.100000e+01 \n/Device/SetWeights_response_size 6.000000e+01 6.000000e+01 \n/Device/TrainBatch_request_size 2.832800e+09 2.832800e+09 \n/Device/EvaluateBatch_request_size 7.082035e+08 7.082035e+08 \n/Device/TrainEpoch_response_size 6.137760e+05 6.137760e+05 \n/Device/Evaluate_request_size 2.400000e+02 2.400000e+02 \n/Device/SetWeights_request_size 1.006524e+07 1.006524e+07 \n/Device/Evaluate_response_size 4.000000e+01 4.000000e+01 \n/Device/TrainBatch_response_size 2.831694e+09 2.831694e+09 \n/Device/GetBatteryStatus_request_size 6.500000e+01 6.000000e+01 \n/Device/EndExperiment_request_size NaN NaN \n/Device/GetBatteryStatus_response_size 3.250000e+02 3.000000e+02 \n\n 4 \n/Device/StartExperiment_response_size 5.000000e+00 \n/Device/EvaluateBatch_response_size 1.520000e+03 \n/Device/TrainEpoch_request_size 2.240000e+02 \n/Device/TrainGlobal_response_size 9.604890e+06 \n/Device/TrainGlobal_request_size 2.100000e+01 \n/Device/SetWeights_response_size 7.000000e+01 \n/Device/TrainBatch_request_size 2.832800e+09 \n/Device/EvaluateBatch_request_size 7.082035e+08 \n/Device/TrainEpoch_response_size 6.137760e+05 \n/Device/Evaluate_request_size 2.400000e+02 \n/Device/SetWeights_request_size 1.021859e+07 \n/Device/Evaluate_response_size 4.000000e+01 \n/Device/TrainBatch_response_size 2.831694e+09 \n/Device/GetBatteryStatus_request_size 5.500000e+01 \n/Device/EndExperiment_request_size NaN \n/Device/GetBatteryStatus_response_size 2.750000e+02 ", + "text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>0</th>\n <th>1</th>\n <th>2</th>\n <th>3</th>\n <th>4</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>/Device/StartExperiment_response_size</th>\n <td>5.000000e+00</td>\n <td>5.000000e+00</td>\n <td>5.000000e+00</td>\n <td>5.000000e+00</td>\n <td>5.000000e+00</td>\n </tr>\n <tr>\n <th>/Device/EvaluateBatch_response_size</th>\n <td>1.520000e+03</td>\n <td>1.520000e+03</td>\n <td>1.520000e+03</td>\n <td>1.520000e+03</td>\n <td>1.520000e+03</td>\n </tr>\n <tr>\n <th>/Device/TrainEpoch_request_size</th>\n <td>2.240000e+02</td>\n <td>2.240000e+02</td>\n <td>2.240000e+02</td>\n <td>2.240000e+02</td>\n <td>2.240000e+02</td>\n </tr>\n <tr>\n <th>/Device/TrainGlobal_response_size</th>\n <td>9.604898e+06</td>\n <td>9.604890e+06</td>\n <td>9.604890e+06</td>\n <td>9.604890e+06</td>\n <td>9.604890e+06</td>\n </tr>\n <tr>\n <th>/Device/TrainGlobal_request_size</th>\n <td>1.400000e+01</td>\n <td>2.100000e+01</td>\n <td>2.100000e+01</td>\n <td>2.100000e+01</td>\n <td>2.100000e+01</td>\n </tr>\n <tr>\n <th>/Device/SetWeights_response_size</th>\n <td>5.000000e+01</td>\n <td>6.000000e+01</td>\n <td>6.000000e+01</td>\n <td>6.000000e+01</td>\n <td>7.000000e+01</td>\n </tr>\n <tr>\n <th>/Device/TrainBatch_request_size</th>\n <td>2.832800e+09</td>\n <td>2.832800e+09</td>\n <td>2.832800e+09</td>\n <td>2.832800e+09</td>\n <td>2.832800e+09</td>\n </tr>\n <tr>\n <th>/Device/EvaluateBatch_request_size</th>\n <td>7.082035e+08</td>\n <td>7.082035e+08</td>\n <td>7.082035e+08</td>\n <td>7.082035e+08</td>\n <td>7.082035e+08</td>\n </tr>\n <tr>\n <th>/Device/TrainEpoch_response_size</th>\n <td>6.137760e+05</td>\n <td>6.137760e+05</td>\n <td>6.137760e+05</td>\n <td>6.137760e+05</td>\n <td>6.137760e+05</td>\n </tr>\n <tr>\n <th>/Device/Evaluate_request_size</th>\n <td>2.400000e+02</td>\n <td>2.400000e+02</td>\n <td>2.400000e+02</td>\n <td>2.400000e+02</td>\n <td>2.400000e+02</td>\n </tr>\n <tr>\n <th>/Device/SetWeights_request_size</th>\n <td>1.006521e+07</td>\n <td>1.006524e+07</td>\n <td>1.006524e+07</td>\n <td>1.006524e+07</td>\n <td>1.021859e+07</td>\n </tr>\n <tr>\n <th>/Device/Evaluate_response_size</th>\n <td>4.000000e+01</td>\n <td>4.000000e+01</td>\n <td>4.000000e+01</td>\n <td>4.000000e+01</td>\n <td>4.000000e+01</td>\n </tr>\n <tr>\n <th>/Device/TrainBatch_response_size</th>\n <td>2.831694e+09</td>\n <td>2.831694e+09</td>\n <td>2.831694e+09</td>\n <td>2.831694e+09</td>\n <td>2.831694e+09</td>\n </tr>\n <tr>\n <th>/Device/GetBatteryStatus_request_size</th>\n <td>7.500000e+01</td>\n <td>7.000000e+01</td>\n <td>6.500000e+01</td>\n <td>6.000000e+01</td>\n <td>5.500000e+01</td>\n </tr>\n <tr>\n <th>/Device/EndExperiment_request_size</th>\n <td>5.000000e+00</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>NaN</td>\n </tr>\n <tr>\n <th>/Device/GetBatteryStatus_response_size</th>\n <td>3.750000e+02</td>\n <td>3.500000e+02</td>\n <td>3.250000e+02</td>\n <td>3.000000e+02</td>\n <td>2.750000e+02</td>\n </tr>\n </tbody>\n</table>\n</div>" + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame.from_dict(dpd('swarm_seq_equal_batteries_only_flops'), orient='index')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:36:06.659579700Z", + "start_time": "2023-11-23T14:35:34.518736600Z" + } + }, + "id": "9744e3f848c98bbf" + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "data": { + "text/plain": "{'/Device/GetBatteryStatus_response_size': 12500.0,\n '/Device/SetWeights_response_size': 1500.0,\n '/Device/Evaluate_request_size': 6000.0,\n '/Device/StartExperiment_response_size': 25.0,\n '/Device/Evaluate_response_size': 1000.0,\n '/Device/TrainEpoch_response_size': 15344400.0,\n '/Device/EndExperiment_request_size': 25.0,\n '/Device/GetBatteryStatus_request_size': 2500.0,\n '/Device/SetWeights_request_size': 250663400.0,\n '/Device/TrainEpoch_request_size': 5600.0,\n '/Device/EvaluateBatch_request_size': 17705087800.0,\n '/Device/TrainBatch_response_size': 70792350000.0,\n '/Device/TrainBatch_request_size': 70820010000.0,\n '/Device/EvaluateBatch_response_size': 38000.0,\n '/Device/TrainGlobal_response_size': 240122250.0,\n '/Device/TrainGlobal_request_size': 350.0}" + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dpm('swarm_max_equal_batteries_unlimited')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:37:14.507147Z", + "start_time": "2023-11-23T14:35:51.928280100Z" + } + }, + "id": "ceadccaf172f3f92" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [ + { + "data": { + "text/plain": "{'/Device/SetWeights_response_size': 1250.0,\n '/Device/TrainEpoch_response_size': 15344400.0,\n '/Device/Evaluate_request_size': 6000.0,\n '/Device/EndExperiment_request_size': 25.0,\n '/Device/GetBatteryStatus_response_size': 6250.0,\n '/Device/SetWeights_request_size': 19103105.0,\n '/Device/GetBatteryStatus_request_size': 1250.0,\n '/Device/Evaluate_response_size': 1000.0,\n '/Device/TrainEpoch_request_size': 5600.0,\n '/Device/StartExperiment_response_size': 25.0,\n '/Device/EvaluateBatch_request_size': 17705087800.0,\n '/Device/EvaluateBatch_response_size': 38000.0,\n '/Device/TrainGlobal_request_size': 350.0,\n '/Device/TrainBatch_response_size': 70792350000.0,\n '/Device/TrainGlobal_response_size': 240122250.0,\n '/Device/TrainBatch_request_size': 70820010000.0}" + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dpm('split_equal_batteries_unlimited')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-23T14:38:35.200762600Z", + "start_time": "2023-11-23T14:37:14.500146600Z" + } + }, + "id": "5247508e730487a" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "dpm('swarm_seq_equal_batteries_unlimited')" + ], + "metadata": { + "collapsed": false, + "is_executing": true, + "ExecuteTime": { + "start_time": "2023-11-23T14:38:35.208821600Z" + } + }, + "id": "bb8f3730adc4db72" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import pandas as pd\n", + "methods = ['swarm_seq_equal_batteries_unlimited', 'swarm_rand_equal_batteries_unlimited', 'split_equal_batteries_unlimited', 'swarm_max_equal_batteries_unlimited']\n", + "times_dict = {}\n", + "for m in methods:\n", + " times = dpm(m)\n", + " for measurement in times:\n", + " if measurement in times_dict.keys():\n", + " times_dict[measurement].append(times[measurement])\n", + " else:\n", + " times_dict[measurement] = [times[measurement]]\n", + "\n", + "timetable = pd.DataFrame.from_dict(times_dict)" + ], + "metadata": { + "collapsed": false, + "is_executing": true + }, + "id": "c801e2e71b87f06f" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "timetable" + ], + "metadata": { + "collapsed": false, + "is_executing": true + }, + "id": "48afb3a97554b464" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "is_executing": true + }, + "id": "e9f5273f22b6935" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/results/first_experiments_evaluation.ipynb b/results/first_experiments_evaluation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3e2fc95dd0db584b789f5d487fc1e2d297ae80a8 --- /dev/null +++ b/results/first_experiments_evaluation.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "7189f4d0-f6aa-44f5-be9b-ec094eed2fc9", + "metadata": { + "is_executing": true + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "from plotting_helpers import get_earliest_start_time, get_time_dict, get_device_activity_plot, get_runtime" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5fd48fa8-64c4-4dd2-9a60-3518d9f38bfd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "521.6006650924683\n" + ] + } + ], + "source": [ + "split_client_train = pd.read_csv(\"./csv/split_client_train_times.csv\")\n", + "split_device_train = pd.read_csv(\"./csv/split_device_train_times.csv\")\n", + "split_client_eval = pd.read_csv(\"./csv/split_client_eval_times.csv\")\n", + "split_device_eval = pd.read_csv(\"./csv/split_device_eval_times.csv\")\n", + "\n", + "split_offset = get_earliest_start_time([split_client_train, split_device_train, split_client_eval, split_device_eval])\n", + "print(get_runtime([split_client_train, split_device_train, split_client_eval, split_device_eval]))\n", + "\n", + "split_client_train_times = get_time_dict(dataframe=split_client_train, offset=split_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - client_train_epoch_time\")\n", + "split_device_train_times = get_time_dict(dataframe=split_device_train, offset=split_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - train_global_time\", device_indices=[0])\n", + "split_client_eval_times = get_time_dict(dataframe=split_client_eval, offset=split_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - client_evaluate_time\")\n", + "split_device_eval_times = get_time_dict(dataframe=split_device_eval, offset=split_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - evaluate_global_time\", device_indices=[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e68b912b-3a0d-4b8d-8b66-72bbaba05292", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAGwCAYAAABcnuQpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAyPklEQVR4nO3deXRU9f3/8ddEJiGQjUQDQUB2EGRxxYhVLAEUC2oViloQRT0oLnzBfBWoAtZKLV8olqJ+QQUN4NIaFi1aUzbBYlRCWKoGRAJ+WWQJZAjBMJDP7w9+jA4kMBPmZuYTno9zck7m3pv3fc87NzOvc2cm12WMMQIAALBUVLgbAAAAOBuEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAq9UKdwNOKy8v144dOxQfHy+XyxXudgAAQACMMTp48KAaNmyoqKjTn3up8WFmx44daty4cbjbAAAAVfD999+rUaNGp92mxoeZ+Ph4SceHkZCQENLaXq9XH3/8sXr27Cm32x3S2ucy5uoM5uocZusM5uocG2br8XjUuHFj3/P46dT4MHPipaWEhARHwkydOnWUkJAQsQeDjZirM5irc5itM5irc2yabSBvEeENwAAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8zAp9P/DKCuhXVDzbb77+RcbevZlmPMVrb93s6l44EwAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqLmOMCXcTTvJ4PEpMTFRxcbESEhJCWtvr9WrRokXq3bu33G53SGufy5irM5irc5itM5irc2yYbTDP35yZAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsFnFhplu3bho+fHi42wAAAJaIuDBTmaFDh8rlcmnKlCnhbgUAAEQQK8LMvHnz9Nlnn6lhw4bhbgUAAESYsIaZQ4cOadCgQYqLi1NaWpomTZp0yjbbt2/Xo48+qjlz5kTs9SMAAED41ArnzjMzM7V8+XItWLBAqampGj16tPLy8tS5c2dJUnl5uQYOHKjMzEy1b98+oJplZWUqKyvz3fZ4PJKOX1TL6/WGtP8T9UJd91zHXJ3BXJ3DbJ3BXJ1jw2yD6S1sYaakpESvvfaaZs+ere7du0uS3njjDTVq1Mi3zQsvvKBatWrpscceC7juhAkTNH78+FOWf/zxx6pTp87ZN16BnJwcR+qe65irM5irc5itM5ircyJ5tqWlpQFvG7Yws3nzZh05ckRdunTxLUtOTlabNm0kSatXr9aLL76ovLw8uVyugOuOGjVKI0aM8N32eDxq3LixevbsecZLiAfL6/UqJydHPXr0cOYlsKmJFS9/tNiZ+hFSt9K5Rmi/1V63irVPe7z+vF6k3/8Im6sUwGOBbbOo6LEnFLWDrBvUY6wTs7D99+bUbM9QO1ROvLISiLC+zHQ6K1as0O7du9WkSRPfsmPHjmnkyJGaMmWKCgsLK/y5mJgYxcTEnLLc7XY79p4bx2qbw5Xt0Jn6EVb3lLlGeL/VVvcsa1d4vP68XqTf/wid6/HNK3kssG0WFT32hKJ2FesG9BjrxCxs/705NdsAa5+tYJ5Xw/YG4BYtWsjtdis3N9e3bP/+/dq4caMkaeDAgVq3bp3y8/N9Xw0bNlRmZqb++c9/hqttAAAQYcJ2ZiYuLk5DhgxRZmamUlJSlJqaqjFjxigq6ni+SklJUUpKit/PuN1uNWjQwPdSFAAAQFhfZpo4caJKSkrUp08fxcfHa+TIkSoudv51OAAAUHOENczExcUpKytLWVlZvmWZmZmVbl/Z+2QAAMC5y4r/AAwAAFAZwgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUi9tpMkDTS2Fmfus7WdaK2E70yV+pWV12nats2B9vqhhBnZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCTITq9D8DrKzvdN8I/Yyd+J3ZeBzwN2Evm45h2+pG6n5PRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDWXMcaEuwkneTweJSYmqri4WAkJCSGt7fV6tWjRIvXu3Vtutzuktc9lzNUZzNU5zNYZzNU5Nsw2mOdvzswAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFaLuDDTrVs3DR8+PNxtAAAAS0RcmPm5cePGqW3btqpbt67q1aunjIwM5ebmhrstAAAQQSI6zLRu3Vp//etftX79eq1cuVJNmzZVz549tWfPnnC3BgAAIkRYw8yhQ4c0aNAgxcXFKS0tTZMmTfJbf9dddykjI0PNmzdX+/btNXnyZHk8Hq1bty5MHQMAgEhTK5w7z8zM1PLly7VgwQKlpqZq9OjRysvLU+fOnU/Z9siRI5o+fboSExPVqVOnSmuWlZWprKzMd9vj8Ug6flEtr9cb0v5P1At13XMdc3UGc3UOs3UGc3WODbMNprewXTW7pKREKSkpmj17tvr16ydJKioqUqNGjfTggw9qypQpkqQPPvhAAwYMUGlpqdLS0jR//nxdeeWVldYdN26cxo8ff8ryuXPnqk6dOo7cFwAAEFqlpaW66667ArpqdtjCzNq1a9W5c2dt3bpVTZo08S2/9NJLdf311/vCzKFDh7Rz507t3btXM2bM0JIlS5Sbm6vU1NQK61Z0ZqZx48bau3fvGYcRLK/Xq5ycHPXo0ePsLqE+NfHUZY8WV71eIPsIVX0H6vrm+t19cpvDIa0tyapZhLKuo3OtrvvuZO2zqOv3WPDK+SGre4oIP8ZCXbfSx1gn+o3QGThV97TPX071HCSPx6Pzzz8/oDAT1peZAlG3bl21bNlSLVu21NVXX61WrVrptdde06hRoyrcPiYmRjExMacsd7vdZxc4TuOsa//8ieWnolWvF8g+QlXfqbqS3Oaw/5NupPdsSV1H5lpd993J2iGo63a7/Wcboro+lhxjoa57ymOsE/1G+Aycqlvh85eTx3AQgnleDdsbgFu0aCG32+33Uev9+/dr48aNp/258vJyvzMvAADg3Ba2MzNxcXEaMmSIMjMzlZKSotTUVI0ZM0ZRUcfz1aFDh/SHP/xBffv2VVpamvbu3atp06Zp+/btvvfYAAAAhPVlpokTJ6qkpER9+vRRfHy8Ro4cqeLi46/NnXfeefrmm2/0xhtvaO/evUpJSdGVV16pFStWqH379uFsGwAARJCwhpm4uDhlZWUpKyvLtywzM9P3fXZ2djjaAgAAFono/wAMAABwJoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrRfy1mc4JI6vhWp9O7cPJ3h8tduaaILbNItR1nZirLfe9OmrTs511berVybpO13YIZ2YAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACr1Qr2B5YsWaLs7GwVFhbK5XKpWbNmuuOOO3Tdddc50R8AAMBpBXVmZujQocrIyNBbb72lffv2ac+ePZozZ45uuOEGPfroo071CAAAUKmAw8y8efM0c+ZMvf7669q7d69WrVqlzz77THv27NGMGTM0ffp0LVy40MleAQAAThFwmJk5c6ZGjBihwYMHy+Vy/VQgKkr33Xefhg8frtdee82RJmuyTv8zoEbswxZOzcK2uqFkQ48nc6rnrlPvdaSuZN8xZlNdm3p1WjA9R9L9CzjM5OXl6bbbbqt0/a9//WutXr06JE0BAAAEKuAws3fvXjVq1KjS9Y0aNdK+fftC0hQAAECgAg4zR44ckdvtrnR9rVq1dOTIkZA0BQAAEKigPpr99NNPq06dOhWuKy0tDUlDAAAAwQg4zFx33XUqKCg44zYAAADVKeAws2zZMgfbAAAAqBouZwAAAKwW8JmZESNGBLTd5MmTq9wMAABAsAIOM2vWrDnjNj//Z3oAAADVIeAws3TpUif7AAAAqBLeMwMAAKxGmAEAAFYjzAAAAKu5jDEm3E04yePxKDExUcXFxUpISAhpba/Xq0WLFql3796nvdQDgsNcncFcncNsncFcnWPDbIN5/ubMDAAAsFqVwsyKFSv029/+Vunp6dq+fbskKSsrSytXrgxpcwAAAGcSdJh577331KtXL8XGxmrNmjUqKyuTJBUXF+v5558PeYMAAACnE3SYee655/TKK69oxowZfq+zde3aVXl5eSFtDgAA4EyCDjMFBQUVXh07MTFRBw4cCEVPAAAAAQs6zDRo0EDffvvtKctXrlyp5s2bh6QpAACAQAUdZh544AE9/vjjys3Nlcvl0o4dOzRnzhw98cQTeuihh5zoEQAAoFIBX5vphKeeekrl5eXq3r27SktLdd111ykmJkZPPPGEHn30USd6BAAAqFTQZ2ZcLpfGjBmjoqIibdiwQZ999pn27Nmj3//+9yFpqFu3bho+fHhIagEAgJov6DBTXFysoqIiRUdHq127drrqqqsUFxenoqIieTyekDXm9Xr15JNPqkOHDqpbt64aNmyoQYMGaceOHSHbBwAAsF/QYWbAgAF6++23T1n+7rvvasCAASFpSpJKS0uVl5enp59+Wnl5ecrOzlZBQYH69u0bsn0AAAD7BR1mcnNzdcMNN5yyvFu3bsrNzQ2q1qFDhzRo0CDFxcUpLS1NkyZN8q1LTExUTk6O+vfvrzZt2ujqq6/WX//6V61evVrbtm0Ltm0AAFBDBf0G4LKyMh09evSU5V6vV4cPHw6qVmZmppYvX64FCxYoNTVVo0ePVl5enjp37lzh9sXFxXK5XEpKSjptfyf+K7Ek30tfXq9XXq83qP7O5ES9UNc91zFXZzBX5zBbZzBX59gw22B6C/qq2TfccIMuueQSTZ061W/5sGHDtG7dOq1YsSKgOiUlJUpJSdHs2bPVr18/SVJRUZEaNWqkBx98UFOmTPHb/scff1TXrl3Vtm1bzZkzp9K648aN0/jx409ZPnfuXNWpUyeg3gAAQHiVlpbqrrvuCuiq2UGfmXnuueeUkZGhtWvXqnv37pKkxYsX64svvtDHH38ccJ3NmzfryJEj6tKli29ZcnKy2rRpc8q2Xq9X/fv3lzFGL7/88mnrjho1SiNGjPDd9ng8aty4sXr27HnGYQTL6/UqJydHPXr0OPMl1KcmVrz80eKQ9lThvkK1D6fqnlTb64pVTvPXA5trEHUlRf4snKj7/2uGbK7VdByEtLbDdUN6zJ5U28eSWYSyru8x9rv75DaHQ1ZXkn+/ETwDp2pXONsQ1A2lYD5UFHSY6dq1q1atWqWJEyfq3XffVWxsrDp27KjXXntNrVq1CrbcGZ0IMlu3btWSJUvOGEhiYmIUExNzynK32x26B5mq1DaVvATnRE8n7ytU+3CqbkW1FaLfmW2zcKLuSTXPeq7VeRxE8lwrqBvSxxnLZxHK48JtDv/0hOtEvxbMwKnafrMNYd1QCOZvKegwI0mdO3c+7Us9gWjRooXcbrdyc3PVpEkTSdL+/fu1ceNGXX/99ZJ+CjKbNm3S0qVLlZKSclb7BAAANU9AYcbj8fjOiJzptE+gL+XExcVpyJAhyszMVEpKilJTUzVmzBhFRR3/gJXX69Udd9yhvLw8ffDBBzp27Jh27dol6fjLUdHR0QHtBwAA1GwBhZl69epp586dSk1NVVJSklwu1ynbGGPkcrl07NixgHc+ceJElZSUqE+fPoqPj9fIkSNVXHz89brt27dr4cKFknTKp5uWLl2qbt26BbwfAABQcwUUZpYsWaLk5GTf9xWFmaqIi4tTVlaWsrKyfMsyMzN93wf5QSsAAHAOCijMnHgPiyTOiAAAgIgS9H8AbtWqlcaNG6dNmzY50Q8AAEBQgg4zDz/8sP7xj3+obdu2uvLKK/Xiiy/63pgLAABQ3YIOM//1X/+lL774Ql9//bV69+6tadOm+f4p3ZtvvulEjwAAAJUKOsyc0Lp1a40fP14bN27UihUrtGfPHt17772h7A0AAOCMqvRP8074/PPPNXfuXL3zzjvyeDy+aywBAABUl6DDzMaNGzVnzhy99dZb2rJli375y1/qhRde0K9//WvFxcU50SMAAEClgg4zJ974O2zYMA0YMED169d3oq+aZ2Q1/s8cp/bl5H34eW2vV1q0KPR1Q8mmuidqhmqu1XUc2FQ3lMfsybVDzba60vGLH4b6mkFO/q05wanaTsw2DIIOMwUFBY5cUBIAAKAqqvR/Zg4cOKBXX31Vo0aNUlFRkSQpLy9P27dvD3mDAAAApxP0mZl169ape/fuSkpKUmFhoR544AElJycrOztb27Zt4+PZAACgWlXp/8zce++92rRpk2rXru1b3rt3b33yySchbQ4AAOBMgj4z8+WXX2r69OmnLL/wwgv5T8AAAKDaBX1mJiYmRh6P55TlGzdu1AUXXBCSpgAAAAIVdJjp27evnn32WXm9XkmSy+XStm3b9OSTT+r2228PeYMAAACnE3SYmTRpkkpKSpSamqrDhw/r+uuvV8uWLRUfH68//OEPTvQIAABQqaDfM5OYmKicnBytXLlS69atU0lJiS677DJlZGQ40R8AAMBpVfnaTNdee62uvfbaUPYCAAAQtKDCTHl5uWbNmqXs7GwVFhbK5XKpWbNmuuOOOzRw4EC5XC6n+gQAAKhQwO+ZMcaob9++uv/++7V9+3Z16NBB7du319atWzV48GDddtttTvYJAABQoYDPzMyaNUuffPKJFi9erBtuuMFv3ZIlS3TrrbfqzTff1KBBg0LeJAAAQGUCPjPz1ltvafTo0acEGUn65S9/qaeeekpz5swJaXMAAABnEnCYWbdunW688cZK1990001au3ZtSJoCAAAIVMBhpqioSPXr1690ff369bV///6QNAUAABCogMPMsWPHVKtW5W+xOe+883T06NGQNAUAABCogN8AbIzR4MGDFRMTU+H6srKykDUFAAAQqIDDzD333HPGbfgkEwAAqG4Bh5mZM2c62QcAAECVBH2hSQAAgEhCmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphphp0+p8BNXJfoeJUz9S1C3P9iW2zcKpu16n3OlLXiX6dPM5sPIarG2EGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNVcxhgT7iac5PF4lJiYqOLiYiUkJIS0ttfr1aJFi9S7d2+53e6Q1j6XMVdnMFfnMFtnMFfn2DDbYJ6/OTMDAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFgt4sJMt27dNHz48HC3AQAALBFxYebnsrOz1bNnT6WkpMjlcik/Pz/cLQEAgAgT0WHm0KFDuvbaa/XCCy+EuxUAABChaoVz54cOHdJDDz2k7OxsxcfH64knnvBbP3DgQElSYWFhGLoDAAA2CGuYyczM1PLly7VgwQKlpqZq9OjRysvLU+fOnatcs6ysTGVlZb7bHo9H0vGLanm93rNt2c+JeqGue65jrs5grs5hts5grs6xYbbB9Ba2q2aXlJQoJSVFs2fPVr9+/SRJRUVFatSokR588EFNmTLFt21hYaGaNWumNWvWnDHojBs3TuPHjz9l+dy5c1WnTp1Q3gUAAOCQ0tJS3XXXXQFdNTtsZ2Y2b96sI0eOqEuXLr5lycnJatOmzVnVHTVqlEaMGOG77fF41LhxY/Xs2fOMwwiW1+tVTk6OevTo8dMl1Kcmnrrho8Uh3a+fk/cXyn05VfsMdSucawjqVlkNqVuluTp5PNs+15/VrvIxW1lt22bhUF3v0L1nN9dK6oakX8sfe72uWOU0f/3sZ+ugE6+sBCKsLzM5ISYmRjExMacsd7vdjv3C/GqbwxVt4Mh+K9xfKPflVO0A6wb9Owtzv7bUDWquTh7Pts+1gtpVfpyxfRYO1z3rx28n+q0Jj71y9rnxbAXTV9g+zdSiRQu53W7l5ub6lu3fv18bN24MV0sAAMBCYTszExcXpyFDhigzM1MpKSlKTU3VmDFjFBX1U74qKirStm3btGPHDklSQUGBJKlBgwZq0KBBWPoGAACRJawvM02cOFElJSXq06eP4uPjNXLkSBUX//Ta4MKFC3Xvvff6bg8YMECSNHbsWI0bN6662wUAABEorGEmLi5OWVlZysrK8i3LzMz0fT948GANHjw4DJ0BAABbRPR/AAYAADgTwgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNVq3LWZwm5kNV+E3Mn9OVWbuvbU5fhyvq6TtW2v6/U6UzdSazpd++d1vV5p0SJn9hMGnJkBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMh1Ol/BoS7hZBx6r5QFxK/r59jFsfZ1q9TbDkeIu33RZgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDWXMcaEuwkneTweJSYmqri4WAkJCSGt7fV6tWjRIvXu3Vtutzuktc9lzNUZzNU5zNYZzNU5ZzvbDRs2+L6/5JJLQtmaTzDP35yZAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsFnFhplu3bho+fHi42wAAAJaIuDDzc8YYPfPMM0pLS1NsbKwyMjK0adOmcLcFAAAiSESHmT/96U/6y1/+oldeeUW5ubmqW7euevXqpR9//DHcrQEAgAgR1jBz6NAhDRo0SHFxcUpLS9OkSZN864wxmjJlin73u9/plltuUceOHfXmm29qx44dmj9/fviaBgAAEaVWOHeemZmp5cuXa8GCBUpNTdXo0aOVl5enzp07a8uWLdq1a5cyMjJ82ycmJqpLly5atWqVBgwYUGHNsrIylZWV+W57PB5Jxy+q5fV6Q9r/iXqhrnuuY67OYK7OYbbOYK7OOdvZlpeXn1Ir1IKpG7arZpeUlCglJUWzZ89Wv379JElFRUVq1KiRHnzwQfXv319du3bVjh07lJaW5vu5/v37y+Vy6Z133qmw7rhx4zR+/PhTls+dO1d16tRx5s4AAICQKi0t1V133RXQVbPDdmZm8+bNOnLkiLp06eJblpycrDZt2pxV3VGjRmnEiBG+2x6PR40bN1bPnj3POIxgeb1e5eTkqMd398ltDkuPFoe0fqWmJv70faj3eaJ2GOv65tqjx5kvTR8B/VapbqhrB1D3nJirE3+DAdQOarZB1K0SW4/dCmpWaa4n162kdpXZ+ns7qe5ZzbaanHhlJRBhfZnpdBo0aCBJ+uGHH/zOzPzwww/q3LlzpT8XExOjmJiYU5a73W7HfmFuc/h4mKmuA8Ic/tnOQ7zPE7UjoG5Av7MI6jeouqGuHUTdGj1XJ/4Gg6gd1OOMbbNw+tg9Tc0qPX6HsV+b6jr53Hi2gukrbG8AbtGihdxut3Jzc33L9u/fr40bN0qSmjVrpgYNGmjx4sW+9R6PR7m5uUpPT6/2fgEAQGQK25mZuLg4DRkyRJmZmUpJSVFqaqrGjBmjqKjj+crlcmn48OF67rnn1KpVKzVr1kxPP/20GjZsqFtvvTVcbQMAgAgT1peZJk6cqJKSEvXp00fx8fEaOXKkiot/el3vv//7v3Xo0CE9+OCDOnDggK699lp99NFHql27dhi7BgAAkSSsYSYuLk5ZWVnKysryLcvMzPR973K59Oyzz+rZZ58NR3sAAMACEf0fgAEAAM6EMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDdTwxUXFxtJpri4OOS1jxw5YubPn29iY48YyfAVoq/YWObKXO36YrbM1bavs51tdQjm+ZszMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsVivcDdQExcWS2x3uLmoOr1datIi5hhpzdQ6zdQZzdU5Nmy1nZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWqxXuBpxmjJEkeTyekNf2er0qLS2Vx+OR2+0Oef1zFXN1BnN1DrN1BnN1jg2zPfG8feJ5/HRqfJg5ePCgJKlx48Zh7gQAAATr4MGDSkxMPO02LhNI5LFYeXm5duzYofj4eLlcrpDW9ng8aty4sb7//nslJCSEtPa5jLk6g7k6h9k6g7k6x4bZGmN08OBBNWzYUFFRp39XTI0/MxMVFaVGjRo5uo+EhISIPRhsxlydwVydw2ydwVydE+mzPdMZmRN4AzAAALAaYQYAAFiNMHMWYmJiNHbsWMXExIS7lRqFuTqDuTqH2TqDuTqnps22xr8BGAAA1GycmQEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEmSqaNm2amjZtqtq1a6tLly76/PPPw91SRPvkk0/Up08fNWzYUC6XS/Pnz/dbb4zRM888o7S0NMXGxiojI0ObNm3y26aoqEh33323EhISlJSUpCFDhqikpKQa70XkmTBhgq688krFx8crNTVVt956qwoKCvy2+fHHHzVs2DClpKQoLi5Ot99+u3744Qe/bbZt26abb75ZderUUWpqqjIzM3X06NHqvCsR5+WXX1bHjh19/1QsPT1dH374oW89cw2NP/7xj3K5XBo+fLhvGbOtmnHjxsnlcvl9tW3b1re+Rs/VIGhvv/22iY6ONq+//rr5z3/+Yx544AGTlJRkfvjhh3C3FrEWLVpkxowZY7Kzs40kM2/ePL/1f/zjH01iYqKZP3++Wbt2renbt69p1qyZOXz4sG+bG2+80XTq1Ml89tlnZsWKFaZly5bmzjvvrOZ7Ell69eplZs6caTZs2GDy8/NN7969TZMmTUxJSYlvm6FDh5rGjRubxYsXmy+//NJcffXV5pprrvGtP3r0qLnkkktMRkaGWbNmjVm0aJE5//zzzahRo8JxlyLGwoULzT/+8Q+zceNGU1BQYEaPHm3cbrfZsGGDMYa5hsLnn39umjZtajp27Ggef/xx33JmWzVjx4417du3Nzt37vR97dmzx7e+Js+VMFMFV111lRk2bJjv9rFjx0zDhg3NhAkTwtiVPU4OM+Xl5aZBgwZm4sSJvmUHDhwwMTEx5q233jLGGPPVV18ZSeaLL77wbfPhhx8al8tltm/fXm29R7rdu3cbSWb58uXGmONzdLvd5m9/+5tvm6+//tpIMqtWrTLGHA+aUVFRZteuXb5tXn75ZZOQkGDKysqq9w5EuHr16plXX32VuYbAwYMHTatWrUxOTo65/vrrfWGG2Vbd2LFjTadOnSpcV9PnystMQTpy5IhWr16tjIwM37KoqChlZGRo1apVYezMXlu2bNGuXbv8ZpqYmKguXbr4Zrpq1SolJSXpiiuu8G2TkZGhqKgo5ebmVnvPkaq4uFiSlJycLElavXq1vF6v32zbtm2rJk2a+M22Q4cOql+/vm+bXr16yePx6D//+U81dh+5jh07prfffluHDh1Seno6cw2BYcOG6eabb/abocQxe7Y2bdqkhg0bqnnz5rr77ru1bds2STV/rjX+QpOhtnfvXh07dszvly1J9evX1zfffBOmruy2a9cuSapwpifW7dq1S6mpqX7ra9WqpeTkZN8257ry8nINHz5cXbt21SWXXCLp+Nyio6OVlJTkt+3Js61o9ifWncvWr1+v9PR0/fjjj4qLi9O8efPUrl075efnM9ez8PbbbysvL09ffPHFKes4ZquuS5cumjVrltq0aaOdO3dq/Pjx+sUvfqENGzbU+LkSZoAaYtiwYdqwYYNWrlwZ7lZqjDZt2ig/P1/FxcX6+9//rnvuuUfLly8Pd1tW+/777/X4448rJydHtWvXDnc7NcpNN93k+75jx47q0qWLLrroIr377ruKjY0NY2fO42WmIJ1//vk677zzTnkH+A8//KAGDRqEqSu7nZjb6WbaoEED7d6922/90aNHVVRUxNwlPfLII/rggw+0dOlSNWrUyLe8QYMGOnLkiA4cOOC3/cmzrWj2J9ady6Kjo9WyZUtdfvnlmjBhgjp16qQXX3yRuZ6F1atXa/fu3brssstUq1Yt1apVS8uXL9df/vIX1apVS/Xr12e2IZKUlKTWrVvr22+/rfHHLGEmSNHR0br88su1ePFi37Ly8nItXrxY6enpYezMXs2aNVODBg38ZurxeJSbm+ubaXp6ug4cOKDVq1f7tlmyZInKy8vVpUuXau85Uhhj9Mgjj2jevHlasmSJmjVr5rf+8ssvl9vt9pttQUGBtm3b5jfb9evX+4XFnJwcJSQkqF27dtVzRyxRXl6usrIy5noWunfvrvXr1ys/P9/3dcUVV+juu+/2fc9sQ6OkpESbN29WWlpazT9mw/0OZBu9/fbbJiYmxsyaNct89dVX5sEHHzRJSUl+7wCHv4MHD5o1a9aYNWvWGElm8uTJZs2aNWbr1q3GmOMfzU5KSjILFiww69atM7fcckuFH82+9NJLTW5urlm5cqVp1arVOf/R7IceesgkJiaaZcuW+X0cs7S01LfN0KFDTZMmTcySJUvMl19+adLT0016erpv/YmPY/bs2dPk5+ebjz76yFxwwQVWfBzTSU899ZRZvny52bJli1m3bp156qmnjMvlMh9//LExhrmG0s8/zWQMs62qkSNHmmXLlpktW7aYTz/91GRkZJjzzz/f7N692xhTs+dKmKmiqVOnmiZNmpjo6Ghz1VVXmc8++yzcLUW0pUuXGkmnfN1zzz3GmOMfz3766adN/fr1TUxMjOnevbspKCjwq7Fv3z5z5513mri4OJOQkGDuvfdec/DgwTDcm8hR0UwlmZkzZ/q2OXz4sHn44YdNvXr1TJ06dcxtt91mdu7c6VensLDQ3HTTTSY2Ntacf/75ZuTIkcbr9VbzvYks9913n7noootMdHS0ueCCC0z37t19QcYY5hpKJ4cZZls1v/nNb0xaWpqJjo42F154ofnNb35jvv32W9/6mjxXlzHGhOecEAAAwNnjPTMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwCqzbJly+RyuU652F0odOvWTcOHDw95XQCRr1a4GwCAUMjOzpbb7Q53G5UaN26c5s+fr/z8/JDUGzx4sA4cOKD58+eHpB5gM8IMgBohOTk53C1Ui2PHjsnlcoW7DSCi8DITYKm///3v6tChg2JjY5WSkqKMjAwdOnTIt/7VV1/VxRdfrNq1a6tt27Z66aWX/H7+//7v/3TnnXcqOTlZdevW1RVXXKHc3Fzf+pdfflktWrRQdHS02rRpo6ysLL+fd7lcevXVV3XbbbepTp06atWqlRYuXOi3zaJFi9S6dWvFxsbqhhtuUGFhod/6rVu3qk+fPqpXr57q1q2r9u3ba9GiRZXe55deekmtWrVS7dq1Vb9+fd1xxx2+dSe/zNS0aVM9//zzuu+++xQfH68mTZpo+vTpQc1gwYIFuuyyy1S7dm01b95c48eP19GjRyvtb9myZbrqqqtUt25dJSUlqWvXrtq6datmzZql8ePHa+3atXK5XHK5XJo1a5YkafLkyerQoYPq1q2rxo0b6+GHH1ZJSYmv5qxZs5SUlKSFCxeqXbt2iomJ0X333ac33nhDCxYs8NVbtmxZpX0BNV64r3QJIHg7duwwtWrVMpMnTzZbtmwx69atM9OmTfNdRXz27NkmLS3NvPfee+a7774z7733nklOTjazZs0yxhhz8OBB07x5c/OLX/zCrFixwmzatMm888475t///rcxxpjs7GzjdrvNtGnTTEFBgZk0aZI577zzzJIlS3w9SDKNGjUyc+fONZs2bTKPPfaYiYuLM/v27TPGGLNt2zYTExNjRowYYb755hsze/ZsU79+fSPJ7N+/3xhjzM0332x69Ohh1q1bZzZv3mzef/99s3z58grv8xdffGHOO+88M3fuXFNYWGjy8vLMiy++6Ft/8pWXL7roIpOcnGymTZtmNm3aZCZMmGCioqLMN998E9AMPvnkE5OQkGBmzZplNm/ebD7++GPTtGlTM27cuAr783q9JjEx0TzxxBPm22+/NV999ZWZNWuW2bp1qyktLTUjR4407du3Nzt37jQ7d+40paWlxhhj/vznP5slS5aYLVu2mMWLF5s2bdqYhx56yFd35syZxu12m2uuucZ8+umn5ptvvjHFxcWmf//+5sYbb/TVKysrC+zgAWogwgxgodWrVxtJprCwsML1LVq0MHPnzvVb9vvf/96kp6cbY4z53//9XxMfH+8LHie75pprzAMPPOC3rF+/fqZ3796+25LM7373O9/tkpISI8l8+OGHxhhjRo0aZdq1a+dX48knn/QLMx06dKg0HJzsvffeMwkJCcbj8VS4vqIw89vf/tZ3u7y83KSmppqXX37ZGHPmGXTv3t08//zzfsuysrJMWlpahdvv27fPSDLLli2rcP3YsWNNp06dKrt7Pn/7299MSkqK7/bMmTONJJOfn++33T333GNuueWWM9YDzgW8zARYqFOnTurevbs6dOigfv36acaMGdq/f78k6dChQ9q8ebOGDBmiuLg439dzzz2nzZs3S5Ly8/N16aWXVvo+k6+//lpdu3b1W9a1a1d9/fXXfss6duzo+75u3bpKSEjQ7t27fTW6dOnit316errf7ccee0zPPfecunbtqrFjx2rdunWV3ucePXrooosuUvPmzTVw4EDNmTNHpaWlpxuTX38ul0sNGjTw9XemGaxdu1bPPvus3wwfeOAB7dy5s8L9Jicna/DgwerVq5f69OmjF198UTt37jxtf5L0r3/9S927d9eFF16o+Ph4DRw4UPv27fPbR3R0tN99AeCPMANY6LzzzlNOTo4+/PBDtWvXTlOnTlWbNm20ZcsW3/stZsyYofz8fN/Xhg0b9Nlnn0mSYmNjQ9LHyZ8ecrlcKi8vD/jn77//fn333XcaOHCg1q9fryuuuEJTp06tcNv4+Hjl5eXprbfeUlpamp555hl16tTptB/zPl1/Z5pBSUmJxo8f7zfD9evXa9OmTapdu3aFPzNz5kytWrVK11xzjd555x21bt3aN/OKFBYW6le/+pU6duyo9957T6tXr9a0adMkSUeOHPFtFxsby5t+gdMgzACWcrlc6tq1q8aPH681a9YoOjpa8+bNU/369dWwYUN99913atmypd9Xs2bNJB0/Y5Gfn6+ioqIKa1988cX69NNP/ZZ9+umnateuXcD9XXzxxfr888/9llX0xN64cWMNHTpU2dnZGjlypGbMmFFpzVq1aikjI0N/+tOftG7dOhUWFmrJkiUB9/RzZ5rBZZddpoKCglNm2LJlS0VFVf7Qeemll2rUqFH697//rUsuuURz586VdPzsyrFjx/y2Xb16tcrLyzVp0iRdffXVat26tXbs2BFQ/xXVA85VfDQbsFBubq4WL16snj17KjU1Vbm5udqzZ48uvvhiSdL48eP12GOPKTExUTfeeKPKysr05Zdfav/+/RoxYoTuvPNOPf/887r11ls1YcIEpaWlac2aNWrYsKHS09OVmZmp/v3769JLL1VGRobef/99ZWdn61//+lfAPQ4dOlSTJk1SZmam7r//fq1evdr3CZ4Thg8frptuukmtW7fW/v37tXTpUt99ONkHH3yg7777Ttddd53q1aunRYsWqby8XG3atKnSDM80g2eeeUa/+tWv1KRJE91xxx2KiorS2rVrtWHDBj333HOn1NuyZYumT5+uvn37qmHDhiooKNCmTZs0aNAgScc/XbVlyxbl5+erUaNGio+PV8uWLeX1ejV16lT16dNHn376qV555ZWA+m/atKn++c9/qqCgQCkpKUpMTIzo/7MDOCrcb9oBELyvvvrK9OrVy1xwwQUmJibGtG7d2kydOtVvmzlz5pjOnTub6OhoU69ePXPdddeZ7Oxs3/rCwkJz++23m4SEBFOnTh1zxRVXmNzcXN/6l156yTRv3ty43W7TunVr8+abb/rVl2TmzZvntywxMdHMnDnTd/v99983LVu2NDExMeYXv/iFef311/3eAPzII4+YFi1amJiYGHPBBReYgQMHmr1791Z4n1esWGGuv/56U69ePRMbG2s6duxo3nnnHd/6it4A/Oc//9mvRqdOnczYsWMDnsFHH31krrnmGhMbG2sSEhLMVVddZaZPn15hf7t27TK33nqrSUtLM9HR0eaiiy4yzzzzjDl27Jgxxpgff/zR3H777SYpKclI8s1p8uTJJi0tzcTGxppevXqZN998029GM2fONImJiafsb/fu3aZHjx4mLi7OSDJLly6tsC/gXOAyxpiwpikAAICzwHtmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGC1/wcKTobypAnMHwAAAABJRU5ErkJggg==", + "text/plain": [ + "<Figure size 640x480 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "get_device_activity_plot(\n", + " [split_device_train_times, split_client_train_times, split_device_eval_times, split_client_eval_times],\n", + " num_devices=5).show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f37d4d98-fb88-438f-8494-de1f82380e38", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "328.0836374759674\n" + ] + } + ], + "source": [ + "fed_client_train = pd.read_csv(\"./csv/fed_client_train_times.csv\")\n", + "fed_device_train = pd.read_csv(\"./csv/fed_device_train_times.csv\")\n", + "fed_client_eval = pd.read_csv(\"./csv/fed_client_eval_times.csv\")\n", + "fed_device_eval = pd.read_csv(\"./csv/fed_device_eval_times.csv\")\n", + "\n", + "fed_offset = get_earliest_start_time([fed_client_train, fed_device_train, fed_client_eval, fed_device_eval])\n", + "print(get_runtime([fed_client_train, fed_device_train, fed_client_eval, fed_device_eval]))\n", + "\n", + "fed_client_train_times = get_time_dict(dataframe=fed_client_train, offset=fed_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - client_train_epoch_time\")\n", + "fed_device_train_times = get_time_dict(dataframe=fed_device_train, offset=fed_offset,\n", + " format_string_fn=lambda device_idx: f\"Group: d{device_idx} - fed_train_time\")\n", + "fed_client_eval_times = get_time_dict(dataframe=fed_client_eval, offset=fed_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - client_evaluate_time\")\n", + "fed_device_eval_times = get_time_dict(dataframe=fed_device_eval, offset=fed_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - evaluate_global_time\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0d4011c8-2374-47d0-81ce-767279549a83", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAGwCAYAAABcnuQpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1/klEQVR4nO3deXhU9dn/8c8EJiEhCyQakjSA7MquVWnEIpawGIu7FmhBlMqFdSkF81OwClirrRSrpVQfN9CwaFuDqE9sTdlEi6kSIFI1IBJoWVQIZEgCyUi+vz94GB0my0wyk8k3vF/XlUvmnJN77nPnzMzHM5MchzHGCAAAwFIR4W4AAACgKQgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWaxvuBkKtpqZG+/btU1xcnBwOR7jbAQAAfjDG6OjRo0pLS1NERP3nXlp9mNm3b586d+4c7jYAAEAj/Oc//1F6enq927T6MBMXFyfp5DDi4+ODWtvtduvtt9/WqFGj5HQ6g1rbRszDFzPxxUy8MQ9fzMTXmTgTl8ulzp07e17H69Pqw8ypt5bi4+NDEmZiYmIUHx9/xhxc9WEevpiJL2bijXn4Yia+zuSZ+PMRET4ADAAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wswZbtDvxoW7hRbDxlmEqmfb6oa6dqjYNudQ1A12TY4De+s2BWEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIM2e4rfe8HO4WWgwbZxGqnm2rG+raoWLbnENRN9g1OQ7srdsUhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGotLswMHz5c06dPD3cbAADAEi0uzNRl2rRpcjgceuKJJ8LdCgAAaEGsCDMrV67U+++/r7S0tHC3AgAAWpiwhpmKigpNmjRJsbGxSk1N1YIFC3y22bt3r+666y4tW7ZMTqczDF0CAICWrG047zw7O1vr16/XqlWrlJycrNmzZ6uwsFCDBw+WJNXU1GjixInKzs5Wv379/KpZVVWlqqoqz22XyyVJcrvdcrvdQe3/VL1g17UV8/DFTHwxE2/Mwxcz8XUmziSQfQ1bmCkvL9fzzz+vpUuXasSIEZKkF198Uenp6Z5tfvvb36pt27a6++67/a776KOPat68eT7L3377bcXExDS98Vrk5+eHpK6tmIcvZuKLmXhjHr6Yia8zaSaVlZV+bxu2MLNz505VV1dryJAhnmWJiYnq06ePJGnTpk168sknVVhYKIfD4XfdWbNmacaMGZ7bLpdLnTt31qhRoxQfHx+8HdDJ1Jifn6+RI0cG/y2whQm+y+4qC37tYNWU5P5jivK7v6CRn98qpzkWmn6lll/3W7XdjuiTMwnWMdIMPYe6btAeN801i2DWrqVuSObR0o+LBmoGPJPm2Pdg1m5EXb9nYtNx0IBT76z4I6xvM9Vnw4YN+vLLL9WlSxfPshMnTmjmzJl64oknVFJSUuv3RUVFKSoqyme50+kM2WduQlLbHKvtjoJfO5h9/19dpzl2MsyEol+p5detpXbQjpHm6rkZ6jZ5Js3482uOYy6o82jpx4WfNf2eSXPsezBrN6FugzOx6ThoQCCPh7B9ALhHjx5yOp0qKCjwLDt8+LC2b98uSZo4caKKioq0ZcsWz1daWpqys7P197//PVxtAwCAFiZsZ2ZiY2M1ZcoUZWdnKykpScnJybr//vsVEXEyXyUlJSkpKcnre5xOp1JSUjxvRQEAAIT1bab58+ervLxcY8eOVVxcnGbOnKmysuZ5Lw4AALQOYQ0zsbGxysnJUU5OjmdZdnZ2ndvX9TkZAABw5rLiLwADAADUhTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1hzHGhLuJUHK5XEpISFBZWVlILjSZl5enrKysoF+bqbZrawbrJ3V67WDVjYlxa8WKPI0fn6Vjx5wtvt/mmHF09MmZBOsYsfG4OL1udXVwHjetYRZS6ObRkmfRUM1An1tt2vfa6vpT29+Z2DaL+gTy+s2ZGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8y0UMb4foWqdrCUlX3zXxv6bY4Zn5pJsOvadFzY/POj59DUDXZNm/a9tro29tzSEGYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0w0wIN+t0462pTN/S16bl50HNo6wa7pk37HmrhmEVLmRNhBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVHMYYE+4mQsnlcikhIUFlZWWKj48Pam232628vDxlZWXJ6XQGtbaNmIcvZuKLmXhjHr6Yia+WOpNt27apf//+IakdyOs3Z2YAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKu1uDAzfPhwTZ8+PdxtAAAAS7S4MPNtc+fO1bnnnqv27durY8eOyszMVEFBQbjbAgAALUiLDjO9e/fWH//4R3300Ud69913dc4552jUqFH66quvwt0aAABoIcIaZioqKjRp0iTFxsYqNTVVCxYs8Fo/YcIEZWZmqnv37urXr58ef/xxuVwuFRUVhaljAADQ0rQN551nZ2dr/fr1WrVqlZKTkzV79mwVFhZq8ODBPttWV1frmWeeUUJCggYNGlRnzaqqKlVVVXluu1wuSScv0uV2u4Pa/6l6wa5rK+bhi5n4YibemIcvZuKrpc6kpqYmZD0FUjdsV80uLy9XUlKSli5dqhtvvFGSVFpaqvT0dE2dOlVPPPGEJOnNN9/UuHHjVFlZqdTUVL322mu66KKL6qw7d+5czZs3z2f58uXLFRMTE5J9AQAAwVVZWakJEyb4ddXssIWZrVu3avDgwdq9e7e6dOniWX7++efrsssu84SZiooK7d+/XwcPHtSzzz6rNWvWqKCgQMnJybXWre3MTOfOnXXw4MEGhxEot9ut/Px8jRw5smmXZF+Y4H37rrKmNVZf/WDWPq1uSOYRwn5DUve02k2eSRh6DnXdRs8klI+TMM65UfMI48+vyXX9qNngTFrRc4S/ap2JbcdBgFwul8466yy/wkxY32byR/v27dWzZ0/17NlT3/ve99SrVy89//zzmjVrVq3bR0VFKSoqyme50+ls2gtsPZpc2xw7vWDTGqqvfjBr11E3qPNohn6DWreO2o2eSRh7DnXdgGcSysdJC5hzQPNoAT+/RtcNoGadM2mFzxH+8pqJbcdBgAJ5fgjbB4B79Oghp9Pp9avWhw8f1vbt2+v9vpqaGq8zLwAA4MwWtjMzsbGxmjJlirKzs5WUlKTk5GTdf//9iog4ma8qKir061//WldddZVSU1N18OBBLVq0SHv37vV8xgYAACCsbzPNnz9f5eXlGjt2rOLi4jRz5kyVlZ18b65Nmzb69NNP9eKLL+rgwYNKSkrSRRddpA0bNqhfv37hbBsAALQgYQ0zsbGxysnJUU5OjmdZdna259+5ubnhaAsAAFikRf8FYAAAgIYQZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArBa2C002F5fLpYSEBL8uVBUot9utvLw8ZWVlNelaRA6H77Jg/lROrx+s2qfXra4OzTxC1W+o6n67dlOPkXD0HOq6jZ1JqGYRytr+1G3MPML58wtG3YZqNjQT258jGlO3tpnYdhwEKpDXb87MAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALBa23A3gNBfjTRU9X2vhhyausFiW91Q1ratrq21qRuaurb0Sd3mw5kZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALBawJczWLNmjXJzc1VSUiKHw6Fu3brphhtu0LBhw0LRHwAAQL0COjMzbdo0ZWZmasWKFTp06JC++uorLVu2TJdffrnuuuuuUPUIAABQJ7/DzMqVK7V48WK98MILOnjwoDZu3Kj3339fX331lZ599lk988wzev3110PZKwAAgA+/w8zixYs1Y8YMTZ48WQ6H45sCERG69dZbNX36dD3//PMhabI1G/S7cVbWp25o64aytm11ba1tW11b2DbX1ly3JR2LfoeZwsJCXXvttXWuv+6667Rp06agNAUAAOAvv8PMwYMHlZ6eXuf69PR0HTp0KChNAQAA+MvvMFNdXS2n01nn+rZt26q6ujooTQEAAPgroF/NfuCBBxQTE1PrusrKyqA0BAAAEAi/w8ywYcNUXFzc4DYAAADNye8ws27duhC2AQAA0DhczgAAAFjN7zMzM2bM8Gu7xx9/vNHNAAAABMrvMLN58+YGt/n2H9MDAABoDn6HmbVr14ayDwAAgEbhMzMAAMBqhBkAAGA1wgwAALCawxhjwt1EKLlcLiUkJKisrEzx8fFBre12u5WXl6esrKx6L/VwpmAevpiJL2bijXn4Yia+WuJMtm3bpv79+4esfiCv35yZAQAAVmtUmNmwYYN+8pOfKCMjQ3v37pUk5eTk6N133w1qcwAAAA0JOMy8+uqrGj16tKKjo7V582ZVVVVJksrKyvTII48EvUEAAID6BBxmHn74YT399NN69tlnvd63Gzp0qAoLC4PaHAAAQEMCDjPFxcW1Xh07ISFBR44cCUZPAAAAfgs4zKSkpOizzz7zWf7uu++qe/fuQWkKAADAXwGHmdtuu00///nPVVBQIIfDoX379mnZsmW65557dPvtt4eiRwAAgDr5fW2mU+677z7V1NRoxIgRqqys1LBhwxQVFaV77rlHd911Vyh6BAAAqFPAZ2YcDofuv/9+lZaWatu2bXr//ff11Vdf6Ve/+lVQGho+fLimT58elFoAAKD1CzjMlJWVqbS0VJGRkerbt68uvvhixcbGqrS0VC6XK2iNud1u3XvvvRowYIDat2+vtLQ0TZo0Sfv27QvafQAAAPsFHGbGjRunl19+2Wf5n//8Z40bNy4oTUlSZWWlCgsL9cADD6iwsFC5ubkqLi7WVVddFbT7AAAA9gs4zBQUFOjyyy/3WT58+HAVFBQEVKuiokKTJk1SbGysUlNTtWDBAs+6hIQE5efn66abblKfPn30ve99T3/84x+1adMm7dmzJ9C2AQBAKxXwB4Crqqr09ddf+yx3u906duxYQLWys7O1fv16rVq1SsnJyZo9e7YKCws1ePDgWrcvKyuTw+FQhw4d6u3v1F8lluR568vtdsvtdgfUX0NO1Qt2XVsxD1/MxBcz8cY8fDETXy1xJjU1NSHtJ5DaAV81+/LLL1f//v21cOFCr+V33HGHioqKtGHDBr/qlJeXKykpSUuXLtWNN94oSSotLVV6erqmTp2qJ554wmv748ePa+jQoTr33HO1bNmyOuvOnTtX8+bN81m+fPlyxcTE+NUbAAAIr8rKSk2YMMGvq2YHfGbm4YcfVmZmprZu3aoRI0ZIklavXq0PPvhAb7/9tt91du7cqerqag0ZMsSzLDExUX369PHZ1u1266abbpIxRk899VS9dWfNmqUZM2Z4brtcLnXu3FmjRo1qcBiBcrvdys/P18iRI/27JPvCBN9ld5UFtSef+whWfT/qBjwPP+s22rdrh6luQDNp5bM4xa+ZNNcsglm7kXUDnkcrnsUpPjMJxf5b9njzzOTzW+U0x1p8v8EQyC8VBRxmhg4dqo0bN2r+/Pn685//rOjoaA0cOFDPP/+8evXqFWi5Bp0KMrt379aaNWsaDCRRUVGKioryWe50Ov1/gQ2Q37VNLW/DBbun0+8jWPUDqBvQrEPV7+m1w1zXr5mcIbP4ZvN6ZtJcswhm7SbW9XseZ8Asvvm2/5tJKPbfxsebJKc5djLMWNJvUwTymh1wmJGkwYMH1/tWjz969Oghp9OpgoICdenSRZJ0+PBhbd++XZdddpmkb4LMjh07tHbtWiUlJTXpPgEAQOvjV5hxuVyeMyINnfbx962c2NhYTZkyRdnZ2UpKSlJycrLuv/9+RUSc/AUrt9utG264QYWFhXrzzTd14sQJHThwQNLJt6MiIyP9uh8AANC6+RVmOnbsqP379ys5OVkdOnSQw+Hw2cYYI4fDoRMnTvh95/Pnz1d5ebnGjh2ruLg4zZw5U2VlJ9+v27t3r15//XVJ8vntprVr12r48OF+3w8AAGi9/Aoza9asUWJioufftYWZxoiNjVVOTo5ycnI8y7Kzsz3/DvAXrQAAwBnIrzBz6jMskjgjAgAAWpSA/wJwr169NHfuXO3YsSMU/QAAAAQk4DDzs5/9TP/7v/+rc889VxdddJGefPJJzwdzAQAAmlvAYeYXv/iFPvjgA33yySfKysrSokWLPH+U7qWXXgpFjwAAAHUKOMyc0rt3b82bN0/bt2/Xhg0b9NVXX+mWW24JZm8AAAANatQfzTvlX//6l5YvX65XXnlFLpfLc40lAACA5hJwmNm+fbuWLVumFStWaNeuXfrBD36g3/72t7ruuusUGxsbih4BAADqFPBVsyMiInTRRRdpwoQJGjdunDp16hSq3oLC5XIpISHBr6tuBsrtdisvL09ZWVl+XUOirj/PE8w/p1PbfQSjvj91A52Hv3Ub6/Ta4agbyExa+yxO8WcmzTmLYNVubN3GzKM1HhffdvpMQtGnbY+3UzMZPz5Lx445g1ZXCt1x0FSBvH4HfGamuLg4JBeUBAAAaIxG/Z2ZI0eO6LnnntOsWbNUWloqSSosLNTevXuD3iAAAEB9Aj4zU1RUpBEjRqhDhw4qKSnRbbfdpsTEROXm5mrPnj38ejYAAGhWjfo7M7fccot27Nihdu3aeZZnZWXpnXfeCWpzAAAADQn4zMyHH36oZ555xmf5d77zHf4SMAAAaHYBn5mJioqSy+XyWb59+3adffbZQWkKAADAXwGHmauuukoPPfSQ3G63JMnhcGjPnj269957df311we9QQAAgPoEHGYWLFig8vJyJScn69ixY7rsssvUs2dPxcXF6de//nUoegQAAKhTwJ+ZSUhIUH5+vt59910VFRWpvLxcF1xwgTIzM0PRHwAAQL0afW2mSy+9VJdeemkwewEAAAhYQGGmpqZGS5YsUW5urkpKSuRwONStWzfdcMMNmjhxohx1/b1+AACAEPH7MzPGGF111VX66U9/qr1792rAgAHq16+fdu/ercmTJ+vaa68NZZ8AAAC18vvMzJIlS/TOO+9o9erVuvzyy73WrVmzRtdcc41eeuklTZo0KehNAgAA1MXvMzMrVqzQ7NmzfYKMJP3gBz/Qfffdp2XLlgW1udbGmNq/Qn0fZ2Ld2mrbVtfGnkNV14ba/PxCV5ef1zfKyuzqt7n4HWaKioo0ZsyYOtdfccUV2rp1a1CaAgAA8JffYaa0tFSdOnWqc32nTp10+PDhoDQFAADgL7/DzIkTJ9S2bd0fsWnTpo2+/vrroDQFAADgL78/AGyM0eTJkxUVFVXr+qqqqqA1BQAA4C+/w8zNN9/c4Db8JhMAAGhufoeZxYsXh7IPAACARgn4QpMAAAAtCWEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhJlWZNDvxlE3xLVD2XOo2DYLG4+LULJtzsGuG4o+bTzGhi68JSR1bXxM1IYwAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqDmOMCXcToeRyuZSQkKCysjLFx8cHtbbb7VZeXp6ysrLkdDqDWttGzMMXM/HFTLwxD1/MxNeZOJNAXr85MwMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWK3FhZnhw4dr+vTp4W4DAABYosWFmW/Lzc3VqFGjlJSUJIfDoS1btoS7JQAA0MK06DBTUVGhSy+9VL/97W/D3QoAAGih2obzzisqKnT77bcrNzdXcXFxuueee7zWT5w4UZJUUlIShu4AAIANwhpmsrOztX79eq1atUrJycmaPXu2CgsLNXjw4EbXrKqqUlVVlee2y+WSdPIiXW63u6ktezlVL9h1bcU8fDETX8zEG/PwxUx8nYkzCWRfw3bV7PLyciUlJWnp0qW68cYbJUmlpaVKT0/X1KlT9cQTT3i2LSkpUbdu3bR58+YGg87cuXM1b948n+XLly9XTExMMHcBAACESGVlpSZMmODXVbPDdmZm586dqq6u1pAhQzzLEhMT1adPnybVnTVrlmbMmOG57XK51LlzZ40aNarBYQTK7XYrPz9fI0eO9L0k+8IE32+4qyyo9+9zP8Gs34i69c6jCXX9cvq8g1W7iXUDOkZaSM+hrlvrTEJ1XISydqjmEaqfn9Sini/q4/5jivK7v6CRn98q550HmlxPUos/Dhqq6552sOHn18bUDsXrUpCcemfFH2F9mykUoqKiFBUV5bPc6XQ2/QCoQ621zbHaNgz+nX/7foJZvwl16511c/QbzNpBquvXMdLCeg51Xa+ZhOq4CGXtUM0jVD8/qUU+X9RXz2mOBe9525LjoKG6QXktC+XjLYgC2c+w/TZTjx495HQ6VVBQ4Fl2+PBhbd++PVwtAQAAC4XtzExsbKymTJmi7OxsJSUlKTk5Wffff78iIr7JV6WlpdqzZ4/27dsnSSouLpYkpaSkKCUlJSx9AwCAliWsbzPNnz9f5eXlGjt2rOLi4jRz5kyVlX3z/t3rr7+uW265xXN73LhxkqQ5c+Zo7ty5zd0uAABogcIaZmJjY5WTk6OcnBzPsuzsbM+/J0+erMmTJ4ehMwAAYIsW/ReAAQAAGkKYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwWtiumt1cXC6XEhIS/LrqZqDcbrfy8vKUlZXlcw0Jh8N3+1BM+vT7CdZ9NKZuffNoSl1/hGreTa0byDHS2mdxSm0zCdUsQlk7WHVPn0conzta+ixOiYlxa8WKPI0fn6Vjx5whOX5te7xFR5+cSX3Pr42t3VJTQCCv35yZAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wkyIGOP71Rz3Q93g1g7lz5FZ1F07mGzr2cY5B7tuWdk3/w3V8RsszfV4OzWTUNRuDQgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAmBQb8bZ/X9UDe0dUNZu7mOvWBizs3DllmEok9b9t0mLW3fCTMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArOYwxphwNxFKLpdLCQkJKisrU3x8fFBru91u5eXlKSsrS06nM6i1bcQ8fDETX8zEG/PwxUx8tbSZbNu2Tf379w/pfQTy+s2ZGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAai0uzAwfPlzTp08PdxsAAMASLS7MfJsxRg8++KBSU1MVHR2tzMxM7dixI9xtAQCAFqRFh5nHHntMf/jDH/T000+roKBA7du31+jRo3X8+PFwtwYAAFqIsIaZiooKTZo0SbGxsUpNTdWCBQs864wxeuKJJ/TLX/5SV199tQYOHKiXXnpJ+/bt02uvvRa+pgEAQIvSNpx3np2drfXr12vVqlVKTk7W7NmzVVhYqMGDB2vXrl06cOCAMjMzPdsnJCRoyJAh2rhxo8aNG1drzaqqKlVVVXluu1wuSScv0uV2u4Pa/6l6wa5rK+bhi5n4YibemIcvZuKrpc2kpqYm5L0EUj9sV80uLy9XUlKSli5dqhtvvFGSVFpaqvT0dE2dOlU33XSThg4dqn379ik1NdXzfTfddJMcDodeeeWVWuvOnTtX8+bN81m+fPlyxcTEhGZnAABAUFVWVmrChAl+XTU7bGdmdu7cqerqag0ZMsSzLDExUX369GlS3VmzZmnGjBme2y6XS507d9aoUaMaHEag3G638vPzNfLzW+U0x75ZcVdZUO/Hy8KE4N/Pt2s2oa5nHiNHnrxEfZDq1ioUcwhB3WY5RppjFkGs7f5jivK7v/DNTCzoOZTHcsjmYdMsTjuGfZ5LmlLv/2oGTZgeb42eSShnEWKn3lnxR1jfZqpPSkqKJOmLL77wOjPzxRdfaPDgwXV+X1RUlKKionyWO53Oxj0o/OA0x7xfqEJ0P5KkUNzPt2sGoa5n1kGu6yVU8w5R3ZAeI81RN5i1/6+uZyYW9Rz0ut+qHfR52DSLOo7hRj9v2/7cU0/tgGcSylmEWCD7GbYPAPfo0UNOp1MFBQWeZYcPH9b27dslSd26dVNKSopWr17tWe9yuVRQUKCMjIxm7xcAALRMYTszExsbqylTpig7O1tJSUlKTk7W/fffr4iIk/nK4XBo+vTpevjhh9WrVy9169ZNDzzwgNLS0nTNNdeEq20AANDChPVtpvnz56u8vFxjx45VXFycZs6cqbKyb97P+3//7/+poqJCU6dO1ZEjR3TppZfqb3/7m9q1axfGrgEAQEsS1jATGxurnJwc5eTkeJZlZ2d7/u1wOPTQQw/poYceCkd7AADAAi36LwADAAA0hDADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGC1sF01u7m4XC4lJCT4ddXNQLndbuXl5Wn8+CwdO+Z9DYlQTdXh8L4djPs5vWZj656aR1ZWlpxOZ9Dq1iYUcwhF3eY4RmyZxSkxMW6tWOE9k1D1HKzaoTyWT5/HmTiL02tWV3s/lzS1ntTyHxcN1T39+TVYdVuyQF6/OTMDAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEmSAoKzt5JdJvf4VKKO7n9JotvW5ttVt63VAeI7bVLSv75r+h7tmGY/n0eQSLTbMIdT0bHhe21W1pCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAq7UNdwOhZoyRJLlcrqDXdrvdqqyslMvlktPpDHp92zAPX8zEFzPxxjx8MRNfZ+JMTr1un3odr0+rDzNHjx6VJHXu3DnMnQAAgEAdPXpUCQkJ9W7jMP5EHovV1NRo3759iouLk8PhCGptl8ulzp076z//+Y/i4+ODWttGzMMXM/HFTLwxD1/MxNeZOBNjjI4ePaq0tDRFRNT/qZhWf2YmIiJC6enpIb2P+Pj4M+bg8gfz8MVMfDETb8zDFzPxdabNpKEzMqfwAWAAAGA1wgwAALAaYaYJoqKiNGfOHEVFRYW7lRaBefhiJr6YiTfm4YuZ+GIm9Wv1HwAGAACtG2dmAACA1QgzAADAaoQZAABgNcIMAACwGmGmkRYtWqRzzjlH7dq105AhQ/Svf/0r3C01m7lz58rhcHh9nXvuuZ71x48f1x133KGkpCTFxsbq+uuv1xdffBHGjoPrnXfe0dixY5WWliaHw6HXXnvNa70xRg8++KBSU1MVHR2tzMxM7dixw2ub0tJS/fjHP1Z8fLw6dOigKVOmqLy8vBn3IrgamsnkyZN9jpkxY8Z4bdOaZvLoo4/qoosuUlxcnJKTk3XNNdeouLjYaxt/Hid79uzRlVdeqZiYGCUnJys7O1tff/11c+5K0Pgzk+HDh/scJ9OmTfPapjXN5KmnntLAgQM9fwgvIyNDb731lmf9mXaMNAVhphFeeeUVzZgxQ3PmzFFhYaEGDRqk0aNH68svvwx3a82mX79+2r9/v+fr3Xff9az7xS9+oTfeeEN/+ctftH79eu3bt0/XXXddGLsNroqKCg0aNEiLFi2qdf1jjz2mP/zhD3r66adVUFCg9u3ba/To0Tp+/Lhnmx//+Mf697//rfz8fL355pt65513NHXq1ObahaBraCaSNGbMGK9jZsWKFV7rW9NM1q9frzvuuEPvv/++8vPz5Xa7NWrUKFVUVHi2aehxcuLECV155ZWqrq7WP//5T7344otasmSJHnzwwXDsUpP5MxNJuu2227yOk8cee8yzrrXNJD09Xb/5zW+0adMmffjhh/rBD36gq6++Wv/+978lnXnHSJMYBOziiy82d9xxh+f2iRMnTFpamnn00UfD2FXzmTNnjhk0aFCt644cOWKcTqf5y1/+4ln2ySefGElm48aNzdRh85FkVq5c6bldU1NjUlJSzPz58z3Ljhw5YqKiosyKFSuMMcZ8/PHHRpL54IMPPNu89dZbxuFwmL179zZb76Fy+kyMMebmm282V199dZ3f09pn8uWXXxpJZv369cYY/x4neXl5JiIiwhw4cMCzzVNPPWXi4+NNVVVV8+5ACJw+E2OMueyyy8zPf/7zOr+ntc/EGGM6duxonnvuOY6RAHFmJkDV1dXatGmTMjMzPcsiIiKUmZmpjRs3hrGz5rVjxw6lpaWpe/fu+vGPf6w9e/ZIkjZt2iS32+01n3PPPVddunQ5I+aza9cuHThwwGv/ExISNGTIEM/+b9y4UR06dNCFF17o2SYzM1MREREqKCho9p6by7p165ScnKw+ffro9ttv16FDhzzrWvtMysrKJEmJiYmS/HucbNy4UQMGDFCnTp0824wePVoul8vzf+42O30mpyxbtkxnnXWW+vfvr1mzZqmystKzrjXP5MSJE3r55ZdVUVGhjIwMjpEAtfoLTQbbwYMHdeLECa+DR5I6deqkTz/9NExdNa8hQ4ZoyZIl6tOnj/bv36958+bp+9//vrZt26YDBw4oMjJSHTp08PqeTp066cCBA+FpuBmd2sfajo9T6w4cOKDk5GSv9W3btlViYmKrndGYMWN03XXXqVu3btq5c6dmz56tK664Qhs3blSbNm1a9Uxqamo0ffp0DR06VP3795ckvx4nBw4cqPU4OrXOZrXNRJImTJigrl27Ki0tTUVFRbr33ntVXFys3NxcSa1zJh999JEyMjJ0/PhxxcbGauXKlerbt6+2bNlyRh8jgSLMIGBXXHGF598DBw7UkCFD1LVrV/35z39WdHR0GDtDSzVu3DjPvwcMGKCBAweqR48eWrdunUaMGBHGzkLvjjvu0LZt27w+V3amq2sm3/6M1IABA5SamqoRI0Zo586d6tGjR3O32Sz69OmjLVu2qKysTH/961918803a/369eFuyzq8zRSgs846S23atPH5RPkXX3yhlJSUMHUVXh06dFDv3r312WefKSUlRdXV1Tpy5IjXNmfKfE7tY33HR0pKis+Hxb/++muVlpaeETOSpO7du+uss87SZ599Jqn1zuTOO+/Um2++qbVr1yo9Pd2z3J/HSUpKSq3H0al1tqprJrUZMmSIJHkdJ61tJpGRkerZs6e++93v6tFHH9WgQYP05JNPntHHSGMQZgIUGRmp7373u1q9erVnWU1NjVavXq2MjIwwdhY+5eXl2rlzp1JTU/Xd735XTqfTaz7FxcXas2fPGTGfbt26KSUlxWv/XS6XCgoKPPufkZGhI0eOaNOmTZ5t1qxZo5qaGs+Td2v33//+V4cOHVJqaqqk1jcTY4zuvPNOrVy5UmvWrFG3bt281vvzOMnIyNBHH33kFfLy8/MVHx+vvn37Ns+OBFFDM6nNli1bJMnrOGlNM6lNTU2NqqqqzshjpEnC/QlkG7388ssmKirKLFmyxHz88cdm6tSppkOHDl6fKG/NZs6cadatW2d27dpl3nvvPZOZmWnOOuss8+WXXxpjjJk2bZrp0qWLWbNmjfnwww9NRkaGycjICHPXwXP06FGzefNms3nzZiPJPP7442bz5s1m9+7dxhhjfvOb35gOHTqYVatWmaKiInP11Vebbt26mWPHjnlqjBkzxpx//vmmoKDAvPvuu6ZXr15m/Pjx4dqlJqtvJkePHjX33HOP2bhxo9m1a5f5xz/+YS644ALTq1cvc/z4cU+N1jST22+/3SQkJJh169aZ/fv3e74qKys92zT0OPn6669N//79zahRo8yWLVvM3/72N3P22WebWbNmhWOXmqyhmXz22WfmoYceMh9++KHZtWuXWbVqlenevbsZNmyYp0Zrm8l9991n1q9fb3bt2mWKiorMfffdZxwOh3n77beNMWfeMdIUhJlGWrhwoenSpYuJjIw0F198sXn//ffD3VKz+dGPfmRSU1NNZGSk+c53vmN+9KMfmc8++8yz/tixY+ZnP/uZ6dixo4mJiTHXXnut2b9/fxg7Dq61a9caST5fN998szHm5K9nP/DAA6ZTp04mKirKjBgxwhQXF3vVOHTokBk/fryJjY018fHx5pZbbjFHjx4Nw94ER30zqaysNKNGjTJnn322cTqdpmvXrua2227zCf+taSa1zUKSWbx4sWcbfx4nJSUl5oorrjDR0dHmrLPOMjNnzjRut7uZ9yY4GprJnj17zLBhw0xiYqKJiooyPXv2NNnZ2aasrMyrTmuaya233mq6du1qIiMjzdlnn21GjBjhCTLGnHnHSFM4jDGm+c4DAQAABBefmQEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAdBs1q1bJ4fD4XPxvGAYPny4pk+fHvS6AFq+tuFuAACCITc3V06nM9xt1Gnu3Ll67bXXPBdPbKrJkyfryJEjeu2114JSD7AZYQZAq5CYmBjuFprFiRMn5HA4wt0G0KLwNhNgqb/+9a8aMGCAoqOjlZSUpMzMTFVUVHjWP/fcczrvvPPUrl07nXvuufrTn/7k9f3//e9/NX78eCUmJqp9+/a68MILVVBQ4Fn/1FNPqUePHoqMjFSfPn2Uk5Pj9f0Oh0PPPfecrr32WsXExKhXr156/fXXvbbJy8tT7969FR0drcsvv1wlJSVe63fv3q2xY8eqY8eOat++vfr166e8vLw69/lPf/qTevXqpXbt2qlTp0664YYbPOtOf5vpnHPO0SOPPKJbb71VcXFx6tKli5555pmAZrBq1SpdcMEFateunbp376558+bp66+/rrO/devW6eKLL1b79u3VoUMHDR06VLt379aSJUs0b948bd26VQ6HQw6HQ0uWLJEkPf744xowYIDat2+vzp0762c/+5nKy8s9NZcsWaIOHTro9ddfV9++fRUVFaVbb71VL774olatWuWpt27dujr7Alq9cF/pEkDg9u3bZ9q2bWsef/xxs2vXLlNUVGQWLVrkucr00qVLTWpqqnn11VfN559/bl599VWTmJholixZYowx5ujRo6Z79+7m+9//vtmwYYPZsWOHeeWVV8w///lPY4wxubm5xul0mkWLFpni4mKzYMEC06ZNG7NmzRpPD5JMenq6Wb58udmxY4e5++67TWxsrDl06JAx5uRVkKOiosyMGTPMp59+apYuXWo6depkJJnDhw8bY4y58sorzciRI01RUZHZuXOneeONN8z69etr3ecPPvjAtGnTxixfvtyUlJSYwsJC8+STT3rWX3bZZebnP/+553bXrl1NYmKiWbRokdmxY4d59NFHTUREhPn000/9msE777xj4uPjzZIlS8zOnTvN22+/bc455xwzd+7cWvtzu90mISHB3HPPPeazzz4zH3/8sVmyZInZvXu3qaysNDNnzjT9+vUz+/fvN/v37zeVlZXGGGN+//vfmzVr1phdu3aZ1atXmz59+pjbb7/dU3fx4sXG6XSaSy65xLz33nvm008/NWVlZeamm24yY8aM8dSrqqry7+ABWiHCDGChTZs2GUmmpKSk1vU9evQwy5cv91r2q1/9ymRkZBhjjPmf//kfExcX5wkep7vkkkvMbbfd5rXsxhtvNFlZWZ7bkswvf/lLz+3y8nIjybz11lvGGGNmzZpl+vbt61Xj3nvv9QozAwYMqDMcnO7VV1818fHxxuVy1bq+tjDzk5/8xHO7pqbGJCcnm6eeesoY0/AMRowYYR555BGvZTk5OSY1NbXW7Q8dOmQkmXXr1tW6fs6cOWbQoEF17Z7HX/7yF5OUlOS5vXjxYiPJbNmyxWu7m2++2Vx99dUN1gPOBLzNBFho0KBBGjFihAYMGKAbb7xRzz77rA4fPixJqqio0M6dOzVlyhTFxsZ6vh5++GHt3LlTkrRlyxadf/75dX7O5JNPPtHQoUO9lg0dOlSffPKJ17KBAwd6/t2+fXvFx8fryy+/9NQYMmSI1/YZGRlet++++249/PDDGjp0qObMmaOioqI693nkyJHq2rWrunfvrokTJ2rZsmWqrKysb0xe/TkcDqWkpHj6a2gGW7du1UMPPeQ1w9tuu0379++v9X4TExM1efJkjR49WmPHjtWTTz6p/fv319ufJP3jH//QiBEj9J3vfEdxcXGaOHGiDh065HUfkZGRXvsCwBthBrBQmzZtlJ+fr7feekt9+/bVwoUL1adPH+3atcvzeYtnn31WW7Zs8Xxt27ZN77//viQpOjo6KH2c/ttDDodDNTU1fn//T3/6U33++eeaOHGiPvroI1144YVauHBhrdvGxcWpsLBQK1asUGpqqh588EENGjSo3l/zrq+/hmZQXl6uefPmec3wo48+0o4dO9SuXbtav2fx4sXauHGjLrnkEr3yyivq3bu3Z+a1KSkp0Q9/+EMNHDhQr776qjZt2qRFixZJkqqrqz3bRUdH86FfoB6EGcBSDodDQ4cO1bx587R582ZFRkZq5cqV6tSpk9LS0vT555+rZ8+eXl/dunWTdPKMxZYtW1RaWlpr7fPOO0/vvfee17L33ntPffv29bu/8847T//617+8ltX2wt65c2dNmzZNubm5mjlzpp599tk6a7Zt21aZmZl67LHHVFRUpJKSEq1Zs8bvnr6toRlccMEFKi4u9plhz549FRFR91Pn+eefr1mzZumf//yn+vfvr+XLl0s6eXblxIkTXttu2rRJNTU1WrBggb73ve+pd+/e2rdvn1/911YPOFPxq9mAhQoKCrR69WqNGjVKycnJKigo0FdffaXzzjtPkjRv3jzdfffdSkhI0JgxY1RVVaUPP/xQhw8f1owZMzR+/Hg98sgjuuaaa/Too48qNTVVmzdvVlpamjIyMpSdna2bbrpJ559/vjIzM/XGG28oNzdX//jHP/zucdq0aVqwYIGys7P105/+VJs2bfL8Bs8p06dP1xVXXKHevXvr8OHDWrt2rWcfTvfmm2/q888/17Bhw9SxY0fl5eWppqZGffr0adQMG5rBgw8+qB/+8Ifq0qWLbrjhBkVERGjr1q3atm2bHn74YZ96u3bt0jPPPKOrrrpKaWlpKi4u1o4dOzRp0iRJJ3+7ateuXdqyZYvS09MVFxennj17yu12a+HChRo7dqzee+89Pf300371f8455+jvf/+7iouLlZSUpISEhBb9d3aAkAr3h3YABO7jjz82o0ePNmeffbaJiooyvXv3NgsXLvTaZtmyZWbw4MEmMjLSdOzY0QwbNszk5uZ61peUlJjrr7/exMfHm5iYGHPhhReagoICz/o//elPpnv37sbpdJrevXubl156yau+JLNy5UqvZQkJCWbx4sWe22+88Ybp2bOniYqKMt///vfNCy+84PUB4DvvvNP06NHDREVFmbPPPttMnDjRHDx4sNZ93rBhg7nssstMx44dTXR0tBk4cKB55ZVXPOtr+wDw73//e68agwYNMnPmzPF7Bn/729/MJZdcYqKjo018fLy5+OKLzTPPPFNrfwcOHDDXXHONSU1NNZGRkaZr167mwQcfNCdOnDDGGHP8+HFz/fXXmw4dOhhJnjk9/vjjJjU11URHR5vRo0ebl156yWtGixcvNgkJCT739+WXX5qRI0ea2NhYI8msXbu21r6AM4HDGGPCmqYAAACagM/MAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBq/x9jqlwDoKmSeAAAAABJRU5ErkJggg==", + "text/plain": [ + "<Figure size 640x480 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "get_device_activity_plot([fed_device_train_times, fed_client_train_times, fed_device_eval_times, fed_client_eval_times],\n", + " num_devices=5).show()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6fa00a73-34f0-4d6f-b9e8-8bd655019772", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "513.5029487609863\n" + ] + } + ], + "source": [ + "swarm_client_train = pd.read_csv(\"./csv/swarm_client_train_times.csv\")\n", + "swarm_device_train = pd.read_csv(\"./csv/swarm_device_train_times.csv\")\n", + "swarm_client_eval = pd.read_csv(\"./csv/swarm_client_eval_times.csv\")\n", + "swarm_device_eval = pd.read_csv(\"./csv/swarm_device_eval_times.csv\")\n", + "\n", + "swarm_offset = get_earliest_start_time([swarm_client_train, swarm_device_train, swarm_client_eval, swarm_device_eval])\n", + "print(get_runtime([swarm_client_train, swarm_device_train, swarm_client_eval, swarm_device_eval]))\n", + "\n", + "swarm_client_train_times = get_time_dict(dataframe=swarm_client_train, offset=swarm_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - client_train_epoch_time\")\n", + "swarm_device_train_times = get_time_dict(dataframe=swarm_device_train, offset=swarm_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - train_global_time\")\n", + "swarm_client_eval_times = get_time_dict(dataframe=swarm_client_eval, offset=swarm_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - client_evaluate_time\")\n", + "swarm_device_eval_times = get_time_dict(dataframe=swarm_device_eval, offset=swarm_offset, format_string_fn=lambda\n", + " device_idx: f\"Group: d{device_idx} - evaluate_global_time\", device_indices=[4])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "3b1ba2a4-aadc-42d9-b0a1-b016e5046c98", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAGwCAYAAABcnuQpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAzu0lEQVR4nO3deXRU9f3/8ddEJiGQjURDgoCyK8jiiogLlgCKxaUuRS2Ioh53+QL5KVAFrNVaitVS1K8baFjU1rBoozVlEyxEJQSkakAk4JdFwECGEAwj+fz+oIyELMyQezPzCc/HOZyT3Hvzvu953zvkde7M5HqMMUYAAACWigp3AwAAAHVBmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsFqjcDfgtoqKCm3dulXx8fHyeDzhbgcAAATBGKO9e/eqRYsWioqq/dpLgw8zW7duVatWrcLdBgAAOA7fffedWrZsWes2DT7MxMfHSzo0jISEBEdr+/1+ffTRR+rfv7+8Xq+jtU90zNZdzNddzNddzNddkTJfn8+nVq1aBX6P16bBh5nDLy0lJCS4EmaaNGmihIQEnlAOY7buYr7uYr7uYr7uirT5BvMWEd4ADAAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wswJrvufBltV1+3abrCtX4nzwpY+ba7rJhtnYWPPkYQwAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqHmOMCXcTbvL5fEpMTFRJSYkSEhIcre33+5WTk6OBAwfK6/U6WvtEx2zdxXzdxXzdxXzdFcx8165dG/j6rLPOcqWPUH5/c2UGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALBaxIWZPn36aMSIEeFuAwAAWCLiwkxN7rnnHnk8Hj333HPhbgUAAEQQK8LMnDlztGLFCrVo0SLcrQAAgAgT1jCzb98+DR06VHFxcUpPT9fkyZOrbLNlyxY9+OCDmjlzJvfgAAAAVTQK584zMzO1ZMkSzZs3T6mpqRo7dqzy8/PVo0cPSVJFRYWGDBmizMxMdenSJaia5eXlKi8vD3zv8/kkHbpxlt/vd7T/w/Wcrgtm6zbm6y7m6y7m665g5ltRUVFle7f6CEbYwkxpaalee+01zZgxQ3379pUkvfHGG2rZsmVgm2eeeUaNGjXSQw89FHTdp59+WhMnTqyy/KOPPlKTJk3q3ng1cnNzXakLZus25usu5usu5uuuYOe7ceNGV/ZfVlYW9LZhCzMbNmzQgQMH1LNnz8Cy5ORkderUSZK0cuVKPf/888rPz5fH4wm67pgxYzRy5MjA9z6fT61atVL//v2PeQvxUPn9fuXm5qpfv37uvAQ2JbHqsgdLnK1Z13ou1a0025dOdrR2FUf27kRtt2bsYO0q564l50Wt+6hLbRvP3xP4mDWI8zeCe671d5sTz7cgHX5lJRhhfZmpNkuXLtWOHTvUunXrwLKDBw9q1KhReu6551RUVFTtz8XExCgmJqbKcq/X69p7blyrbfZXtzNnazrVt0t1vV6vvG71fNiR9Z2o7Wa/DtcOnLuWnRfV7qMutW08f22rW90+TuTz14Keq/3d5tCxC3b/wQrbG4DbtWsnr9ervLy8wLLdu3dr3bp1kqQhQ4ZozZo1KigoCPxr0aKFMjMz9c9//jNcbQMAgAgTtiszcXFxGj58uDIzM5WSkqLU1FSNGzdOUVGH8lVKSopSUlIq/YzX61VaWlrgpSgAAICwvsw0adIklZaWatCgQYqPj9eoUaNUUuL+63AAAKDhCGuYiYuLU1ZWlrKysgLLMjMza9y+pvfJAACAE5cVfwEYAACgJoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrRey9mSBplLGjppt1a6gdwr1Hg/BzfTPKgXL1PAvb6jp77KTA8RstmeN9GByzWuu6dcwceb5J9s3YzdoOHL/YWGn2bCkxUdpf5RaBDh87h3BlBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVuGs2rHTcd0dG2HHs7MMxs1uox8/vl3JypJISyet1pyencWUGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMyeQ7n8abFVdSeo95XbXakvu9u40G4+f2yK9d7fO30h/3LVxuncbnxe2Hr/q+o6Ux0KYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1jzHGhLsJN/l8PiUmJqqkpEQJCQmO1vb7/crJydHAgQPl9XodrX2iY7buYr7uYr7uYr7uipT5hvL7myszAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVIi7M9OnTRyNGjAh3GwAAwBIRF2aONGHCBJ1xxhlq2rSpmjVrpoyMDOXl5YW7LQAAEEEiOsx07NhRf/3rX/XFF19o2bJlOv3009W/f3/t3Lkz3K0BAIAIEdYws2/fPg0dOlRxcXFKT0/X5MmTK62/5ZZblJGRobZt26pLly569tln5fP5tGbNmjB1DAAAIk2jcO48MzNTS5Ys0bx585SamqqxY8cqPz9fPXr0qLLtgQMH9PLLLysxMVHdu3evsWZ5ebnKy8sD3/t8PkmHbpzl9/sd7f9wPafrgtm6jfm6i/m6i/m6K1LmG8r+w3bX7NLSUqWkpGjGjBm68cYbJUnFxcVq2bKl7r77bj333HOSpPfff1+DBw9WWVmZ0tPTNXfuXJ1//vk11p0wYYImTpxYZfmsWbPUpEkTVx4LAABwVllZmW655Zag7podtjCzevVq9ejRQ5s2bVLr1q0Dy88++2xddtllgTCzb98+bdu2Tbt27dIrr7yihQsXKi8vT6mpqdXWre7KTKtWrbRr165jDiNUfr9fubm56tevX91ukz4lsfrlD5Ycf83q6ta1ntt1j+D/a5py275e99kezY3ebZvzlET5PbGH5vvtHfI+sN2ZukfUD3D63HBqJi4fs8B8nTp/6+E5F9hHpB6zI+q5cv66OWO3j5/Dz7nA77Zv75DX7He0dih8Pp9OPvnkoMJMWF9mCkbTpk3Vvn17tW/fXhdeeKE6dOig1157TWPGjKl2+5iYGMXExFRZ7vV6nf2l6GTtI0+WyoWPv2Z1dZ16/G7VrWYfjh83N3q3bc5H1PWa/c4/L47s283adalfT8fMsfO3Hp9zEXvMqqnn6Pnr5ozdPn4uPee8Zn/lMOPS79Aa9x/C/sL2BuB27drJ6/VW+qj17t27tW7dulp/rqKiotKVFwAAcGIL25WZuLg4DR8+XJmZmUpJSVFqaqrGjRunqKhD+Wrfvn36/e9/r6uvvlrp6enatWuXpk6dqi1btgTeYwMAABDWl5kmTZqk0tJSDRo0SPHx8Ro1apRKSg69JnfSSSfp66+/1htvvKFdu3YpJSVF559/vpYuXaouXbqEs20AABBBwhpm4uLilJWVpaysrMCyzMzMwNfZ2dnhaAsAAFgkov8CMAAAwLEQZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArBa2G03WF5/Pp8TExKBuVBUqv9+vnJwcDRw40LX7PtnM4zn+n42N9Wv27BzdfPNA7d9fdbYN+6x1XzDnbl2OX21OhGMXrv8bTpRj1lD/742U4xcp8w3l9zdXZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYrVG4G0DDVZc77fr9Uk6OVFIiNaCb4lol0u6UjGPjmNmN43f8uDIDAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYL+XYGCxcuVHZ2toqKiuTxeNSmTRvdcMMNuvTSS93oDwAAoFYhXZm55557lJGRodmzZ+uHH37Qzp07NXPmTF1++eV68MEH3eoRAACgRkGHmTlz5mjatGl6/fXXtWvXLi1fvlwrVqzQzp079corr+jll1/W/Pnz3ewVAACgiqDDzLRp0zRy5EgNGzZMHo/n5wJRUbrjjjs0YsQIvfbaa6402ZB1/9PgcLcQcWyaiVu92jSD+hLpM4n0/mpjy3lsS5/1Vbs+6tsi6DCTn5+v6667rsb1v/rVr7Ry5UpHmgIAAAhW0GFm165datmyZY3rW7ZsqR9++MGRpgAAAIIVdJg5cOCAvF5vjesbNWqkAwcOONIUAABAsEL6aPZjjz2mJk2aVLuurKzMkYYAAABCEXSYufTSS1VYWHjMbQAAAOpT0GFm8eLFLrYBAABwfLidAQAAsFrQV2ZGjhwZ1HbPPvvscTcDAAAQqqDDzKpVq465zZF/TA8AAKA+BB1mFi1a5GYfAAAAx4X3zAAAAKsRZgAAgNUIMwAAwGqEmTBbPfqtcLcQcWyaiVu92jSD+hLpM4n0/mpjy3lsS5/1Vbs+6tuCMAMAAKx2XGFm6dKl+s1vfqNevXppy5YtkqSsrCwtW7bM0eYAAACOJeQw8+6772rAgAGKjY3VqlWrVF5eLkkqKSnRU0895XiDAAAAtQk5zDz55JN66aWX9Morr8jr9QaW9+7dW/n5+Y42BwAAcCwhh5nCwsJq746dmJioPXv2ONETAABA0EIOM2lpafrmm2+qLF+2bJnatm3rSFMAAADBCjnM3HXXXXr44YeVl5cnj8ejrVu3aubMmRo9erTuvfdeN3oEAACoUdD3Zjrs0UcfVUVFhfr27auysjJdeumliomJ0ejRo/Xggw+60SMAAECNQr4y4/F4NG7cOBUXF2vt2rVasWKFdu7cqd/97neONNSnTx+NGDHCkVoAAKDhCznMlJSUqLi4WNHR0ercubMuuOACxcXFqbi4WD6fz7HG/H6/HnnkEXXt2lVNmzZVixYtNHToUG3dutWxfQAAAPuFHGYGDx6st96q+ueT33nnHQ0ePNiRpiSprKxM+fn5euyxx5Sfn6/s7GwVFhbq6quvdmwfAADAfiGHmby8PF1++eVVlvfp00d5eXkh1dq3b5+GDh2quLg4paena/LkyYF1iYmJys3N1U033aROnTrpwgsv1F//+letXLlSmzdvDrVtAADQQIX8BuDy8nL99NNPVZb7/X7t378/pFqZmZlasmSJ5s2bp9TUVI0dO1b5+fnq0aNHtduXlJTI4/EoKSmp1v4O/1ViSYGXvvx+v/x+f0j9Hcvhek7XBbN1G/N1F/N1F/N1V6TMN5T9e4wxJpTil19+uc466yxNmTKl0vL7779fa9as0dKlS4OqU1paqpSUFM2YMUM33nijJKm4uFgtW7bU3Xffreeee67S9j/++KN69+6tM844QzNnzqyx7oQJEzRx4sQqy2fNmqUmTZoE1RsAAAivsrIy3XLLLSopKVFCQkKt24Z8ZebJJ59URkaGVq9erb59+0qSFixYoM8++0wfffRR0HU2bNigAwcOqGfPnoFlycnJ6tSpU5Vt/X6/brrpJhlj9OKLL9Zad8yYMRo5cmTge5/Pp1atWql///7HHEao/H6/cnNz1a9fv0q3dqjWlMTqlz9Y4lxDR+/DydpH1ner7mEPloQ22xBr14lbM66vY/ff2o7N9+ja/63vCCfrutFjLTUdme9Rx8xRNhyzWurVeb5uP9+O3IeFx87viVVu29ed+f+hDkL5UFHIYaZ3795avny5Jk2apHfeeUexsbHq1q2bXnvtNXXo0CHUcsd0OMhs2rRJCxcuPGYgiYmJUUxMTJXlXq/XtYMSVG1Tw0twTvZ09D6cfryH67tV97Aj6tf5uDk9E7dmXF/H7qjajjwvbJiJGz0GUbNO863hmDnChmMWRL3jnq/bz7cj92HrsZO7vzeDEcq+Qw4zktSjR49aX+oJRrt27eT1epWXl6fWrVtLknbv3q1169bpsssuk/RzkFm/fr0WLVqklJSUOu0TAAA0PEGFGZ/PF7gicqzLPsG+lBMXF6fhw4crMzNTKSkpSk1N1bhx4xQVdegDVn6/XzfccIPy8/P1/vvv6+DBg9q+fbukQy9HRUdHB7UfAADQsAUVZpo1a6Zt27YpNTVVSUlJ8ng8VbYxxsjj8ejgwYNB73zSpEkqLS3VoEGDFB8fr1GjRqmk5NDrf1u2bNH8+fMlqcqnmxYtWqQ+ffoEvR8AANBwBRVmFi5cqOTk5MDX1YWZ4xEXF6esrCxlZWUFlmVmZga+DvGDVgAA4AQUVJg5/B4WSVwRAQAAESXkvwDcoUMHTZgwQevXr3ejHwAAgJCEHGbuu+8+/eMf/9AZZ5yh888/X88//3zgjbkAAAD1LeQw8z//8z/67LPP9NVXX2ngwIGaOnVq4I/Svfnmm270CAAAUKOQw8xhHTt21MSJE7Vu3TotXbpUO3fu1O233+5kbwAAAMd0XH8077BPP/1Us2bN0ttvvy2fzxe4xxIAAEB9CTnMrFu3TjNnztTs2bO1ceNG/eIXv9AzzzyjX/3qV4qLi3OjRwAAgBqFHGYOv/H3/vvv1+DBg9W8eXM3+mp4RtXD38w5zn0E/2eD/lt/dJBbB9uOm7NxurZbvdbXsZOk0VJsrDR7tpSYKO2v4bZhYT9+TtZ1o8c61AzuuFU+ZkH9REM6Zm7Uc6CuW/9fSkEev/qYid8v5eS4sx+XhBxmCgsLXbmhJAAAwPE4rr8zs2fPHr366qsaM2aMiouLJUn5+fnasmWL4w0CAADUJuQrM2vWrFHfvn2VlJSkoqIi3XXXXUpOTlZ2drY2b97Mx7MBAEC9Oq6/M3P77bdr/fr1aty4cWD5wIED9fHHHzvaHAAAwLGEfGXm888/18svv1xl+amnnspfAgYAAPUu5CszMTEx8vl8VZavW7dOp5xyiiNNAQAABCvkMHP11VfriSeekN/vlyR5PB5t3rxZjzzyiK6//nrHGwQAAKhNyGFm8uTJKi0tVWpqqvbv36/LLrtM7du3V3x8vH7/+9+70SMAAECNQn7PTGJionJzc7Vs2TKtWbNGpaWlOuecc5SRkeFGfwAAALU67nszXXzxxbr44oud7AUAACBkIYWZiooKTZ8+XdnZ2SoqKpLH41GbNm10ww03aMiQIfIE/3eeAQAAHBH0e2aMMbr66qt15513asuWLeratau6dOmiTZs2adiwYbruuuvc7BMAAKBaQV+ZmT59uj7++GMtWLBAl19+eaV1Cxcu1LXXXqs333xTQ4cOdbxJAACAmgR9ZWb27NkaO3ZslSAjSb/4xS/06KOPaubMmY42h/phjDv/4L7jPTYlJYd+vqSE4xcOPOfs5dax4/jVTdBhZs2aNbriiitqXH/llVdq9erVjjQFAAAQrKDDTHFxsZo3b17j+ubNm2v37t2ONAUAABCsoMPMwYMH1ahRzW+xOemkk/TTTz850hQAAECwgn4DsDFGw4YNU0xMTLXry8vLHWsKAAAgWEGHmdtuu+2Y2/BJJgAAUN+CDjPTpk1zsw8AAIDjEvKNJgEAACIJYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQSl+58GW1XXDbbOwM36ts6krjhm7teN9HOgNrbMuK51I+kYEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWM1jjDHhbsJNPp9PiYmJKikpUUJCgqO1/X6/cnJyNHDgQHm9Xkdrn+iYrbuYr7uYr7uYr7siZb6h/P7mygwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYLWICzN9+vTRiBEjwt0GAACwRMSFmSNlZ2erf//+SklJkcfjUUFBQbhbAgAAESaiw8y+fft08cUX65lnngl3KwAAIEI1CufO9+3bp3vvvVfZ2dmKj4/X6NGjK60fMmSIJKmoqCgM3QEAABuENcxkZmZqyZIlmjdvnlJTUzV27Fjl5+erR48ex12zvLxc5eXlge99Pp+kQzfO8vv9dW25ksP1nK4LZus25usu5usu5uuuSJlvKPsP212zS0tLlZKSohkzZujGG2+UJBUXF6tly5a6++679dxzzwW2LSoqUps2bbRq1apjBp0JEyZo4sSJVZbPmjVLTZo0cfIhAAAAl5SVlemWW24J6q7ZYbsys2HDBh04cEA9e/YMLEtOTlanTp3qVHfMmDEaOXJk4Hufz6dWrVqpf//+xxxGqPx+v3Jzc9WvX7+fb5M+JbH6jR8scXTflRzep5P7OPpxOFU7yLrVztah2kFzawZH13er7mHV1D+u+QZZ+7jYcuyCPGYhz9fNc60BHrM6n79heM5FVN2jax9V97jn67DDr6wEI6wvM7khJiZGMTExVZZ7vV7XDkql2mZ/TRu5su9K+3RyH0c/Dqdqh1g3pOPmdM9uzeDo+m7VPayW+iE/LyLkvKj3ekfXDbJe0PN181xrwMfsuM/fMD7nIqLu0bVrqOvm781ghLLvsH2aqV27dvJ6vcrLywss2717t9atWxeulgAAgIXCdmUmLi5Ow4cPV2ZmplJSUpSamqpx48YpKurnfFVcXKzNmzdr69atkqTCwkJJUlpamtLS0sLSNwAAiCxhfZlp0qRJKi0t1aBBgxQfH69Ro0appOTn1+7mz5+v22+/PfD94MGDJUnjx4/XhAkT6rtdAAAQgcIaZuLi4pSVlaWsrKzAsszMzMDXw4YN07Bhw8LQGQAAsEVE/wVgAACAYyHMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYrcHdmynsRrl/E3KP5+gl/93n6OOrV+190916HG7Ox+naDtaresykuh43qZ6PnZu1I/DY1eWYxcZKs2dLiYnS/mpu11bluHHMHK/XIJ5zNp4XYcKVGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABW467ZFqr2rq2IaBwz+9TlmPn9Uk6OVFIieb3O9YTg8Zw7sXBlBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMO6v6nwQ1in249Djfn43Rtt49lOM6VurLxvHCSLX0e6UQ/ZkdiFj9zoudIe9yEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABW8xhjTLibcJPP51NiYqJKSkqUkJDgaG2/36+cnBwNHDhQXq/X0donOmbrLubrLubrLubrrkiZbyi/v7kyAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYLeLCTJ8+fTRixIhwtwEAACwRcWHmSMYYPf7440pPT1dsbKwyMjK0fv36cLcFAAAiSESHmT/+8Y/6y1/+opdeekl5eXlq2rSpBgwYoB9//DHcrQEAgAgR1jCzb98+DR06VHFxcUpPT9fkyZMD64wxeu655/Tb3/5W11xzjbp166Y333xTW7du1dy5c8PXNAAAiCiNwrnzzMxMLVmyRPPmzVNqaqrGjh2r/Px89ejRQxs3btT27duVkZER2D4xMVE9e/bU8uXLNXjw4GprlpeXq7y8PPC9z+eTdOjGWX6/39H+D9dzui6YrduYr7uYr7uYr7siZb6h7D9sd80uLS1VSkqKZsyYoRtvvFGSVFxcrJYtW+ruu+/WTTfdpN69e2vr1q1KT08P/NxNN90kj8ejt99+u9q6EyZM0MSJE6ssnzVrlpo0aeLOgwEAAI4qKyvTLbfcEtRds8N2ZWbDhg06cOCAevbsGViWnJysTp061anumDFjNHLkyMD3Pp9PrVq1Uv/+/Y85jFD5/X7l5uaq37d3yGv2Sw+WOFq/RlMSf/7aiX0eWc+pmnWsG5htv37V34Le7Z6dPpb1NeMgax9zvtXVjoDzIujaYX5ehP38tbGuk/M9unaI9WtVH+ev07VDrBvUfOvB4VdWghHWl5lqk5aWJkn6/vvvK12Z+f7779WjR48afy4mJkYxMTFVlnu9XtcOitfsPxRm6uugm/1H7NyBfR5Zz6maDtWt8bi53bPTx7K+Zhxi7VqfFxF8XhyzdoQ8L8J2/tpY18n5Hl37OOvXa103ax9nXTd/bwa7/2CF7Q3A7dq1k9frVV5eXmDZ7t27tW7dOklSmzZtlJaWpgULFgTW+3w+5eXlqVevXvXeLwAAiExhuzITFxen4cOHKzMzUykpKUpNTdW4ceMUFXUoX3k8Ho0YMUJPPvmkOnTooDZt2uixxx5TixYtdO2114arbQAAEGHC+jLTpEmTVFpaqkGDBik+Pl6jRo1SScnPr+X9v//3/7Rv3z7dfffd2rNnjy6++GJ9+OGHaty4cRi7BgAAkSSsYSYuLk5ZWVnKysoKLMvMzAx87fF49MQTT+iJJ54IR3sAAMACEf0XgAEAAI6FMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYLWIvTeTTRJ/W6L9+73SaPf3ZYykUQ7f6Nzpem7XdbG2Z/R/6zp8LF27Ob2FM3azZyePn2vHTLJvtm4/36SQjllsrDR7tpSYKO2v5vZkEs+5eqkbQbgyAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBq3DXbASUlktcb7i7gBDdvlAz3cfzscrzHy++XcnL4vxc/48oMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqNwt2A24wxkiSfz+d4bb/fr7KyMvl8Pnm9Xsfrn8iYrbuYr7uYr7uYr7siZb6Hf28f/j1emwYfZvbu3StJatWqVZg7AQAAodq7d68SExNr3cZjgok8FquoqNDWrVsVHx8vj8fjaG2fz6dWrVrpu+++U0JCgqO1T3TM1l3M113M113M112RMl9jjPbu3asWLVooKqr2d8U0+CszUVFRatmypav7SEhI4AnlEmbrLubrLubrLubrrkiY77GuyBzGG4ABAIDVCDMAAMBqhJk6iImJ0fjx4xUTExPuVhocZusu5usu5usu5usuG+fb4N8ADAAAGjauzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCzHGaOnWqTj/9dDVu3Fg9e/bUp59+Gu6WrPDxxx9r0KBBatGihTwej+bOnVtpvTFGjz/+uNLT0xUbG6uMjAytX7++0jbFxcW69dZblZCQoKSkJA0fPlylpaX1+Cgi09NPP63zzz9f8fHxSk1N1bXXXqvCwsJK2/z444+6//77lZKSori4OF1//fX6/vvvK22zefNmXXXVVWrSpIlSU1OVmZmpn376qT4fSkR68cUX1a1bt8AfEuvVq5c++OCDwHpm65w//OEP8ng8GjFiRGAZ862bCRMmyOPxVPp3xhlnBNZbP1+DkL311lsmOjravP766+Y///mPueuuu0xSUpL5/vvvw91axMvJyTHjxo0z2dnZRpKZM2dOpfV/+MMfTGJiopk7d65ZvXq1ufrqq02bNm3M/v37A9tcccUVpnv37mbFihVm6dKlpn379ubmm2+u50cSeQYMGGCmTZtm1q5dawoKCszAgQNN69atTWlpaWCbe+65x7Rq1cosWLDAfP755+bCCy80F110UWD9Tz/9ZM466yyTkZFhVq1aZXJycszJJ59sxowZE46HFFHmz59v/vGPf5h169aZwsJCM3bsWOP1es3atWuNMczWKZ9++qk5/fTTTbdu3czDDz8cWM5862b8+PGmS5cuZtu2bYF/O3fuDKy3fb6EmeNwwQUXmPvvvz/w/cGDB02LFi3M008/Hcau7HN0mKmoqDBpaWlm0qRJgWV79uwxMTExZvbs2cYYY7788ksjyXz22WeBbT744APj8XjMli1b6q13G+zYscNIMkuWLDHGHJql1+s1f/vb3wLbfPXVV0aSWb58uTHmUNiMiooy27dvD2zz4osvmoSEBFNeXl6/D8ACzZo1M6+++iqzdcjevXtNhw4dTG5urrnssssCYYb51t348eNN9+7dq13XEObLy0whOnDggFauXKmMjIzAsqioKGVkZGj58uVh7Mx+Gzdu1Pbt2yvNNjExUT179gzMdvny5UpKStJ5550X2CYjI0NRUVHKy8ur954jWUlJiSQpOTlZkrRy5Ur5/f5K8z3jjDPUunXrSvPt2rWrmjdvHthmwIAB8vl8+s9//lOP3Ue2gwcP6q233tK+ffvUq1cvZuuQ+++/X1dddVWlOUqcu05Zv369WrRoobZt2+rWW2/V5s2bJTWM+Tb4G006bdeuXTp48GClAypJzZs319dffx2mrhqG7du3S1K1sz28bvv27UpNTa20vlGjRkpOTg5sg0N3ix8xYoR69+6ts846S9Kh2UVHRyspKanStkfPt7r5H153ovviiy/Uq1cv/fjjj4qLi9OcOXPUuXNnFRQUMNs6euutt5Sfn6/PPvusyjrO3brr2bOnpk+frk6dOmnbtm2aOHGiLrnkEq1du7ZBzJcwAzRA999/v9auXatly5aFu5UGpVOnTiooKFBJSYn+/ve/67bbbtOSJUvC3Zb1vvvuOz388MPKzc1V48aNw91Og3TllVcGvu7WrZt69uyp0047Te+8845iY2PD2JkzeJkpRCeffLJOOumkKu/y/v7775WWlhamrhqGw/OrbbZpaWnasWNHpfU//fSTiouLmf9/PfDAA3r//fe1aNEitWzZMrA8LS1NBw4c0J49eyptf/R8q5v/4XUnuujoaLVv317nnnuunn76aXXv3l3PP/88s62jlStXaseOHTrnnHPUqFEjNWrUSEuWLNFf/vIXNWrUSM2bN2e+DktKSlLHjh31zTffNIjzlzAToujoaJ177rlasGBBYFlFRYUWLFigXr16hbEz+7Vp00ZpaWmVZuvz+ZSXlxeYba9evbRnzx6tXLkysM3ChQtVUVGhnj171nvPkcQYowceeEBz5szRwoUL1aZNm0rrzz33XHm93krzLSws1ObNmyvN94svvqgUGHNzc5WQkKDOnTvXzwOxSEVFhcrLy5ltHfXt21dffPGFCgoKAv/OO+883XrrrYGvma+zSktLtWHDBqWnpzeM8zfc70C20VtvvWViYmLM9OnTzZdffmnuvvtuk5SUVOld3qje3r17zapVq8yqVauMJPPss8+aVatWmU2bNhljDn00OykpycybN8+sWbPGXHPNNdV+NPvss882eXl5ZtmyZaZDhw58NNsYc++995rExESzePHiSh+/LCsrC2xzzz33mNatW5uFCxeazz//3PTq1cv06tUrsP7wxy/79+9vCgoKzIcffmhOOeWUiPn4ZTg9+uijZsmSJWbjxo1mzZo15tFHHzUej8d89NFHxhhm67QjP81kDPOtq1GjRpnFixebjRs3mk8++cRkZGSYk08+2ezYscMYY/98CTPHacqUKaZ169YmOjraXHDBBWbFihXhbskKixYtMpKq/LvtttuMMYc+nv3YY4+Z5s2bm5iYGNO3b19TWFhYqcYPP/xgbr75ZhMXF2cSEhLM7bffbvbu3RuGRxNZqpurJDNt2rTANvv37zf33XefadasmWnSpIm57rrrzLZt2yrVKSoqMldeeaWJjY01J598shk1apTx+/31/Ggizx133GFOO+00Ex0dbU455RTTt2/fQJAxhtk67egww3zr5te//rVJT0830dHR5tRTTzW//vWvzTfffBNYb/t8PcYYE55rQgAAAHXHe2YAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgDUm8WLF8vj8VS5oZ0T+vTpoxEjRjheF0DkaxTuBgDACdnZ2fJ6veFuo0YTJkzQ3LlzVVBQ4Ei9YcOGac+ePZo7d64j9QCbEWYANAjJycnhbqFeHDx4UB6PJ9xtABGFl5kAS/39739X165dFRsbq5SUFGVkZGjfvn2B9a+++qrOPPNMNW7cWGeccYZeeOGFSj//f//3f7r55puVnJyspk2b6rzzzlNeXl5g/Ysvvqh27dopOjpanTp1UlZWVqWf93g8evXVV3XdddepSZMm6tChg+bPn19pm5ycHHXs2FGxsbG6/PLLVVRUVGn9pk2bNGjQIDVr1kxNmzZVly5dlJOTU+NjfuGFF9ShQwc1btxYzZs31w033BBYd/TLTKeffrqeeuop3XHHHYqPj1fr1q318ssvhzSDefPm6ZxzzlHjxo3Vtm1bTZw4UT/99FON/S1evFgXXHCBmjZtqqSkJPXu3VubNm3S9OnTNXHiRK1evVoej0cej0fTp0+XJD377LPq2rWrmjZtqlatWum+++5TaWlpoOb06dOVlJSk+fPnq3PnzoqJidEdd9yhN954Q/PmzQvUW7x4cY19AQ1euO90CSB0W7duNY0aNTLPPvus2bhxo1mzZo2ZOnVq4O7hM2bMMOnp6ebdd9813377rXn33XdNcnKymT59ujHGmL1795q2bduaSy65xCxdutSsX7/evP322+bf//63McaY7Oxs4/V6zdSpU01hYaGZPHmyOemkk8zChQsDPUgyLVu2NLNmzTLr1683Dz30kImLizM//PCDMcaYzZs3m5iYGDNy5Ejz9ddfmxkzZpjmzZsbSWb37t3GGGOuuuoq069fP7NmzRqzYcMG895775klS5ZU+5g/++wzc9JJJ5lZs2aZoqIik5+fb55//vnA+qPvsnzaaaeZ5ORkM3XqVLN+/Xrz9NNPm6ioKPP1118HNYOPP/7YJCQkmOnTp5sNGzaYjz76yJx++ulmwoQJ1fbn9/tNYmKiGT16tPnmm2/Ml19+aaZPn242bdpkysrKzKhRo0yXLl3Mtm3bzLZt20xZWZkxxpg///nPZuHChWbjxo1mwYIFplOnTubee+8N1J02bZrxer3moosuMp988on5+uuvTUlJibnpppvMFVdcEahXXl4e3MkDNECEGcBCK1euNJJMUVFRtevbtWtnZs2aVWnZ7373O9OrVy9jjDH/+7//a+Lj4wPB42gXXXSRueuuuyotu/HGG83AgQMD30syv/3tbwPfl5aWGknmgw8+MMYYM2bMGNO5c+dKNR555JFKYaZr1641hoOjvfvuuyYhIcH4fL5q11cXZn7zm98Evq+oqDCpqanmxRdfNMYcewZ9+/Y1Tz31VKVlWVlZJj09vdrtf/jhByPJLF68uNr148ePN927d6/p4QX87W9/MykpKYHvp02bZiSZgoKCStvddttt5pprrjlmPeBEwMtMgIW6d++uvn37qmvXrrrxxhv1yiuvaPfu3ZKkffv2acOGDRo+fLji4uIC/5588klt2LBBklRQUKCzzz67xveZfPXVV+rdu3elZb1799ZXX31VaVm3bt0CXzdt2lQJCQnasWNHoEbPnj0rbd+rV69K3z/00EN68skn1bt3b40fP15r1qyp8TH369dPp512mtq2bashQ4Zo5syZKisrq21MlfrzeDxKS0sL9HesGaxevVpPPPFEpRnedddd2rZtW7X7TU5O1rBhwzRgwAANGjRIzz//vLZt21Zrf5L0r3/9S3379tWpp56q+Ph4DRkyRD/88EOlfURHR1d6LAAqI8wAFjrppJOUm5urDz74QJ07d9aUKVPUqVMnbdy4MfB+i1deeUUFBQWBf2vXrtWKFSskSbGxsY70cfSnhzwejyoqKoL++TvvvFPffvuthgwZoi+++ELnnXeepkyZUu228fHxys/P1+zZs5Wenq7HH39c3bt3r/Vj3rX1d6wZlJaWauLEiZVm+MUXX2j9+vVq3LhxtT8zbdo0LV++XBdddJHefvttdezYMTDz6hQVFemXv/ylunXrpnfffVcrV67U1KlTJUkHDhwIbBcbG8ubfoFaEGYAS3k8HvXu3VsTJ07UqlWrFB0drTlz5qh58+Zq0aKFvv32W7Vv377SvzZt2kg6dMWioKBAxcXF1dY+88wz9cknn1Ra9sknn6hz585B93fmmWfq008/rbSsul/srVq10j333KPs7GyNGjVKr7zySo01GzVqpIyMDP3xj3/UmjVrVFRUpIULFwbd05GONYNzzjlHhYWFVWbYvn17RUXV/F/n2WefrTFjxujf//63zjrrLM2aNUvSoasrBw8erLTtypUrVVFRocmTJ+vCCy9Ux44dtXXr1qD6r64ecKLio9mAhfLy8rRgwQL1799fqampysvL086dO3XmmWdKkiZOnKiHHnpIiYmJuuKKK1ReXq7PP/9cu3fv1siRI3XzzTfrqaee0rXXXqunn35a6enpWrVqlVq0aKFevXopMzNTN910k84++2xlZGTovffeU3Z2tv71r38F3eM999yjyZMnKzMzU3feeadWrlwZ+ATPYSNGjNCVV16pjh07avfu3Vq0aFHgMRzt/fff17fffqtLL71UzZo1U05OjioqKtSpU6fjmuGxZvD444/rl7/8pVq3bq0bbrhBUVFRWr16tdauXasnn3yySr2NGzfq5Zdf1tVXX60WLVqosLBQ69ev19ChQyUd+nTVxo0bVVBQoJYtWyo+Pl7t27eX3+/XlClTNGjQIH3yySd66aWXgur/9NNP1z//+U8VFhYqJSVFiYmJEf13dgBXhftNOwBC9+WXX5oBAwaYU045xcTExJiOHTuaKVOmVNpm5syZpkePHiY6Oto0a9bMXHrppSY7OzuwvqioyFx//fUmISHBNGnSxJx33nkmLy8vsP6FF14wbdu2NV6v13Ts2NG8+eablepLMnPmzKm0LDEx0UybNi3w/XvvvWfat29vYmJizCWXXGJef/31Sm8AfuCBB0y7du1MTEyMOeWUU8yQIUPMrl27qn3MS5cuNZdddplp1qyZiY2NNd26dTNvv/12YH11bwD+85//XKlG9+7dzfjx44OewYcffmguuugiExsbaxISEswFF1xgXn755Wr72759u7n22mtNenq6iY6ONqeddpp5/PHHzcGDB40xxvz444/m+uuvN0lJSUZSYE7PPvusSU9PN7GxsWbAgAHmzTffrDSjadOmmcTExCr727Fjh+nXr5+Ji4szksyiRYuq7Qs4EXiMMSasaQoAAKAOeM8MAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKz2/wEpLEe81QFmKwAAAABJRU5ErkJggg==", + "text/plain": [ + "<Figure size 640x480 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "get_device_activity_plot(\n", + " [swarm_device_train_times, swarm_client_train_times, swarm_device_eval_times, swarm_client_eval_times],\n", + " num_devices=5).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ce5f2dc-414a-407e-9ed4-0e99c65fbb34", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39029960-0f3c-4a1d-9841-e49f1cb28ca8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aee2027a-45a5-46ad-a86c-f54a63139601", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05a72678-0cfb-437b-8834-e4a478a95d03", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/results/plotting_helpers.py b/results/plotting_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..519b164d66adc214ebf5198db3cfb7e6627fe55b --- /dev/null +++ b/results/plotting_helpers.py @@ -0,0 +1,111 @@ +import math + +import matplotlib.pyplot as plt +import numpy as np + +""" +A set of helpers to process and plot the method duration data logged to wandb. +""" + + +def get_earliest_start_time(dataframes): + """Given dataframes with a 'Wall Time' column, representing the logging time and several method execution durations + this function returns the earliest start time for the first method execution of all dataframes. + """ + earliest_start = math.inf + for df in dataframes: + first_end = df["Wall Time"][0] + time_columns = [ + col + for col in df.columns + if not "_MAX" in col + and not "_MIN" in col + and not "_step" in col + and not "Wall" in col + ] + max_runtime = np.nanmax(df[time_columns].loc[0].values) + if earliest_start > first_end - max_runtime: + earliest_start = first_end - max_runtime + return earliest_start + + +def get_last_end_time(dataframes): + """Returns the largest Wall Time of the given dataframes.""" + last_end = 0 + for df in dataframes: + last_wall_time = df["Wall Time"][df.shape[0] - 1] + if last_end < last_wall_time: + last_end = last_wall_time + return last_end + + +def get_runtime(dataframes): + return get_last_end_time(dataframes) - get_earliest_start_time(dataframes) + + +def get_time_dict( + dataframe, + offset, + format_string_fn=lambda device_idx: f"Group: d{device_idx} - train_global_time", + device_indices=range(5), +): + """Returns a dictionary with device_idx as keys and a list of tuples for each device + where each tuple consists of the method start time and duration. + Example: + {0: [(0.0002541542053222656, 4.316676139831543), + (34.571627140045166, 2.5839357376098637), + (60.37545871734619, 2.6242687702178955)], + 1: [(4.9257495403289795, 6.421962022781372), + (37.768064975738525, 4.779199123382568), + (63.61771893501282, 4.6916210651397705)]""" + time_dict = {} + for i in device_indices: + time_dict[i] = [] + num_devices = len(device_indices) + for idx, row in dataframe.iterrows(): + device_idx = device_indices[ + idx % num_devices + ] # in case there is only one device with index > 0 + time_dict[device_idx].append( + ( + (row["Wall Time"] - offset) - row[format_string_fn(device_idx)], + row[format_string_fn(device_idx)], + ) + ) + return time_dict + + +def get_device_activity_plot( + data, bar_height=2, num_devices=5, device_space=10, y_offset=5, colors=None +): + """Plots the activity of the devices over time. The data is a list of dictionaries, where each dictionary + corresponds to a method""" + if colors is None: + colors = ["blue", "darkorange", "lightgrey", "seagreen"] + fig, ax = plt.subplots() + num_methods = len(data) + + for method_idx, device_dict in enumerate(data): + for device_idx in device_dict.keys(): + ax.broken_barh( + device_dict[device_idx], # (xmin, xwidth) + ( + device_space * device_idx + + bar_height * (method_idx - num_methods / 2) + + y_offset, + bar_height, + ), # (ymin, yheight) + facecolors=f"{colors[method_idx % len(colors)]}", + ) + + # ax.set_xlim(left=350) + # ax.set_ylim(bottom=0) + ax.set_xlabel("seconds since start") + ax.set_ylabel("Device ID") + ax.set_yticks( + [(x * device_space + y_offset) for x in range(num_devices)], + labels=["d0", "d1", "d2", "d3", "d4"][:num_devices], + ) # Modify y-axis tick labels + ax.grid(True) # Make grid lines visible + + return plt diff --git a/results/result_generation.ipynb b/results/result_generation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..85b37ddb464b85b1e634dcbcde25c0a221baa9fb --- /dev/null +++ b/results/result_generation.ipynb @@ -0,0 +1,172 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "26fa58dbda0ddc40", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from result_generation import save_dataframes, load_dataframes, generate_metric_files, generate_plots" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Downloading the dataframes (takes time)" + ], + "metadata": { + "collapsed": false + }, + "id": "3e72d0fe76a55e56" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "df_base_dir = \"./dataframes\"\n", + "projects_with_model = [\n", + " (\"greedy_vs_smart_ecg-non-iid_RESULT\", \"tcn\"),\n", + " (\"greedy_vs_smart_cifar100_resnet20_RESULT\", \"resnet20\"),\n", + " (\"greedy_vs_smart_ecg-iid_RESULT\", \"tcn\"),\n", + " (\"greedy_vs_smart_PTBXL_equal_devices_RESULT\", \"tcn\"),\n", + " (\"greedy_vs_smart_PTBXL_unequal_processors_RESULT\", \"tcn\"),\n", + " (\"greedy_vs_smart_MNIST_unequal_processors_RESULT\", \"simple_conv\"),\n", + " (\"greedy_vs_smart_MNIST_unequal_batteries_unequal_partition_RESULT\", \"simple_conv\"),\n", + " (\"greedy_vs_smart_MNIST_equal_devices_RESULT\", \"simple_conv\"),\n", + " (\"greedy_vs_smart_MNIST_unequal_batteries_RESULT\", \"simple_conv\"),\n", + " (\"fed_vs_split_MNIST_limited_batteries_RESULT\", \"simple_conv\"),\n", + " (\"fed_vs_split_MNIST_unlimited_batteries_RESULT\", \"simple_conv\"),\n", + " (\"fed_vs_split_PTBXL_limited_batteries_RESULT\", \"tcn\"),\n", + " (\"fed_vs_split_PTBXL_unlimited_batteries_RESULT\", \"tcn\"),\n", + " (\"fed_vs_split_cifar100_unlimited_batteries_RESULT\", \"resnet20\"),\n", + " (\"fed_vs_split_CIFAR100_limited_batteries_RESULT\", \"resnet20\"),\n", + " (\"fed_vs_split_50_devices_RESULT\", \"resnet110\"),\n", + " (\"greedy_vs_smart_CIFAR100_equal_devices_RESULT\", \"resnet20\"),\n", + " (\"greedy_vs_smart_CIFAR100_unequal_processors_RESULT\", \"resnet20\"),\n", + "]" + ], + "metadata": { + "collapsed": false + }, + "id": "6695251b9af7ea4b" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "for project_name, _ in projects_with_model:\n", + " save_dataframes(project_name=project_name, strategies=[\n", + " \"swarm_seq\",\n", + " \"fed\",\n", + " \"swarm_max\",\n", + " \"swarm_rand\",\n", + " \"swarm_smart\",\n", + " \"split\"\n", + " ])" + ], + "metadata": { + "collapsed": false + }, + "id": "b07913828b33ffcc" + }, + { + "cell_type": "markdown", + "source": [ + "# Generating Results from saved dataframes" + ], + "metadata": { + "collapsed": false + }, + "id": "d8269abd823cdcc7" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Required for total number of FLOPs computation\n", + "model_flops = {\n", + " \"resnet20\": 41498880,\n", + " \"resnet110\": 258136320,\n", + " \"tcn\": 27240000,\n", + " \"simple_conv\": 16621560\n", + "}" + ], + "metadata": { + "collapsed": false + }, + "id": "828bcb4737b21c6d" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "plots_base_path=\"./plots\"\n", + "metrics_base_path=\"./metrics\"" + ], + "metadata": { + "collapsed": false + }, + "id": "b13f9e0e98b7ac5b" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "def generate_result_files(projects_with_model):\n", + " for proj_name, model_name in projects_with_model:\n", + " print(proj_name)\n", + " print(\" loading data from disk\")\n", + " dataframes = load_dataframes(proj_name, df_base_dir)\n", + " print(\" generating metrics\")\n", + " generate_metric_files(dataframes, proj_name, model_flops[model_name])\n", + " print(\" generating plots\")\n", + " generate_plots(dataframes, proj_name)" + ], + "metadata": { + "collapsed": false + }, + "id": "c4b1ed2d809c54e2" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "generate_result_files(projects_with_model)" + ], + "metadata": { + "collapsed": false + }, + "id": "378bf3365dd9fde2" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/results/result_generation.py b/results/result_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..b95183ebae7218c3c6ae7453b6bc103d69ba163e --- /dev/null +++ b/results/result_generation.py @@ -0,0 +1,1109 @@ +import os +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import wandb + +# For plotting +STRATEGY_MAPPING = { + "split": "Vanilla SL", + "swarm_smart": "Swarm SL (Smart)", + "swarm_seq": "Swarm SL (Seq)", + "swarm_rand": "Swarm SL (Rand)", + "swarm_max": "Swarm SL (Greedy)", + "fed": "Vanilla FL", +} + +LABEL_MAPPING = { + "num_devices": "Number of Devices", + "runtime": "Runtime [s]", + "round": "Round", + "total network battery": "Total System Battery", + "device battery": "Device Battery", + "accuracy": "Accuracy", + "train accuracy": "Train Accuracy", + "val accuracy": "Validation Accuracy", + "device": "Device", +} + + +def save_dataframes(project_name, strategies, base_dir="./dataframes"): + """ + Fetches the dataframes from wandb and saves them to the base_dir. + Args: + project_name: (str) the name of the project in wandb + strategies: (list) a list of strategies to fetch the dataframes for + base_dir: (str) the base directory to save the dataframes to + Notes: + - The wandb project is expected to have the following structure: + project_name + ├── split + │ ├── job: train + │ │ ├── d0 + │ │ ├── d1 + │ │ ├── ... + │ ├── job: test + │ │ ├── d0 + │ │ ├── controller + ├── swarm_smart + │ ├── ... + If an experiment was run multiple times, there may be several runs for each device_id. + + - The dataframes are saved in the following structure: + base_dir + ├── project_name + ├── strategy e.g. split + ├── job e.g. train + ├── device_id e.g. d0 + ├── df_0.csv + ├── df_1.csv + ├── ... + """ + print(project_name) + wandb.login() + api = wandb.Api(timeout=120) + run_groups = {} + runs = api.runs(project_name) + + for strategy in strategies: + for job in ["train", "test"]: + filtered_runs = [ + run for run in runs if run.group == strategy and run.job_type == job + ] + runs_by_name = defaultdict(list) + for run in filtered_runs: + runs_by_name[run.name].append(run) + run_groups[(strategy, job)] = runs_by_name + # then fetch dataframes for each run + print("downloading data") + history_groups = {} + for (strategy, job), group in run_groups.items(): + print(f" {strategy} {job}") + history = defaultdict(list) + for name, runs in group.items(): + print(f" {name}") + for run in runs: + history_df = pd.DataFrame(run.scan_history()) + history[name].append(history_df) + history_groups[(strategy, job)] = history + # save dataframe + print("saving data") + for (strategy, job), group in history_groups.items(): + print(f" {strategy} {job}") + for device_id, runs in group.items(): + print(f" {device_id}") + for idx, run_df in enumerate(runs): + dir_path = os.path.join( + base_dir, project_name, strategy, job, device_id + ) + os.makedirs(dir_path, exist_ok=True) + file_path = os.path.join(dir_path, f"df_{idx}.csv") + run_df.to_csv(file_path) + + +def load_dataframes(project_name, base_dir="./dataframes"): + """ + Loades saved dataframes from the given project. + Args: + project_name: (str) the name of the project folder + base_dir: (str) the base directory to fetch the dataframes from + Notes: + - The dataframes are assumed to be stored in the following structure: + base_dir + ├── project_name + ├── strategy e.g. split + ├── job e.g. train + ├── device_id e.g. d0 + ├── df_0.csv + ├── df_1.csv + ├── ... + """ + history_groups = {} + project_path = os.path.join(base_dir, project_name) + for root, dirs, files in os.walk(project_path): + for file in files: + if file.endswith(".csv"): + # Extract keys from the directory structure + path = os.path.relpath(root, base_dir) + project_dir, strategy, job, device_id = path.split(os.sep) + + # Load dataframe from csv + df = pd.read_csv(os.path.join(root, file)) + + # Add dataframe to dictionary + if (strategy, job) not in history_groups: + history_groups[(strategy, job)] = {} + if device_id not in history_groups[(strategy, job)]: + history_groups[(strategy, job)][device_id] = [] + history_groups[(strategy, job)][device_id].append(df) + return history_groups + + +def get_total_flops(groups, total_model_flops): + """ + Returns the total number of FLOPs for each group. + Args: + groups: The runs of one project, according to the structure of the wandb project + total_model_flops: (int) the total number of FLOPs of the model + Returns: + flops_per_group: (dict) the total number of FLOPs for each group + """ + flops_per_group = {"strategy": [], "flops": []} + for (strategy, job), group in groups.items(): + if job == "train": + flops = 0 + num_runs = 1 # avoid division by 0 + for name, runs in group.items(): + if ( + name != "controller" + ): # exclude controller to not count the FLOPs multiple times + num_runs = len(runs) # runs have equal length anyway + for run_df in runs: + for col_name in run_df.columns: + if col_name == "train_accuracy.num_samples": + flops += ( + run_df[col_name].sum() * total_model_flops * 3 + ) # 1x forward + 2x backward + if col_name == "val_accuracy.num_samples": + flops += ( + run_df[col_name].sum() * total_model_flops + ) # 1x forward + flops = flops / num_runs + flops_per_group["strategy"].append(STRATEGY_MAPPING[strategy]) + flops_per_group["flops"].append(round(flops / 1000000000, 3)) # in GFLOPs + return flops_per_group + + +def get_communication_overhead(groups): + """ + Returns the total communication size in GB for each group. + Args: + groups: The runs of one project, according to the structure of the wandb project + Returns: + communication_sizes: (dict) the total communication size for each group + """ + communication_sizes = {"strategy": [], "communicationsize": []} + for (strategy, job), group in groups.items(): + if job == "train": + payload_size = 0 + num_runs = 1 + for name, runs in group.items(): + num_runs = len(runs) + for run_df in runs: + for col in run_df.columns: + if col.startswith("/Device/"): + payload_size += run_df[col].sum() + payload_size = payload_size / num_runs + communication_sizes["strategy"].append(STRATEGY_MAPPING[strategy]) + communication_sizes["communicationsize"].append( + round(payload_size / 1000000000, 3) + ) # in GB + return communication_sizes + + +def get_test_accuracy(history_groups): + """ + Returns the test accuracy for each group. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + Returns: + test_acc: (dict) the test accuracy for each group + """ + result = {"strategy": [], "testaccuracy": []} + for (strategy, job), group in history_groups.items(): + if job == "test": + result["strategy"].append(STRATEGY_MAPPING[strategy]) + test_acc = [] + for run_df in group["d0"]: + test_acc.append( + run_df["test_accuracy.value"].max() + ) # max to filter NaN + result["testaccuracy"].append( + round(sum(test_acc) / len(test_acc), 3) + ) # take the average for multiple runs + return result + + +def remaining_devices_per_round(history_groups): + """ + Returns the remaining devices per round for each group. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + Returns: + results: (dict) the remaining devices (list(int)) per round (list(int)) for each group + """ + results = {} + for (strategy, job), group in history_groups.items(): + if job == "train": + round_cols, value_cols = [], [] + for run_df in group["controller"]: + round_cols.append(list(run_df["remaining_devices.round"].dropna())) + value_cols.append(list(run_df["remaining_devices.devices"].dropna())) + # if multiple columns exist (i.e. multiple runs) average in each round and if one run was shorter, use last value + max_rounds = [] + # get the maximum number of rounds + # for federated learning sometimes zero devices are logged. In that case, break after the first zero + for col in value_cols: + zeros = [(idx, i) for idx, i in enumerate(col) if int(i) == 0] + if len(zeros) > 1: + max_rounds.append(zeros[1][0]) + else: + max_rounds.append(len(col)) + max_rounds = max(max_rounds) + round = range(0, max_rounds) + remaining_devices = [] + for i in round: + values_in_round = [] + for j in range(len(value_cols)): + if len(value_cols[j]) > i: + values_in_round.append(value_cols[j][i]) + else: + values_in_round.append(value_cols[j][-1]) + remaining_devices.append(sum(values_in_round) / len(values_in_round)) + if ( + remaining_devices[-1] == 0 + ): # stop if there was no device left in any of the runs + break + results[(strategy, job)] = (round, remaining_devices) + return results + + +def plot_remaining_devices(devices_per_round, save_path=None): + """ + Plots the remaining devices per round for each group. + Args: + devices_per_round: (dict) the remaining devices (list(int)) per round (list(int)) for each group + save_path: (str) the path to save the plot to + """ + + plt.figure() + num_rounds = [0] + max_devices = [] + for (strategy, job), (rounds, num_devices) in devices_per_round.items(): + num_rounds.append(len(rounds)) + max_devices.append(max(num_devices)) + plt.plot(rounds, num_devices, label=f"{STRATEGY_MAPPING[strategy]}") + plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8))) + plt.yticks( + range(0, int(max(max_devices) + 1), max(1, int((max(max_devices) + 1)) // 5)) + ) + plt.xlabel(LABEL_MAPPING["round"]) + plt.ylabel(LABEL_MAPPING["num_devices"]) + plt.legend() + plt.tight_layout() + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}.pdf", format="pdf") + plt.savefig(f"{save_path}.png", format="png") + plt.close() + + +def accuracy_over_epoch(history_groups, phase="train"): + """ + Returns the accuracy over the epoch for each group. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + phase: (str) the phase to get the accuracy for, either "train" or "val" + Returns: + results: (dict) the accuracy (list(float)) per round (list(int)) for each group + """ + results = {} + for (strategy, job), group in history_groups.items(): + if job == "train": + round_cols, value_cols = [], [] + for run_df in group["controller"]: + round_cols.append(list(run_df[f"{phase}_accuracy.round"].dropna())) + value_cols.append(list(run_df[f"{phase}_accuracy.value"].dropna())) + # if multiple columns exist (i.e. multiple runs) average in each round and if one run was shorter, use last value + max_rounds = max([int(col[-1]) for col in round_cols]) + 1 + mean_acc, round_no = [], [] + for i in range(max_rounds): + single_run_accs = [] + for run_idx, value_col in enumerate( + value_cols + ): # round_cols should have same length + if ( + round_cols[run_idx].count(i) > 0 + ): # check if values were logged in this round + round_idx = round_cols[run_idx].index( + i + ) # get the index of the round (could be less than the round number due to missing values) + single_run_accs.append(value_col[round_idx]) + # print(i, single_run_accs) + if len(single_run_accs) > 0: + mean_acc.append(sum(single_run_accs) / len(single_run_accs)) + round_no.append(i) + + results[(strategy, job)] = (round_no, mean_acc) + return results + + +def plot_accuracies(accuracies_per_round, save_path=None, phase="train"): + """ + Plots the accuracy over the epoch for each group. + Args: + accuracies_per_round: (dict) the accuracy (list(float)) per round (list(int)) for each group + save_path: (str) the path to save the plot to + """ + plt.figure() + num_rounds = [0] + for (strategy, job), (rounds, accs) in accuracies_per_round.items(): + plt.plot(rounds, accs, label=f"{STRATEGY_MAPPING[strategy]}") + num_rounds.append(len(rounds)) + plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8))) + plt.xlabel(LABEL_MAPPING["round"]) + plt.ylabel(LABEL_MAPPING[f"{phase} accuracy"]) + plt.legend() + plt.tight_layout() + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}.pdf", format="pdf") + plt.savefig(f"{save_path}.png", format="png") + plt.close() + + +def battery_over_time(history_groups, num_intervals=1000): + """ + Returns the average battery over time for each group. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + num_intervals: (int) the number of intervals to sample the battery for (required for averaging or aggregation) + Returns: + results: (dict) the average battery (list(float)) per interval (list(int)) for each group + max_runtime_per_group: (dict) the maximum runtime per group + """ + results = {} + max_runtime_per_group = {} + for (strategy, job), group in history_groups.items(): + if job == "train": + average_batteries_per_device_group = {} + max_runtime = [] + # get min start and max end time per group + start_times = [[] for _ in range(len(list(group.values())[0]))] + end_times = [[] for _ in range(len(list(group.values())[0]))] + for device_id, device_group in group.items(): # device_id e.g. "d0" + if device_id != "controller": + for i, run_df in enumerate(device_group): + start_times[i].append(run_df["_timestamp"].min()) + end_times[i].append(run_df["_timestamp"].max()) + min_start_time = [min(time_list) for time_list in start_times] + max_end_time = [max(time_list) for time_list in end_times] + for device_id, device_group in group.items(): # device_id e.g. "d0" + if device_id != "controller": # controller does not record batteries + avg_battery_per_interval = [] + # chunk the runtime into intervals and sample the battery for each interval and run + for i, run_df in enumerate(device_group): + df = run_df.copy() + df["_timestamp"] = pd.to_datetime(df["_timestamp"], unit="s") + df = df.set_index("_timestamp") + max_runtime.append( + df["_runtime"].max() + ) # to match the battery values to runtime again + intervals = pd.date_range( + start=pd.to_datetime(min_start_time[i], unit="s"), + end=pd.to_datetime(max_end_time[i], unit="s"), + periods=num_intervals + 1, + ) + df["interval"] = pd.cut( + df.index, bins=intervals, labels=np.arange(num_intervals) + ) + # average per interval and fill empty values before/after with the first/last values + df_resampled = ( + df.groupby("interval")["battery"].mean().bfill().ffill() + ) + avg_battery_per_interval.append(df_resampled) + # take average battery for each device_id and append to batteries per group + average_batteries_per_device_group[device_id] = pd.concat( + avg_battery_per_interval, axis=1 + ).mean(axis=1) + max_runtime_per_group[(strategy, job)] = max(max_runtime) + results[(strategy, job)] = average_batteries_per_device_group + return results, max_runtime_per_group + + +def aggregated_battery_over_time(history_groups, num_intervals=100): + """ + Returns the aggregated battery over time for each group. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + num_intervals: (int) the number of intervals to sample the battery for (required for averaging or aggregation) + Returns: + results: (dict) the aggregated battery (list(float)) per interval (list(int)) for each group + max_runtime_per_group: (dict) the maximum runtime per group + """ + batteries_per_device, max_runtime_per_group = battery_over_time( + history_groups, num_intervals=num_intervals + ) + results = {} + for (strategy, job), series in batteries_per_device.items(): + results[(strategy, job)] = pd.concat(series.values(), axis=1).sum(axis=1) + return results, max_runtime_per_group + + +def battery_over_epoch(history_groups, num_intervals=1000): + """ + Returns the battery over the epoch for each group. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + num_intervals: (int) the number of intervals to sample the battery for (required for averaging or aggregation) + Returns: + battery_over_epoch: (dict) the battery (list(float)) per round (list(int)) for each group + """ + batteries, _ = battery_over_time(history_groups, num_intervals) + battery_over_epoch = {} + for (strategy, job), group in history_groups.items(): + if job == "train": + max_epochs = 0 + round_runtime = None + device_batteries = {} + for device_id, device_group in group.items(): # device_id e.g. "d0" + if device_id == "controller": # controller does rounds over time + for run_df in device_group: + end_idx = run_df[ + run_df["remaining_devices.devices"] > 0 + ].last_valid_index() + round_per_time_df = run_df[: end_idx + 1][ + ["remaining_devices.round", "_runtime"] + ].dropna() + if ( + round_per_time_df["remaining_devices.round"].max() + > max_epochs + ): + round_runtime = round_per_time_df + for device_id, battery_series in batteries[(strategy, job)].items(): + round_rt = round_runtime.copy() + round_rt["_runtime"] = pd.to_datetime(round_rt["_runtime"], unit="s") + round_rt = round_rt.set_index("_runtime") + time_index = pd.date_range( + start=round_rt.index.min(), + end=round_rt.index.max(), + periods=len(battery_series), + ) + battery_series.index = time_index + round_rt["battery"] = battery_series.asof(round_rt.index) + device_batteries[device_id] = round_rt.set_index( + "remaining_devices.round" + ) + battery_over_epoch[(strategy, job)] = device_batteries + return battery_over_epoch + + +def aggregated_battery_over_epoch(history_groups, num_intervals=1000): + """ + Returns the aggregated battery over the epoch for each group. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + num_intervals: (int) the number of intervals to sample the battery for (required for averaging or aggregation) + Returns: + results: (dict) the aggregated battery (list(float)) per round (list(int)) for each group + """ + batteries_per_device = battery_over_epoch( + history_groups, num_intervals=num_intervals + ) + results = {} + for (strategy, job), series in batteries_per_device.items(): + results[(strategy, job)] = pd.concat(series.values(), axis=1).sum(axis=1) + return results + + +def plot_batteries_over_time( + batteries_over_time, max_runtimes, save_path=None, aggregated=True +): + """ + Plots the battery over time for each group. + Args: + batteries_over_time: (dict) the battery (list(float)) per interval (list(int)) for each group + max_runtimes: (dict) the maximum runtime per group + save_path: (str) the path to save the plot to + aggregated: (bool) whether the battery is aggregated or not + """ + if aggregated: + plt.figure() + plt.rcParams.update({"font.size": 13}) + for (strategy, job), series in batteries_over_time.items(): + runtime = max_runtimes[(strategy, job)] + x_values = [ + runtime / series.size * i for i in range(1, series.size + 1) + ] # match series index with runtime + y_values = series.values + plt.plot(x_values, y_values, label=f"{STRATEGY_MAPPING[strategy]}") + plt.xlabel(LABEL_MAPPING["runtime"]) + plt.ylabel(LABEL_MAPPING["total network battery"]) + plt.legend() + plt.tight_layout() + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}.pdf", format="pdf") + plt.savefig(f"{save_path}.png", format="png") + plt.close() + else: + for (strategy, job), series_dict in batteries_over_time.items(): + plt.figure() + plt.rcParams.update({"font.size": 13}) + for device_id, series in series_dict.items(): + runtime = max_runtimes[(strategy, job)] + x_values = [ + runtime / series.size * i for i in range(1, series.size + 1) + ] # match series index with runtime + y_values = series.values + plt.plot(x_values, y_values, label=f"{device_id}") + plt.xlabel(LABEL_MAPPING["runtime"]) + plt.ylabel(LABEL_MAPPING["device battery"]) + plt.legend() + plt.tight_layout() + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}_{strategy}.pdf", format="pdf") + plt.savefig(f"{save_path}_{strategy}.png", format="png") + plt.close() + + +def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=True): + """ + Plots the battery over the epoch for each group. + Args: + batteries_over_epoch: (dict) the battery (list(float)) per round (list(int)) for each group + save_path: (str) the path to save the plot to + aggregated: (bool) whether the battery is aggregated or not + """ + if aggregated: + plt.figure() + plt.rcParams.update({"font.size": 13}) + num_rounds = [0] + for (strategy, job), series in batteries_over_epoch.items(): + x_values = series.index + y_values = series.values + plt.plot(x_values, y_values, label=f"{STRATEGY_MAPPING[strategy]}") + num_rounds.append(len(series)) + plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8))) + plt.ylabel(LABEL_MAPPING["total network battery"]) + plt.xlabel(LABEL_MAPPING["round"]) + plt.legend() + plt.tight_layout() + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}.pdf", format="pdf") + plt.savefig(f"{save_path}.png", format="png") + plt.close() + else: + for (strategy, job), series_dict in batteries_over_epoch.items(): + plt.figure() + plt.rcParams.update({"font.size": 13}) + num_rounds = [0] + for device_id, series in series_dict.items(): + x_values = series.index + y_values = series.values + plt.plot(x_values, y_values, label=f"{device_id}") + num_rounds.append(len(series)) + plt.xticks( + range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8)) + ) + plt.ylabel(LABEL_MAPPING["device battery"]) + plt.xlabel(LABEL_MAPPING["round"]) + plt.legend() + plt.tight_layout() + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}_{strategy}.pdf", format="pdf") + plt.savefig(f"{save_path}_{strategy}.png", format="png") + plt.close() + + +def get_train_times(run_groups): + results = {} + for (strategy, job), group in run_groups.items(): + if job == "test": + continue + start_times = [] + server_train_times = {} + client_train_times = {} + for device_id, device_group in group.items(): # device_id e.g. "d0" + if device_id != "controller": + start_times.append(device_group[0]["_timestamp"][0]) + min_start_time = min(start_times) + for device_id, device_group in group.items(): # device_id e.g. "d0" + if device_id != "controller": + run_df = device_group[0].copy() + # create an empty dataframe in case device was never server + server_train_times[device_id] = pd.DataFrame(columns=["start", "end"]) + if "train_global_time.start" in run_df.columns: + valid_train_global_times = run_df["train_global_time.start"] > 0 + server_train_times[device_id] = ( + run_df[valid_train_global_times][ + ["train_global_time.start", "train_global_time.end"] + ] + .reset_index(drop=True) + .rename( + columns={ + "train_global_time.start": "start", + "train_global_time.end": "end", + } + ) + - min_start_time + ) + + if "client_train_epoch_time.start" in run_df.columns: + valid_train_epoch_times = ( + run_df["client_train_epoch_time.start"] > 0 + ) + valid_eval_times = run_df["client_evaluate_time.end"] > 0 + train_epoch_start_times = run_df[valid_train_epoch_times][ + ["client_train_epoch_time.start"] + ].reset_index(drop=True) + train_epoch_end_times = run_df[valid_train_epoch_times][ + ["client_train_epoch_time.end"] + ].reset_index(drop=True) + eval_start_times = run_df[valid_eval_times][ + ["client_evaluate_time.start"] + ].reset_index(drop=True) + eval_end_times = run_df[valid_eval_times][ + ["client_evaluate_time.end"] + ].reset_index(drop=True) + client_train_times[device_id] = pd.concat( + [ + pd.concat( + [train_epoch_start_times, train_epoch_end_times], axis=1 + ).rename( + columns={ + "client_train_epoch_time.start": "start", + "client_train_epoch_time.end": "end", + } + ) + - min_start_time, + pd.concat( + [eval_start_times, eval_end_times], axis=1 + ).rename( + columns={ + "client_evaluate_time.start": "start", + "client_evaluate_time.end": "end", + } + ) + - min_start_time, + ] + ).reset_index(drop=True) + + if "fed_train_time.start" in run_df.columns: + valid_fed_train_times = run_df["fed_train_time.start"] > 0 + server_train_times[device_id] = ( + run_df[valid_fed_train_times][ + ["fed_train_time.start", "fed_train_time.end"] + ] + .reset_index(drop=True) + .rename( + columns={ + "fed_train_time.start": "start", + "fed_train_time.end": "end", + } + ) + - min_start_time + ) + client_train_times[device_id] = server_train_times[device_id].copy() + results[(strategy, job)] = server_train_times, client_train_times + return results + + +def train_time_end(server_train_times, client_train_times): + max_timestamp = 0 + for idx, (device_id, df) in enumerate(server_train_times.items()): + if df["end"].max() > max_timestamp: + max_timestamp = df["end"].max() + for idx, (device_id, df) in enumerate(client_train_times.items()): + if df["end"].max() > max_timestamp: + max_timestamp = df["end"].max() + return max_timestamp + + +def plot_batteries_over_time_with_activity( + batteries_over_time, max_runtimes, training_times, save_path=None +): + device_space = 200 + bar_height = 50 + y_offset = 10 + for (strategy, job), series_dict in batteries_over_time.items(): + server_train_times, client_train_times = training_times[(strategy, job)] + xlim = ( + 0, + train_time_end(server_train_times, client_train_times) * 1.05, + ) # set end 5% after last activity timestamp + plt.figure() + plt.rcParams.update({"font.size": 13}) + # battery_plot + plt.subplot(2, 1, 1) + for device_id, series in series_dict.items(): + runtime = max_runtimes[(strategy, job)] + x_values = [ + runtime / series.size * i for i in range(1, series.size + 1) + ] # match series index with runtime + y_values = series.values + plt.plot(x_values, y_values, label=f"{device_id}") + plt.ylabel(LABEL_MAPPING["device battery"]) + plt.legend() + plt.tight_layout() + plt.xlim(xlim) + # plt.ylim(3000, 3800) + # activity plot + plt.subplot(2, 1, 2) + for idx, (device_id, df) in enumerate(server_train_times.items()): + x_values = [] + for _, row in df.iterrows(): + x_values.append((row["start"], row["end"] - row["start"])) + plt.broken_barh( + x_values, # (xmin, xwidth) + (device_space * idx + bar_height * 1 + y_offset, bar_height), + label="Server" if idx == 0 else "", + ) # avoid duplicate labels + for idx, (device_id, df) in enumerate(client_train_times.items()): + x_values = [] + for _, row in df.iterrows(): + x_values.append((row["start"], row["end"] - row["start"])) + plt.broken_barh( + x_values, # (xmin, xwidth) + (device_space * idx + bar_height * 2 + y_offset, bar_height), + facecolors="darkorange", + label="Client" if idx == 0 else "", + ) # avoid duplicate labels + plt.xlabel(LABEL_MAPPING["runtime"]) + plt.ylabel(LABEL_MAPPING["device"]) + plt.legend(loc="upper left") + plt.tight_layout() + plt.xlim(xlim) + plt.yticks( + [(device_space * x + 2 * bar_height + y_offset) for x in range(0, 5)], + labels=["d0", "d1", "d2", "d3", "d4"][:5], + ) + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}_{strategy}.pdf", format="pdf") + plt.savefig(f"{save_path}_{strategy}.png", format="png") + plt.close() + + +def plot_batteries_over_epoch_with_activity_at_time_scale( + batteries_over_time, max_runtimes, training_times, save_path=None +): + device_space = 200 + bar_height = 50 + y_offset = 10 + for (strategy, job), series_dict in batteries_over_time.items(): + server_train_times, client_train_times = training_times[(strategy, job)] + train_times = pd.concat( + [device_server_times for device_server_times in server_train_times.values()] + ).sort_values("start") + start_time, end_time = train_times["start"].min(), train_times["end"].max() + xlim = (start_time, end_time) + start_times = list( + train_times["start"] + ) # list(zip(train_times["start"], train_times["end"])) + xticks = ( + [ + start_times[i] + for i in range(0, len(start_times), max(1, len(start_times) // 8)) + ], + [str(i) for i in range(0, len(start_times), max(1, len(start_times) // 8))], + ) + plt.figure() + plt.rcParams.update({"font.size": 13}) + # battery_plot + plt.subplot(2, 1, 1) + for device_id, series in series_dict.items(): + runtime = max_runtimes[(strategy, job)] + x_values = [ + runtime / series.size * i for i in range(1, series.size + 1) + ] # match series index with runtime + y_values = series.values + plt.plot(x_values, y_values, label=f"{device_id}") + plt.ylabel(LABEL_MAPPING["device battery"]) + plt.legend() + plt.tight_layout() + plt.xlim(xlim) + plt.xticks(xticks[0], labels=xticks[1]) + # plt.ylim(3000, 3800) + # activity plot + plt.subplot(2, 1, 2) + for idx, (device_id, df) in enumerate(server_train_times.items()): + x_values = [] + for _, row in df.iterrows(): + x_values.append((row["start"], row["end"] - row["start"])) + plt.broken_barh( + x_values, # (xmin, xwidth) + (device_space * idx + bar_height * 1 + y_offset, bar_height), + label="Server" if idx == 0 else "", + ) # avoid duplicate labels + for idx, (device_id, df) in enumerate(client_train_times.items()): + x_values = [] + for _, row in df.iterrows(): + x_values.append((row["start"], row["end"] - row["start"])) + plt.broken_barh( + x_values, # (xmin, xwidth) + (device_space * idx + bar_height * 2 + y_offset, bar_height), + facecolors="darkorange", + label="Client" if idx == 0 else "", + ) # avoid duplicate labels + plt.xlabel(LABEL_MAPPING["round"]) + plt.ylabel(LABEL_MAPPING["device"]) + plt.legend(loc="upper left") + plt.tight_layout() + plt.xlim(xlim) + plt.xticks(xticks[0], labels=xticks[1]) + plt.yticks( + [(device_space * x + 2 * bar_height + y_offset) for x in range(0, 5)], + labels=["d0", "d1", "d2", "d3", "d4"][:5], + ) + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}_{strategy}.pdf", format="pdf") + plt.savefig(f"{save_path}_{strategy}.png", format="png") + plt.close() + + +def plot_batteries_over_epoch_with_activity_at_epoch_scale( + batteries_over_epoch, training_times, save_path=None +): + device_space = 200 + bar_height = 50 + y_offset = 10 + for (strategy, job), series_dict in batteries_over_epoch.items(): + server_train_times, client_train_times = training_times[(strategy, job)] + train_times = pd.concat( + [ + device_server_times.assign(device_id=device_id) + for device_id, device_server_times in server_train_times.items() + ] + ).sort_values("start") + + # battery plot + plt.figure() + plt.rcParams.update({"font.size": 13}) + plt.subplot(2, 1, 1) + num_rounds = [] + for device_id, series in series_dict.items(): + x_values = series.index + y_values = series.values + plt.plot(x_values, y_values, label=f"{device_id}") + num_rounds.append(len(series)) + plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8))) + plt.ylabel(LABEL_MAPPING["device battery"]) + plt.legend() + plt.tight_layout() + + # device plot + plt.subplot(2, 1, 2) + if strategy == "fed": + for device_idx, device_id in enumerate(server_train_times.keys()): + for round_idx, row in train_times.iterrows(): + current_device_id = row["device_id"] + if device_id == current_device_id: + plt.broken_barh( + [(round_idx, 1)], # (xmin, xwidth) + ( + device_space * device_idx + bar_height * 1 + y_offset, + bar_height, + ), + ) + + plt.broken_barh( + [(round_idx, 1)], + ( + device_space * device_idx + bar_height * 2 + y_offset, + bar_height, + ), + facecolors="darkorange", + ) + else: + server_train_times_with_device = [] + for idx, (device_id, df) in enumerate(server_train_times.items()): + server_train_times_with_device += [ + (row["start"], row["end"], device_id) for _, row in df.iterrows() + ] + server_train_times_with_device.sort(key=lambda x: x[0]) + + client_train_rounds_with_device = {} + for round, (start, end, device_id) in enumerate( + server_train_times_with_device + ): + clients_in_round = [] + for client_device_idx, (client_device_id, df) in enumerate( + client_train_times.items() + ): + if len(df[df["start"] > start]) > 0: + clients_in_round.append((client_device_idx, client_device_id)) + client_train_rounds_with_device[round] = [ + (client_device_idx, client_device_id, len(clients_in_round)) + for client_device_idx, client_device_id in clients_in_round + ] + + for device_idx, device_id in enumerate(server_train_times.keys()): + for round_idx, (_, _, current_device_id) in enumerate( + server_train_times_with_device + ): + if device_id == current_device_id: + plt.broken_barh( + [(round_idx, 1)], # (xmin, xwidth) + ( + device_space * device_idx + bar_height * 1 + y_offset, + bar_height, + ), + ) + + for ( + client_device_idx, + client_device_id, + num_clients_in_round, + ) in client_train_rounds_with_device[round_idx]: + plt.broken_barh( + [ + ( + round_idx + + (1 / num_clients_in_round) * client_device_idx, + (1 / num_clients_in_round), + ) + ], + ( + device_space * client_device_idx + + bar_height * 2 + + y_offset, + bar_height, + ), + facecolors="darkorange", + ) + plt.broken_barh( + [(0, 0)], (0, 0), facecolors="darkorange", label="Client" + ) # avoid duplicate labels + plt.broken_barh([(0, 0)], (0, 0), label="Server") + plt.xlabel(LABEL_MAPPING["round"]) + plt.ylabel(LABEL_MAPPING["device"]) + plt.legend(loc="upper left") + plt.tight_layout() + plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8))) + plt.yticks( + [(device_space * x + 2 * bar_height + y_offset) for x in range(0, 5)], + labels=["d0", "d1", "d2", "d3", "d4"][:5], + ) + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}_{strategy}.pdf", format="pdf") + plt.savefig(f"{save_path}_{strategy}.png", format="png") + plt.close() + + +def generate_plots(history_groups, project_name, base_path="./plots"): + """ + Generates plots for the given history groups and saves them to the project_name folder. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + project_name: (str) the name of the project + base_path: (str) the base path to save the plots to + """ + project_path = f"{base_path}/{project_name}" + os.makedirs(project_path, exist_ok=True) + # batteries over time + batteries_over_time, max_runtimes = aggregated_battery_over_time( + history_groups, num_intervals=1000 + ) + plot_batteries_over_time( + batteries_over_time, + max_runtimes, + save_path=f"{project_path}/total_batteries_over_time", + aggregated=True, + ) + + batteries_over_time, max_runtimes = battery_over_time( + history_groups, num_intervals=1000 + ) + plot_batteries_over_time( + batteries_over_time, + max_runtimes, + save_path=f"{project_path}/batteries_over_time", + aggregated=False, + ) + + train_times = get_train_times(history_groups) + plot_batteries_over_time_with_activity( + batteries_over_time, + max_runtimes, + train_times, + save_path=f"{project_path}/activity_over_time", + ) + plot_batteries_over_epoch_with_activity_at_time_scale( + batteries_over_time, + max_runtimes, + train_times, + save_path=f"{project_path}/activity_over_time_with_epoch", + ) + + # batteries over epoch + batteries_over_epoch = aggregated_battery_over_epoch( + history_groups, num_intervals=1000 + ) + plot_batteries_over_epoch( + batteries_over_epoch, + save_path=f"{project_path}/total_batteries_over_epoch", + aggregated=True, + ) + + batteries_over_epoch = battery_over_epoch(history_groups, num_intervals=1000) + plot_batteries_over_epoch( + batteries_over_epoch, + save_path=f"{project_path}/batteries_over_epoch", + aggregated=False, + ) + training_times = get_train_times(history_groups) + plot_batteries_over_epoch_with_activity_at_epoch_scale( + batteries_over_epoch, + training_times=training_times, + save_path=f"{project_path}/activity_over_epoch", + ) + + # remaining devices + remaining_devices = remaining_devices_per_round(history_groups) + plot_remaining_devices( + remaining_devices, save_path=f"{project_path}/remaining_devices_per_epoch" + ) + + # accuracies + train_accs = accuracy_over_epoch(history_groups, "train") + plot_accuracies( + train_accs, save_path=f"{project_path}/train_accuracy", phase="train" + ) + + val_accs = accuracy_over_epoch(history_groups, "val") + plot_accuracies(val_accs, save_path=f"{project_path}/val_accuracy", phase="val") + + +def generate_metric_files( + history_groups, project_name, model_flops, base_path="./metrics" +): + """ + Generates metric file for the given history groups and saves them to the project_name folder. + Args: + history_groups: The runs of one project, according to the structure of the wandb project + project_name: (str) the name of the project + model_flops: (int) the total number of FLOPs of the model + base_path: (str) the base path to save the metrics to + """ + project_path = f"{base_path}/{project_name}" + os.makedirs(project_path, exist_ok=True) + + # FLOPs, Communication Overhead, Test Accuracy + test_acc = pd.DataFrame.from_dict(get_test_accuracy(history_groups)).set_index( + "strategy" + ) + comm_overhead = pd.DataFrame.from_dict( + get_communication_overhead(history_groups) + ).set_index("strategy") + total_flops = pd.DataFrame.from_dict( + get_total_flops(history_groups, model_flops) + ).set_index("strategy") + df = pd.concat([test_acc, comm_overhead, total_flops], axis=1) + df.to_csv(f"{project_path}/metrics.csv") diff --git a/results/results.py b/results/results.py new file mode 100644 index 0000000000000000000000000000000000000000..d27b8b8c19a0ad635df02f554fea6caeda620cb1 --- /dev/null +++ b/results/results.py @@ -0,0 +1,153 @@ +import os + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import wandb + + +def plot_communication_overhead( + plot_groups: dict[str, list[float]], + strategies: list[str], + batteries: list[str], + save_path: str, +): + # adapted from https://matplotlib.org/stable/gallery/lines_bars_and_markers/barchart.html#sphx-glr-gallery-lines-bars-and-markers-barchart-py + + x = np.arange(len(strategies)) # the label locations + bar_width = 0.25 + multiplier = 0 + + fig, ax = plt.subplots(layout="constrained") + + for attribute, measurement in plot_groups.items(): + offset = bar_width * multiplier + rects = ax.bar(x + offset, measurement, bar_width, label=attribute) + ax.bar_label(rects, padding=3) + multiplier += 1 + + ax.set_xlabel("Learning algorithm") + ax.set_ylabel("Sent + Received Data [GBytes]") + ax.set_title("Communication overhead by algorithm") + ax.set_xticks( + x + bar_width * (len(batteries) - 1) / 2, strategies + ) # center x ticks among groups + ax.legend(loc="upper left", ncols=len(batteries)) + ax.set_ylim( + 0, max([item for sublist in plot_groups.values() for item in sublist]) * 1.15 + ) # add 15% margin to max value + plt.tight_layout() + plt.savefig(save_path) + + +def get_communication_overhead( + history_group: dict[str, list[pd.DataFrame]], batteries: list[str] +) -> dict[str, list[float]]: + communication_sizes = {} + for name, group in history_group.items(): + payload_size = 0 + for history_df in group: + for col in history_df.columns: + if col.startswith("/Device/"): + payload_size += history_df[col].sum() + communication_sizes[name] = payload_size / 1000000000 # in GB + plot_groups = {battery: [] for battery in batteries} + for idx, (method, size) in enumerate(communication_sizes.items()): + plot_groups[batteries[idx % len(batteries)]].append(size) + return plot_groups + + +def get_run_history( + run_groups: dict[str, wandb.apis.public.Runs] +) -> dict[str, list[pd.DataFrame]]: + """ + Retrieve the history of all runs in the run groups + """ + history_groups = {} + for name, group in run_groups.items(): + history = [] + for run in group: + history_df = pd.DataFrame(run.scan_history()) + history.append(history_df) + history_groups[name] = history + return history_groups + + +def plot_batteries(history_groups: dict[str, list[pd.DataFrame]], save_path: str): + for name, group in history_groups.items(): + plt.figure() + for idx, run_df in enumerate(group): + plt.step( + run_df.dropna(subset=["battery"])["_runtime"], + run_df.dropna(subset=["battery"])["battery"], + label=f"Device {idx}", + ) + + plt.xlabel("Runtime [s]") + plt.ylabel("Battery capacity") + plt.title("Battery capacity over time") + plt.ylim(bottom=0) + plt.legend(loc="upper right") + plt.tight_layout() + plt.savefig(f"{save_path}/{name}_battery.png") + + +def retrieve_and_plot_results( + strategies: list[str], + batteries: list[str], + entity: str = "swarmsl", + project: str = "baseline", +): + """ + Retrieve results from wandb and plot them + """ + # fetch results by group + wandb.login() + print("retrieving results...") + run_groups = get_run_groups(batteries, entity, project, strategies) + history_groups = get_run_history(run_groups) + + # make data plottable + print("plotting results...") + plot_groups = get_communication_overhead(history_groups, batteries) + + # plot results + cwd = os.getcwd() + if not os.path.exists(f"{cwd}/{project}"): + os.makedirs(f"{cwd}/{project}") # create directory if it does not exist + + plot_communication_overhead( + plot_groups, + strategies, + batteries, + f"{cwd}/{project}/communication_overhead.png", + ) + plot_batteries(history_groups, f"{cwd}/{project}") + + +def get_run_groups( + batteries: list[str], entity: str, project: str, strategies: list[str] +) -> dict[str, wandb.apis.public.Runs]: + api = wandb.Api(timeout=30) + run_groups = {} + for strategy in strategies: + for battery in batteries: + group = api.runs( + f"{entity}/{project}", filters={"group": f"{strategy}_{battery}"} + ) + run_groups[f"{strategy}_{battery}"] = group + return run_groups + + +@hydra.main(config_path=None, version_base=None) # use only for command line arguments +def main(cfg): + batteries = cfg.batteries + strategies = cfg.strategies + entity = cfg.entity + project = cfg.project + retrieve_and_plot_results(strategies, batteries, entity, project) + + +if __name__ == "__main__": + main() diff --git a/run_test.sh b/run_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..31492ed541a96aa13df4652b913e48cba41a63da --- /dev/null +++ b/run_test.sh @@ -0,0 +1,42 @@ +# shell script to run a test with a number 'n' of devices experiment 'e' and with model from round 'r' +# bash -i run_test.sh -n 2 -e "default_experiment" -r 1 + +conda init bash +conda activate slenv + +# use trap to kill all processes in the subshell on ctrl-c +(trap 'kill 0' SIGINT; + +# Parse arguments +while getopts ":n:e:r:" flag +do + case "${flag}" in + n) num_devices=${OPTARG};; + e) experiment=${OPTARG};; + r) round=${OPTARG};; + *) echo "usage: $0 [-n] [-e] [-r]" >&2 + # invalid option passed + exit 1 ;; + esac +done + +# Start devices +device_pids=() +echo "Starting $num_devices devices" +for (( i = 0; i < num_devices; i++ )); do + python3 main.py own_device_id="$i" num_devices="$num_devices" battery="unlimited" experiment.job="test" & + device_pids+=($!) +done + +# run controller and wait for its termination +python3 main.py +method="test" +best_round="$round" num_devices="$num_devices" experiment="$experiment" battery="unlimited" experiment.job="test" & +controller_pid=$! +wait $controller_pid + +# kill device processes after controller has finished +for pid in "${device_pids[@]}"; do + kill "$pid" +done + +) +exit 1; diff --git a/wandb_key.txt b/wandb_key.txt new file mode 100644 index 0000000000000000000000000000000000000000..51b339faa697566baf378bce797badb94920a287 --- /dev/null +++ b/wandb_key.txt @@ -0,0 +1 @@ +# Place (only) your W&B API key here