diff --git a/edml/core/client.py b/edml/core/client.py
index d7ff6723297bc45a9f2b5fd3e3bb002bb3de0111..fe41e2cdb3c0be6847b37df8ae3c7cae8c4eba7d 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -135,7 +135,9 @@ class DeviceClient:
         return len(self._train_data.dataset)
 
     @check_device_set()
-    def train_single_batch(self, batch_index: int) -> Optional[SLTrainBatchResult]:
+    def train_single_batch(
+        self, batch_index: int
+    ) -> Optional[torch.Tensor, torch.Tensor]:
         torch.cuda.set_device(self._device)
         # We have to re-initialize the data loader in the case that we do another epoch.
         if batch_index == 0:
@@ -180,11 +182,7 @@ class DeviceClient:
                 "start_time": start_time,
                 "end_time": end_time,
             }
-
-            return SLTrainBatchResult(
-                smashed_data=smashed_data,
-                labels=batch_labels,
-            )
+            return smashed_data, batch_labels
 
     @check_device_set()
     def backward_single_batch(
@@ -378,7 +376,7 @@ class DeviceClient:
         )
         return diagnostic_metric_results
 
-    def set_gradient_and_finalize_training(self, gradients: Any):        
+    def set_gradient_and_finalize_training(self, gradients: Any):
         for param, grad in zip(self._model.parameters(), gradients):
             param.grad = grad.to(self._device)
 
diff --git a/edml/core/device.py b/edml/core/device.py
index 2ba6833226b537118f8c02399a0c6c4a4967e47f..51278532bfb7f1888f13671a24a289fc10e868b2 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -266,10 +266,8 @@ class NetworkDevice(Device):
     @update_battery
     @log_execution_time("logger", "client_only_batch_train")
     def train_batch_on_client_only(self, batch_index: int):
-        result = self.client.train_single_batch(batch_index=batch_index)
-        if result is None:
-            return result
-        return result.smashed_data, result.labels
+        smashed_data, labels = self.client.train_single_batch(batch_index=batch_index)
+        return smashed_data, labels
 
     @update_battery
     def train_batch_on_client_only_on(self, device_id: str, batch_index: int):
@@ -579,14 +577,10 @@ class RPCDeviceServicer(DeviceServicer):
         batch_index = request.batch_index
         print(f"Starting single batch@{batch_index}")
 
-        result = self.device.client.train_single_batch(batch_index)
+        smashed_data, labels = self.device.client.train_single_batch(batch_index)
 
-        # If the last batch's size was smaller than the configured batch size, the train call returns None
-        if result is None:
-            return connection_pb2.SingleBatchTrainingResponse()
-
-        smashed_data = Activations(activations=tensor_to_proto(result.smashed_data))
-        labels = Labels(labels=tensor_to_proto(result.labels))
+        smashed_data = Activations(activations=tensor_to_proto(smashed_data))
+        labels = Labels(labels=tensor_to_proto(labels))
         return connection_pb2.SingleBatchTrainingResponse(
             smashed_data=smashed_data,
             labels=labels,
@@ -597,14 +591,16 @@ class RPCDeviceServicer(DeviceServicer):
     ):
         gradients = proto_to_tensor(request.gradients.gradients)
 
-        metrics, gradients = self.device.client.backward_single_batch(gradients=gradients)
+        metrics, gradients = self.device.client.backward_single_batch(
+            gradients=gradients
+        )
         return connection_pb2.SingleBatchBackwardResponse(
             metrics=metrics_to_proto(metrics),
             gradients=Gradients(gradients=tensor_to_proto(gradients)),
         )
-    
+
     def SetGradientsAndFinalizeTrainingStep(
-            self, request: SetGradientsRequest, context
+        self, request: SetGradientsRequest, context
     ):
         gradients = proto_to_tensor(request.gradients.gradients)
         self.device.client.set_gradient_and_finalize_training(gradients=gradients)
diff --git a/edml/core/server.py b/edml/core/server.py
index b78793678f6f77bc84354c93f9d32064eee657c1..8853bd92bc3d5dc438e7ada06cb4dce5e62967c5 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -260,19 +260,15 @@ class DeviceServer:
                 for client_id in clients
             ]
             for future in concurrent.futures.as_completed(futures):
-                client_forward_pass_responses.append(future.result())
+                (client_id, result) = future.result()
+                if result is not None and result is not False:
+                    client_forward_pass_responses.append((client_id, result))
 
             # We want to split up the responses into a list of client IDs and batches again.
             client_ids = [b[0] for b in client_forward_pass_responses]
             client_batches = [b[1] for b in client_forward_pass_responses]
 
