From fcc154c62c93fe50519b9fea0947cfc77b2af4db Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Tue, 16 Jul 2024 17:08:17 +0200
Subject: [PATCH] experiment setups

---
 .../battery/resnet110_cifar100_cost.yaml      |  4 ++
 .../topology/equal_batteries_10_devices.yaml  | 62 +++++++++++++++++++
 .../resnet110_cifar100_batteries.yaml         | 28 +++++++++
 edml/core/server.py                           |  4 +-
 4 files changed, 97 insertions(+), 1 deletion(-)
 create mode 100644 edml/config/battery/resnet110_cifar100_cost.yaml
 create mode 100644 edml/config/topology/equal_batteries_10_devices.yaml
 create mode 100644 edml/config/topology/resnet110_cifar100_batteries.yaml

diff --git a/edml/config/battery/resnet110_cifar100_cost.yaml b/edml/config/battery/resnet110_cifar100_cost.yaml
new file mode 100644
index 0000000..4e7f11b
--- /dev/null
+++ b/edml/config/battery/resnet110_cifar100_cost.yaml
@@ -0,0 +1,4 @@
+deduction_per_second: 0.02
+deduction_per_mflop: 0.00000005
+deduction_per_mbyte_received: 0.0002
+deduction_per_mbyte_sent: 0.0002
diff --git a/edml/config/topology/equal_batteries_10_devices.yaml b/edml/config/topology/equal_batteries_10_devices.yaml
new file mode 100644
index 0000000..ea81cd3
--- /dev/null
+++ b/edml/config/topology/equal_batteries_10_devices.yaml
@@ -0,0 +1,62 @@
+devices: [
+    {
+        device_id: "d0",
+        address: "localhost:50051",
+        battery_capacity: 1000000,
+        torch_device: cuda:0
+    },
+    {
+        device_id: "d1",
+        address: "localhost:50052",
+        battery_capacity: 1000000,
+        torch_device: cuda:1
+    },
+    {
+        device_id: "d2",
+        address: "localhost:50053",
+        battery_capacity: 1000000,
+        torch_device: cuda:2
+    },
+    {
+        device_id: "d3",
+        address: "localhost:50054",
+        battery_capacity: 1000000,
+        torch_device: cuda:0
+    },
+    {
+        device_id: "d4",
+        address: "localhost:50055",
+        battery_capacity: 1000000,
+        torch_device: cuda:1
+    },
+    {
+        device_id: "d5",
+        address: "localhost:50056",
+        battery_capacity: 1000000,
+        torch_device: cuda:2
+    },
+    {
+        device_id: "d6",
+        address: "localhost:50057",
+        battery_capacity: 1000000,
+        torch_device: cuda:0
+    },
+    {
+        device_id: "d7",
+        address: "localhost:50058",
+        battery_capacity: 1000000,
+        torch_device: cuda:1
+    },
+    {
+        device_id: "d8",
+        address: "localhost:50059",
+        battery_capacity: 1000000,
+        torch_device: cuda:2
+    },
+    {
+        device_id: "d9",
+        address: "localhost:50060",
+        battery_capacity: 1000000,
+        torch_device: cuda:0
+    }
+]
diff --git a/edml/config/topology/resnet110_cifar100_batteries.yaml b/edml/config/topology/resnet110_cifar100_batteries.yaml
new file mode 100644
index 0000000..b05e705
--- /dev/null
+++ b/edml/config/topology/resnet110_cifar100_batteries.yaml
@@ -0,0 +1,28 @@
+devices: [
+  {
+    device_id: "d0",
+    address: "localhost:50051",
+    battery_capacity: 400,
+
+  },
+  {
+    device_id: "d1",
+    address: "localhost:50052",
+    battery_capacity: 400
+  },
+  {
+    device_id: "d2",
+    address: "localhost:50053",
+    battery_capacity: 300
+  },
+  {
+    device_id: "d3",
+    address: "localhost:50054",
+    battery_capacity: 200
+  },
+  {
+    device_id: "d4",
+    address: "localhost:50055",
+    battery_capacity: 200
+  }
+]
diff --git a/edml/core/server.py b/edml/core/server.py
index 7503389..4b9ef80 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -298,7 +298,9 @@ class DeviceServer:
                 print(
                     f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
                 )
-                self.node_device.log({"adaptive_learning_threshold_applied": True})
+                self.node_device.log(
+                    {"adaptive_learning_threshold_applied": server_gradients.size(0)}
+                )
                 continue
 
             num_client_gradients = len(client_forward_pass_responses)
-- 
GitLab