diff --git a/edml/core/client.py b/edml/core/client.py
index 7dc1617a1b0d283211b9be9900214831ab971e08..1a81dfb0d0c661da362e5b6c14eff28c21c935f8 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -229,7 +229,7 @@ class DeviceClient:
         metrics_container = DiagnosticMetricResultContainer([metric])
 
         gradients = []
-        for param in self._model.parameters():
+        for param in self._model.get_optimizer_params():
             if param.grad is not None:
                 gradients.append(param.grad)
             else:
@@ -389,7 +389,7 @@ class DeviceClient:
         return diagnostic_metric_results
 
     def set_gradient_and_finalize_training(self, gradients: Any):
-        for param, grad in zip(self._model.parameters(), gradients):
+        for param, grad in zip(self._model.get_optimizer_params(), gradients):
             param.grad = grad.to(self._device)
 
         self._optimizer.step()
diff --git a/edml/models/provider/base.py b/edml/models/provider/base.py
index 2b28c7cf27953017e140f976b4b9d76f5f2bed00..9fa81abd5a7c58e880f163d2d88eb070c7271e47 100644
--- a/edml/models/provider/base.py
+++ b/edml/models/provider/base.py
@@ -1,4 +1,3 @@
-import torch
 from torch import nn
 
 
@@ -8,16 +7,6 @@ def has_autoencoder(model: nn.Module):
     return False
 
 
-def get_grads(model: nn.Module):
-    gradients = []
-    for param in model.parameters():
-        if param.grad is not None:
-            gradients.append(param.grad)
-        else:
-            gradients.append(torch.zeros_like(param))
-    return gradients
-
-
 def add_optimizer_params_function(model: nn.Module):
     # exclude AE params
     if has_autoencoder(model):
diff --git a/results/result_generation.py b/results/result_generation.py
index caa42642976785d82f0313733b324c9cd63b1166..f7db5c83721b4f51fb6cd08b7f14e6a277a4387b 100644
--- a/results/result_generation.py
+++ b/results/result_generation.py
@@ -39,6 +39,8 @@ LABEL_MAPPING = {
     "device": "Device",
 }
 
+DPI = 300
+
 
 def scale_parallel_time(run_df, scale_factor=1.0):
     """
@@ -401,7 +403,7 @@ def plot_remaining_devices(devices_per_round, save_path=None):
         save_path: (str) the path to save the plot to
     """
 
-    plt.figure()
+    plt.figure(dpi=DPI)
     num_rounds = [0]
     max_devices = []
     for (strategy, job), (rounds, num_devices) in devices_per_round.items():
@@ -492,7 +494,7 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"):
         accuracies_per_round: (dict) the accuracy (list(float)) per round (list(int)) for each group
         save_path: (str) the path to save the plot to
     """
-    plt.figure()
+    plt.figure(dpi=DPI)
     num_rounds = [0]
     for (strategy, job), (rounds, accs) in accuracies_per_round.items():
         plt.plot(rounds, accs, label=f"{STRATEGY_MAPPING[strategy]}")
@@ -517,7 +519,7 @@ def plot_accuracies_over_time(accuracies_per_time, save_path=None, phase="train"
         accuracies_per_time: (dict) the accuracy (list(float)) per time (list(float)) for each group
         save_path: (str) the path to save the plot to
     """
-    plt.figure()
+    plt.figure(dpi=DPI)
     for (strategy, job), (time, accs) in accuracies_per_time.items():
         plt.plot(time, accs, label=f"{STRATEGY_MAPPING[strategy]}")
     plt.xlabel(LABEL_MAPPING["runtime"])
@@ -688,7 +690,7 @@ def plot_batteries_over_time(
         aggregated: (bool) whether the battery is aggregated or not
     """
     if aggregated:
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         for (strategy, job), series in batteries_over_time.items():
             runtime = max_runtimes[(strategy, job)]
@@ -709,7 +711,7 @@ def plot_batteries_over_time(
             plt.close()
     else:
         for (strategy, job), series_dict in batteries_over_time.items():
-            plt.figure()
+            plt.figure(dpi=DPI)
             plt.rcParams.update({"font.size": 13})
             for device_id, series in series_dict.items():
                 runtime = max_runtimes[(strategy, job)]
@@ -739,7 +741,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T
         aggregated: (bool) whether the battery is aggregated or not
     """
     if aggregated:
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         num_rounds = [0]
         for (strategy, job), series in batteries_over_epoch.items():
@@ -760,7 +762,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T
             plt.close()
     else:
         for (strategy, job), series_dict in batteries_over_epoch.items():
-            plt.figure()
+            plt.figure(dpi=DPI)
             plt.rcParams.update({"font.size": 13})
             num_rounds = [0]
             for device_id, series in series_dict.items():
@@ -899,7 +901,7 @@ def plot_batteries_over_time_with_activity(
             0,
             train_time_end(server_train_times, client_train_times) * 1.05,
         )  # set end 5% after last activity timestamp
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         # battery_plot
         plt.subplot(2, 1, 1)
@@ -976,7 +978,7 @@ def plot_batteries_over_epoch_with_activity_at_time_scale(
             ],
             [str(i) for i in range(0, len(start_times), max(1, len(start_times) // 8))],
         )
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         # battery_plot
         plt.subplot(2, 1, 1)
@@ -1048,7 +1050,7 @@ def plot_batteries_over_epoch_with_activity_at_epoch_scale(
         ).sort_values("start")
 
         # battery plot
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         plt.subplot(2, 1, 1)
         num_rounds = []