-            if _empty_batches(client_batches):
-                # Only the last batch anyway.
-                break
-
             print(f"\n\n\nBATCHES: {len(client_batches)}\n\n\n")
-            # batches2 = [b for b in batches if b is not None]
-            # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n")
             server_batch = _concat_smashed_data(
                 [b[0].to(self._device) for b in client_batches]
             )
@@ -307,10 +303,6 @@ class DeviceServer:
             )
 
             client_gradients = torch.chunk(server_gradients, num_client_gradients)
-            # print(f"::: result shape: {client_gradients[1].shape}")
-            # concatenated_client_gradients = torch.stack(client_gradients, dim=0)
-            # mean_tensor = torch.mean(concatenated_client_gradients, dim=0)
-            # print(f"::: -> {mean_tensor.shape}")
 
             futures = [
                 executor.submit(
@@ -323,13 +315,15 @@ class DeviceServer:
                 client_backpropagation_results.append(future.result())
 
             client_backpropagation_gradients = [
-                result[1] for result in client_backpropagation_results
+                result[1]
+                for result in client_backpropagation_results
+                if result is not None and result is not False
             ]
 
             # We want to average the client's backpropagation gradients and send them over again to finalize the
             # current training step.
             averaged_gradient = _calculate_gradient_mean(
-                client_backpropagation_gradients
+                client_backpropagation_gradients, self._device
             )
             futures = [
                 executor.submit(
@@ -383,9 +377,11 @@ class DeviceServer:
         )
 
 
-def _calculate_gradient_mean(gradients: List[Variable], device: str = "cpu") -> Variable:
+def _calculate_gradient_mean(
+    gradients: List[Variable], device: str = "cpu"
+) -> Variable:
     num_devices = len(gradients)
-    weights = [1] * num_devices
+    weights = [1 / num_devices] * num_devices
 
     # We need to move all tensors to the same device to do calculations.
     for i, client_gradients in enumerate(gradients):
@@ -401,8 +397,3 @@ def _calculate_gradient_mean(gradients: List[Variable], device: str = "cpu") ->
 def _concat_smashed_data(data: List[Any]) -> Any:
     """Creates a single batch tensor from a list of tensors."""
     return torch.cat(data, dim=0)
-
-
-def _empty_batches(batches):
-    """Checks if all the list entries are `None`."""
-    return batches.count(None) == len(batches)
diff --git a/edml/tests/core/server_test.py b/edml/tests/core/server_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1331f8aa7bad195eb02044bcf25d381fe2ee2a13
--- /dev/null
+++ b/edml/tests/core/server_test.py
@@ -0,0 +1,192 @@
+import unittest
+from collections import OrderedDict
+from copy import copy
+from typing import Tuple, Any
+from unittest.mock import Mock
+
+import torch.utils.data
+from omegaconf import DictConfig
+from torch import nn, tensor
+from torch.utils.data import DataLoader, Dataset, TensorDataset
+
+from edml.core.battery import Battery
+from edml.core.client import DeviceClient
+from edml.core.device import Device
+from edml.core.server import DeviceServer
+
+
+class ClientModel(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.layer = nn.Linear(1, 1)
+        self.output = nn.ReLU()
+
+    def forward(self, x):
+        return self.output(self.layer(x))
+
+
+class ServerModel(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.layer = nn.Linear(1, 2)
+        # self.layer.weight = tensor([[-0.6857]])
+        self.output = nn.Softmax(dim=1)
+
+    def forward(self, x):
+        return self.output(self.layer(x))
+
+
+class ToyDataset(Dataset):
+    def __init__(self, data: list, labels: list):
+        self.length = len(data)
+        self.data = torch.Tensor(data)
+        self.labels = torch.Tensor(labels, dtype=int)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        return self.data[index], self.labels[index]
+
+    def __len__(self):
+        return self.length
+
+
+class PSLTest(unittest.TestCase):
+    def setUp(self):
+        cfg = DictConfig(
+            {
+                "optimizer": {
+                    "_target_": "torch.optim.SGD",
+                    "lr": 1,
+                    "momentum": 0,
+                    "weight_decay": 0,
+                },
+                "experiment": {"metrics": ["accuracy"]},
+                "dataset": {"num_classes": 2, "average_setting": "micro"},
+                "topology": {
+                    "devices": [
+                        {
+                            "device_id": "d0",
+                            "address": "localhost:50051",
+                            "torch_device": "cuda:0",
+                        },
+                        {
+                            "device_id": "d1",
+                            "address": "localhost:50052",
+                            "torch_device": "cuda:0",
+                        },
+                    ]
+                },
+                "own_device_id": "d0",
+            }
+        )
+        # init models with fixed weights for repeatability
+        server_state_dict = OrderedDict(
+            [
+                ("layer.weight", tensor([[-0.5], [-1.0]])),
+                ("layer.bias", tensor([-0.5, 0.25])),
+            ]
+        )
+        client_state_dict = OrderedDict(
+            [("layer.weight", tensor([[-1.0]])), ("layer.bias", tensor([0.5]))]
+        )
+        server_model = ServerModel()
+        server_model.load_state_dict(server_state_dict)
+        client_model1 = ClientModel()
+        client_model1.load_state_dict(client_state_dict)
+        client_model2 = ClientModel()
+        client_model2.load_state_dict(client_state_dict)
+        self.server = DeviceServer(
+            model=server_model, loss_fn=torch.nn.L1Loss(), cfg=cfg
+        )
+        self.client1 = DeviceClient(
+            model=client_model1,
+            cfg=cfg,
+            train_dl=DataLoader(TensorDataset(tensor([[0.9]]), tensor([[0.0, 1.0]]))),
+            val_dl=DataLoader(TensorDataset(tensor([[0.75]]), tensor([[0.0, 1.0]]))),
+            test_dl=None,
+        )
+        client2_cfg = cfg.copy()
+        client2_cfg["own_device_id"] = "d1"
+        self.client2 = DeviceClient(
+            model=client_model2,
+            cfg=client2_cfg,
+            train_dl=DataLoader(TensorDataset(tensor([[0.1]]), tensor([[1.0, 0.0]]))),
+            val_dl=DataLoader(TensorDataset(tensor([[0.25]]), tensor([[1.0, 0.0]]))),
+            test_dl=None,
+        )
+
+        def get_client_side_effect(fn):
+            """
+            Creates side effects for methods of the form METHOD_on(device_id) skipping the request dispatcher.
+            """
+
+            def side_effect(*args, **kwargs):
+                # get the device id which is either a positional or keyword arg
+                if len(args) > 0:
+                    device_id = args[0]
+                elif "client_id" in kwargs:
+                    device_id = kwargs.pop("client_id")
+                elif "device_id" in kwargs:
+                    device_id = kwargs.pop("device_id")
+                else:
+                    return KeyError(
+                        f"Could not find device_id in args or kwargs for function {fn}"
+                    )
+                # delegate to correct client then using the given method name
+                if device_id == "d0":
+                    return self.client1.__class__.__dict__[fn](
+                        self.client1, *args, **kwargs
+                    )
+                elif device_id == "d1":
+                    return self.client2.__class__.__dict__[fn](
+                        self.client2, *args, **kwargs
+                    )
+                else:
+                    return KeyError(f"Unknown device_id {device_id}")
+
+            return side_effect
+
+        def get_server_side_effect(fn):
+            def side_effect(*args, **kwargs):
+                return self.server.__class__.__dict__[fn](
+                    self.server, *args, **kwargs
+                ) + (
+                    1,
+                )  # Add (1,) as placeholder for DiagnosticMetricsContainer
+
+            return side_effect
+
+        node_device = Mock(Device)
+        node_device.battery = Mock(Battery)
+        node_device.train_batch_on_client_only_on.side_effect = get_client_side_effect(
+            "train_single_batch"
+        )
+        node_device.backpropagation_on_client_only_on.side_effect = (
+            get_client_side_effect("backward_single_batch")
+        )
+        node_device.set_gradient_and_finalize_training_on_client_only_on.side_effect = (
+            get_client_side_effect("set_gradient_and_finalize_training")
+        )
+        node_device.train_batch.side_effect = get_server_side_effect("train_batch")
+        node_device.evaluate_on.side_effect = get_client_side_effect("evaluate")
+        node_device.evaluate_batch.side_effect = get_server_side_effect(
+            "evaluate_batch"
+        )
+
+        self.node_device1 = copy(node_device)
+        self.node_device2 = copy(node_device)
+        self.node_device1.client = self.client1
+        self.node_device1.device_id = "d0"
+        self.node_device2.device_id = "d1"
+        self.client1.set_device(self.node_device1)
+        self.client2.set_device(self.node_device2)
+        self.server.set_device(self.node_device1)
+
+    def test_train_parallel_sl(self):
+        (
+            client_weights,
+            server_weights,
+            model_metrics,
+            optimizer_state,
+            diagnostic_metrics,
+        ) = self.server.train_parallel_split_learning(["d0", "d1"], round_no=0)
+        self.assertDictEqual(self.client1.get_weights(), self.client2.get_weights())