From 5d4b3302f2646323936bf000db038dc0d8d6ac89 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Wed, 21 Aug 2024 16:19:42 +0200
Subject: [PATCH] Fixed autoencoder: skipping AE layers during backprop

---
 edml/core/client.py                           | 13 ++-
 edml/core/server.py                           | 11 ++-
 edml/models/autoencoder.py                    | 19 +++-
 edml/models/provider/base.py                  | 38 +++++++-
 edml/tests/controllers/fed_controller_test.py |  1 +
 edml/tests/controllers/sample_config.yaml     |  1 +
 edml/tests/core/server_test.py                | 10 ++
 edml/tests/core/start_device_test.py          | 95 +++++++++++--------
 edml/tests/models/model_loading_helpers.py    | 15 +++
 edml/tests/models/model_provider_test.py      | 43 +++++++++
 10 files changed, 196 insertions(+), 50 deletions(-)
 create mode 100644 edml/tests/models/model_loading_helpers.py
 create mode 100644 edml/tests/models/model_provider_test.py

diff --git a/edml/core/client.py b/edml/core/client.py
index d8096c7..7dc1617 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -19,6 +19,7 @@ from edml.helpers.flops import estimate_model_flops
 from edml.helpers.load_optimizer import get_optimizer_and_scheduler
 from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult
 from edml.helpers.types import StateDict
+from edml.models.provider.base import has_autoencoder
 
 if TYPE_CHECKING:
     from edml.core.device import Device
@@ -69,7 +70,7 @@ class DeviceClient:
         self._device = torch.device(get_torch_device_id(cfg))
         self._model = model.to(self._device)
         self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
-            cfg, self._model.parameters()
+            cfg, self._model.get_optimizer_params()
         )
         # get first sample from train data to estimate model flops
         sample = self._train_data.dataset.__getitem__(0)[0]
@@ -208,7 +209,10 @@ class DeviceClient:
             self._model_flops * len(batch_data) * 2
         )  # 2x for backward pass
         gradients = gradients.to(self._device)
-        smashed_data.backward(gradients)
+        if has_autoencoder(self._model):
+            self._model.trainable_layers_output.backward(gradients)
+        else:
+            smashed_data.backward(gradients)
         # self._optimizer.step()
 
         # We need to store a reference to the smashed_data to make it possible to finalize the training step.
@@ -303,7 +307,10 @@ class DeviceClient:
                     self._model_flops * len(batch_data) * 2
                 )  # 2x for backward pass
                 server_grad = server_grad.to(self._device)
-                smashed_data.backward(server_grad)
+                if has_autoencoder(self._model):
+                    self._model.trainable_layers_output.backward(server_grad)
+                else:
+                    smashed_data.backward(server_grad)
                 self._optimizer.step()
 
         client_train_time = (
diff --git a/edml/core/server.py b/edml/core/server.py
index 4279f40..d13893c 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -20,6 +20,7 @@ from edml.helpers.metrics import (
     DiagnosticMetricResultContainer,
 )
 from edml.helpers.types import StateDict, LossFn
+from edml.models.provider.base import has_autoencoder
 
 if TYPE_CHECKING:
     from edml.core.device import Device
@@ -40,7 +41,7 @@ class DeviceServer:
         self._device = torch.device(get_torch_device_id(cfg))
         self._model = model.to(self._device)
         self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
-            cfg, self._model.parameters()
+            cfg, self._model.get_optimizer_params()
         )
         self._model_flops = 0  # determine later
         self._metrics = create_metrics(
@@ -156,8 +157,12 @@ class DeviceServer:
         # Capturing training metrics for the current batch.
         self.node_device.log({"loss": loss_train.item()})
         self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())
-
-        return smashed_data.grad, loss_train.item()
+        if has_autoencoder(self._model):
+            return self._model.trainable_layers_input.grad, loss_train.item()
+        return (
+            smashed_data.grad,
+            loss_train.item(),
+        )  # hier sollten beim AE die gradients vom server model vor dem decoder zurückgegeben werden
 
     def _set_model_flops(self, smashed_data):
         """Helper to determine the model flops when smashed data are available for the first time."""
diff --git a/edml/models/autoencoder.py b/edml/models/autoencoder.py
index ed16f30..7814b7e 100644
--- a/edml/models/autoencoder.py
+++ b/edml/models/autoencoder.py
@@ -1,5 +1,5 @@
-import torch
 from torch import nn
+from torch.autograd import Variable
 
 
 class ClientWithAutoencoder(nn.Module):
@@ -7,10 +7,13 @@ class ClientWithAutoencoder(nn.Module):
         super().__init__()
         self.model = model
         self.autoencoder = autoencoder.requires_grad_(False)
