diff --git a/edml/core/client.py b/edml/core/client.py
index fe41e2cdb3c0be6847b37df8ae3c7cae8c4eba7d..b3c3e7491e718cf98c9ccde598e00fd4d2054996 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -220,7 +220,7 @@ class DeviceClient:
 
         gradients = []
         for param in self._model.parameters():
-            if param is not None:
+            if param.grad is not None:
                 gradients.append(param.grad)
             else:
                 gradients.append(torch.zeros_like(param))
diff --git a/edml/core/device.py b/edml/core/device.py
index 51278532bfb7f1888f13671a24a289fc10e868b2..1348bea7e4a4fadb44a6bd304bfc466f44bbba41 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -575,7 +575,6 @@ class RPCDeviceServicer(DeviceServicer):
 
     def TrainSingleBatchOnClient(self, request, context):
         batch_index = request.batch_index
-        print(f"Starting single batch@{batch_index}")
 
         smashed_data, labels = self.device.client.train_single_batch(batch_index)
 
diff --git a/edml/core/server.py b/edml/core/server.py
index 8853bd92bc3d5dc438e7ada06cb4dce5e62967c5..5e2f8235de476c5a998f9c8c040b462ac9d3b198 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -1,12 +1,11 @@
 from __future__ import annotations
 
 import concurrent.futures
-import time
 from typing import List, Optional, Tuple, Any, TYPE_CHECKING
 
 import torch
 from omegaconf import DictConfig
-from colorama import init, Fore
+from colorama import Fore
 from torch import nn
 from torch.autograd import Variable
 
@@ -20,7 +19,7 @@ from edml.helpers.metrics import (
     ModelMetricResultContainer,
     DiagnosticMetricResultContainer,
 )
-from edml.helpers.types import StateDict, SLTrainBatchResult, LossFn
+from edml.helpers.types import StateDict, LossFn
 
 if TYPE_CHECKING:
     from edml.core.device import Device
@@ -251,10 +250,9 @@ class DeviceServer:
 
         # We iterate over each batch, initializing all client training at once and processing the results afterward.
         num_batches = self.node_device.client.get_approximated_num_batches()
-        print(f"\n\n:: BATCHES :: {num_batches}\n\n")
+        print(f":: BATCHES :: {num_batches}")
         for batch_index in range(num_batches):
             client_forward_pass_responses = []
-
             futures = [
                 executor.submit(client_training_job, client_id, batch_index)
                 for client_id in clients
@@ -268,20 +266,17 @@ class DeviceServer:
             client_ids = [b[0] for b in client_forward_pass_responses]
             client_batches = [b[1] for b in client_forward_pass_responses]
 
-            print(f"\n\n\nBATCHES: {len(client_batches)}\n\n\n")
             server_batch = _concat_smashed_data(
                 [b[0].to(self._device) for b in client_batches]
             )
             server_labels = _concat_smashed_data(
                 [b[1].to(self._device) for b in client_batches]
             )
-
             # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
             # to split the gradients back into batch-sized tensors to average them before sending them to the client.
             server_gradients, server_loss, server_metrics = (
                 self.node_device.train_batch(server_batch, server_labels)
             )  # DiagnosticMetricResultContainer
-
             # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
             # do the client propagation once the current loss value is larger than the threshold.
             print(
@@ -301,25 +296,23 @@ class DeviceServer:
             print(
                 f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
             )
-
-            client_gradients = torch.chunk(server_gradients, num_client_gradients)
-
+            # clone single client gradients so that client_gradients is not a list of views of server_gradients
+            # if we just use torch.chunk, each client will receive the whole server_gradients
+            client_gradients = [
+                t.clone().detach()
+                for t in torch.chunk(server_gradients, num_client_gradients)
+            ]
             futures = [
                 executor.submit(
                     client_backpropagation_job, client_id, client_gradients[idx]
                 )
                 for (idx, client_id) in enumerate(client_ids)
             ]
-            client_backpropagation_results = []
+            client_backpropagation_gradients = []
             for future in concurrent.futures.as_completed(futures):
-                client_backpropagation_results.append(future.result())
-
-            client_backpropagation_gradients = [
-                result[1]
-                for result in client_backpropagation_results
-                if result is not None and result is not False
-            ]
-
+                _, grads = future.result()
+                if grads is not None and grads is not False:
+                    client_backpropagation_gradients.append(grads)
             # We want to average the client's backpropagation gradients and send them over again to finalize the
             # current training step.
             averaged_gradient = _calculate_gradient_mean(
@@ -340,7 +333,6 @@ class DeviceServer:
         for client_id in clients:
             train_metrics = self.finalize_metrics(str(client_id), "train")
 
-            print(f"::: evaluating on {client_id}")
             evaluation_diagnostics_metrics = self.node_device.evaluate_on(
                 device_id=client_id,
                 server_device=self.node_device.device_id,