+        self.trainable_layers_output = (
+            None  # needed to continue the backprop skipping the AE
+        )
 
     def forward(self, x):
-        x = self.model(x)
-        return self.autoencoder(x)
+        self.trainable_layers_output = self.model(x)
+        return self.autoencoder(self.trainable_layers_output)
 
 
 class ServerWithAutoencoder(nn.Module):
@@ -18,7 +21,13 @@ class ServerWithAutoencoder(nn.Module):
         super().__init__()
         self.model = model
         self.autoencoder = autoencoder.requires_grad_(False)
+        self.trainable_layers_input = (
+            None  # needed to stop backprop before AE and send the gradients to client
+        )
 
     def forward(self, x):
-        x = self.autoencoder(x)
-        return self.model(x)
+        self.trainable_layers_input = self.autoencoder(x)
+        self.trainable_layers_input = Variable(
+            self.trainable_layers_input, requires_grad=True
+        )
+        return self.model(self.trainable_layers_input)
diff --git a/edml/models/provider/base.py b/edml/models/provider/base.py
index e60736b..2b28c7c 100644
--- a/edml/models/provider/base.py
+++ b/edml/models/provider/base.py
@@ -1,10 +1,44 @@
+import torch
 from torch import nn
 
 
+def has_autoencoder(model: nn.Module):
+    if hasattr(model, "model") and hasattr(model, "autoencoder"):
+        return True
+    return False
+
+
+def get_grads(model: nn.Module):
+    gradients = []
+    for param in model.parameters():
+        if param.grad is not None:
+            gradients.append(param.grad)
+        else:
+            gradients.append(torch.zeros_like(param))
+    return gradients
+
+
+def add_optimizer_params_function(model: nn.Module):
+    # exclude AE params
+    if has_autoencoder(model):
+
+        def get_optimizer_params(model: nn.Module):
+            return model.model.parameters
+
+        model.get_optimizer_params = get_optimizer_params(model)
+    else:
+
+        def get_optimizer_params(model: nn.Module):
+            return model.parameters
+
+        model.get_optimizer_params = get_optimizer_params(model)
+    return model
+
+
 class ModelProvider:
     def __init__(self, client: nn.Module, server: nn.Module):
-        self._client = client
-        self._server = server
+        self._client = add_optimizer_params_function(client)
+        self._server = add_optimizer_params_function(server)
 
     @property
     def models(self) -> tuple[nn.Module, nn.Module]:
diff --git a/edml/tests/controllers/fed_controller_test.py b/edml/tests/controllers/fed_controller_test.py
index a44090e..45c4ef9 100644
--- a/edml/tests/controllers/fed_controller_test.py
+++ b/edml/tests/controllers/fed_controller_test.py
@@ -149,6 +149,7 @@ class FedTrainRoundThreadingTest(unittest.TestCase):
                     "num_devices": 0,
                     "wandb": {"enabled": False},
                     "experiment": {"early_stopping": False},
+                    "simulate_parallelism": False,
                 }
             )
         )
diff --git a/edml/tests/controllers/sample_config.yaml b/edml/tests/controllers/sample_config.yaml
index 505bdae..ad8e917 100644
--- a/edml/tests/controllers/sample_config.yaml
+++ b/edml/tests/controllers/sample_config.yaml
@@ -24,3 +24,4 @@ battery:
   deduction_per_mbyte_sent: 1
   deduction_per_mbyte_received: 1
   deduction_per_mflop: 1
+simulate_parallelism: False
diff --git a/edml/tests/core/server_test.py b/edml/tests/core/server_test.py
index 1331f8a..5ed6d1c 100644
--- a/edml/tests/core/server_test.py
+++ b/edml/tests/core/server_test.py
@@ -76,6 +76,7 @@ class PSLTest(unittest.TestCase):
                     ]
                 },
                 "own_device_id": "d0",
+                "simulate_parallelism": False,
             }
         )
         # init models with fixed weights for repeatability
@@ -90,10 +91,19 @@ class PSLTest(unittest.TestCase):
         )
         server_model = ServerModel()
         server_model.load_state_dict(server_state_dict)
+        server_model.get_optimizer_params = (
+            server_model.parameters
+        )  # using a model provider, this is created automatically
         client_model1 = ClientModel()
         client_model1.load_state_dict(client_state_dict)
+        client_model1.get_optimizer_params = (
+            server_model.parameters
+        )  # using a model provider, this is created automatically
         client_model2 = ClientModel()
         client_model2.load_state_dict(client_state_dict)
+        client_model2.get_optimizer_params = (
+            server_model.parameters
+        )  # using a model provider, this is created automatically
         self.server = DeviceServer(
             model=server_model, loss_fn=torch.nn.L1Loss(), cfg=cfg
         )
diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py
index abb5d9d..dd1f409 100644
--- a/edml/tests/core/start_device_test.py
+++ b/edml/tests/core/start_device_test.py
@@ -4,11 +4,51 @@ from copy import deepcopy
 
 import torch
 from omegaconf import OmegaConf
-from torch.autograd import Variable
 
-from edml.core.start_device import _get_models
 from edml.helpers.model_splitting import Part
 from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
+from edml.tests.models.model_loading_helpers import (
+    _get_model_from_model_provider_config,
+)
+
+
+def plot_grad_flow(named_parameters):
+    from matplotlib import pyplot as plt
+    import numpy as np
+    from matplotlib.lines import Line2D
+
+    """Plots the gradients flowing through different layers in the net during training.
+    Can be used for checking for possible gradient vanishing / exploding problems.
+
+    Usage: Plug this function in Trainer class after loss.backwards() as
+    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow"""
+    ave_grads = []
+    max_grads = []
+    layers = []
+    for n, p in named_parameters:
+        if (p.requires_grad) and ("bias" not in n):
+            layers.append(n)
+            ave_grads.append(p.grad.abs().mean())
+            max_grads.append(p.grad.abs().max())
+    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
+    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
+    plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
+    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
+    plt.xlim(left=0, right=len(ave_grads))
+    plt.ylim(bottom=-0.001)  # , top=0.02)  # zoom in on the lower gradient regions
+    plt.xlabel("Layers")
+    plt.ylabel("average gradient")
+    plt.title("Gradient flow")
+    plt.grid(True)
+    plt.legend(
+        [
+            Line2D([0], [0], color="c", lw=4),
+            Line2D([0], [0], color="b", lw=4),
+            Line2D([0], [0], color="k", lw=4),
+        ],
+        ["max-gradient", "mean-gradient", "zero-gradient"],
+    )
+    plt.show()
 
 
 class GetModelsTest(unittest.TestCase):
@@ -22,17 +62,8 @@ class GetModelsTest(unittest.TestCase):
             )
         )
 
-    def _get_model_from_model_provider_config(self, config_name):
-        self.cfg.model_provider = OmegaConf.load(
-            os.path.join(
-                os.path.dirname(__file__),
-                f"../../config/model_provider/{config_name}.yaml",
-            )
-        )
-        return _get_models(self.cfg)
-
     def test_load_resnet20(self):
-        client, server = self._get_model_from_model_provider_config("resnet20")
+        client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
         self.assertIsInstance(client, Part)
         self.assertIsInstance(server, Part)
         self.assertEqual(len(client.layers), 4)
@@ -40,29 +71,18 @@ class GetModelsTest(unittest.TestCase):
         self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
 
     def test_load_resnet20_with_ae(self):
-        client, server = self._get_model_from_model_provider_config(
-            "resnet20-with-autoencoder"
+        client, server = _get_model_from_model_provider_config(
+            self.cfg, "resnet20-with-autoencoder"
         )
         self.assertIsInstance(client, ClientWithAutoencoder)
         self.assertIsInstance(server, ServerWithAutoencoder)
         self.assertEqual(len(client.model.layers), 4)
         self.assertEqual(len(server.model.layers), 5)
         self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
-        optimizer = torch.optim.Adam(server.parameters())
-        smashed_data = client(torch.zeros(1, 3, 32, 32))
-        server_smashed_data = Variable(smashed_data, requires_grad=True)
-        output_train = server(server_smashed_data)
-        loss_train = torch.nn.functional.cross_entropy(
-            output_train, torch.zeros((1, 100))
-        )
-        loss_train.backward()
-        optimizer.step()
-        smashed_data.backward(server_smashed_data.grad)
-        optimizer.step()
 
     def test_training_resnet20_with_ae_as_non_trainable_layers(self):
-        client_encoder, server_decoder = self._get_model_from_model_provider_config(
-            "resnet20-with-autoencoder"
+        client_encoder, server_decoder = _get_model_from_model_provider_config(
+            self.cfg, "resnet20-with-autoencoder"
         )
         client_params = deepcopy(str(client_encoder.model.state_dict()))
         encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict()))
@@ -70,17 +90,18 @@ class GetModelsTest(unittest.TestCase):
         decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict()))
 
         # Training loop
-        client_optimizer = torch.optim.Adam(client_encoder.parameters())
-        server_optimizer = torch.optim.Adam(server_decoder.parameters())
-        smashed_data = client_encoder(torch.zeros(1, 3, 32, 32))
-        server_smashed_data = Variable(smashed_data, requires_grad=True)
-        output_train = server_decoder(server_smashed_data)
+        client_optimizer = torch.optim.Adam(client_encoder.get_optimizer_params())
+        server_optimizer = torch.optim.Adam(server_decoder.get_optimizer_params())
+        smashed_data = client_encoder(torch.zeros(7, 3, 32, 32))
+        output_train = server_decoder(smashed_data)
         loss_train = torch.nn.functional.cross_entropy(
-            output_train, torch.rand((1, 100))
+            output_train, torch.rand((7, 100))
         )
         loss_train.backward()
         server_optimizer.step()
-        smashed_data.backward(server_smashed_data.grad)
+        client_encoder.trainable_layers_output.backward(
+            server_decoder.trainable_layers_input.grad
+        )
         client_optimizer.step()
 
         # check that AE hasn't changed, but client and server have
@@ -90,7 +111,7 @@ class GetModelsTest(unittest.TestCase):
         self.assertNotEqual(server_params, str(server_decoder.model.state_dict()))
 
     def test_load_resnet110(self):
-        client, server = self._get_model_from_model_provider_config("resnet110")
+        client, server = _get_model_from_model_provider_config(self.cfg, "resnet110")
         self.assertIsInstance(client, Part)
         self.assertIsInstance(server, Part)
         self.assertEqual(len(client.layers), 4)
@@ -98,8 +119,8 @@ class GetModelsTest(unittest.TestCase):
         self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
 
     def test_load_resnet110_with_ae(self):
-        client, server = self._get_model_from_model_provider_config(
-            "resnet110-with-autoencoder"
+        client, server = _get_model_from_model_provider_config(
+            self.cfg, "resnet110-with-autoencoder"
         )
         self.assertIsInstance(client, ClientWithAutoencoder)
         self.assertIsInstance(server, ServerWithAutoencoder)
diff --git a/edml/tests/models/model_loading_helpers.py b/edml/tests/models/model_loading_helpers.py
new file mode 100644
index 0000000..18b88a0
--- /dev/null
+++ b/edml/tests/models/model_loading_helpers.py
@@ -0,0 +1,15 @@
+import os
+
+from omegaconf import OmegaConf
+
+from edml.core.start_device import _get_models
+
+
+def _get_model_from_model_provider_config(cfg, config_name):
+    cfg.model_provider = OmegaConf.load(
+        os.path.join(
+            os.path.dirname(__file__),
+            f"../../config/model_provider/{config_name}.yaml",
+        )
+    )
+    return _get_models(cfg)
diff --git a/edml/tests/models/model_provider_test.py b/edml/tests/models/model_provider_test.py
new file mode 100644
index 0000000..bf872d0
--- /dev/null
+++ b/edml/tests/models/model_provider_test.py
@@ -0,0 +1,43 @@
+import os
+import unittest
+
+from omegaconf import OmegaConf
+
+from edml.models.provider.base import has_autoencoder
+from edml.tests.models.model_loading_helpers import (
+    _get_model_from_model_provider_config,
+)
+
+
+class ModelProviderTest(unittest.TestCase):
+    def setUp(self):
+        os.chdir(os.path.join(os.path.dirname(__file__), "../../../"))
+        self.cfg = OmegaConf.create({"some_key": "some_value"})
+
+    def test_has_autoencoder_for_model_without_autoencoder(self):
+        client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
+        self.assertFalse(has_autoencoder(client))
+        self.assertFalse(has_autoencoder(server))
+
+    def test_has_autoencoder_for_model_with_autoencoder(self):
+        client, server = _get_model_from_model_provider_config(
+            self.cfg, "resnet20-with-autoencoder"
+        )
+        self.assertTrue(has_autoencoder(client))
+        self.assertTrue(has_autoencoder(server))
+
+    def test_get_optimizer_params_for_model_without_autoencoder(self):
+        client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
+        self.assertEqual(list(client.get_optimizer_params()), list(client.parameters()))
+        self.assertEqual(list(server.get_optimizer_params()), list(server.parameters()))
+
+    def test_get_optimizer_params_for_model_with_autoencoder(self):
+        client, server = _get_model_from_model_provider_config(
+            self.cfg, "resnet20-with-autoencoder"
+        )
+        self.assertEqual(
+            list(client.get_optimizer_params()), list(client.model.parameters())
+        )
+        self.assertEqual(
+            list(server.get_optimizer_params()), list(server.model.parameters())
+        )
-- 
GitLab