diff --git a/edml/controllers/base_controller.py b/edml/controllers/base_controller.py
index a2adca85736844b23f40a7c8c5c1a5c4574ef280..76e1a6d61a6f8258b70e12804b52b57d45a9e58a 100644
--- a/edml/controllers/base_controller.py
+++ b/edml/controllers/base_controller.py
@@ -5,7 +5,7 @@ from typing import Optional
 import torch
 
 from edml.controllers.early_stopping import create_early_stopping_callback
-from edml.core.device import DeviceRequestDispatcher
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
 from edml.core.start_device import _get_models
 from edml.helpers.logging import SimpleLogger, create_logger
 from edml.helpers.metrics import ModelMetricResultContainer
diff --git a/edml/core/device.py b/edml/core/device.py
index aea1ae95de2735dc4c984a13d49dadd632d66432..873a39b08ac386f807c0448d13824e282f4ff2e5 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -2,76 +2,25 @@ from __future__ import annotations
 
 import threading
 from abc import ABC, abstractmethod
-from typing import Optional, Dict, Any, List, Union, Tuple, cast
+from typing import Optional, Any, List, Tuple, cast
 
-import grpc
-from google.protobuf.message import Message
 from omegaconf import DictConfig
-from torch import Tensor
 from torch.autograd import Variable
 
 from edml.core.battery import Battery
 from edml.core.client import DeviceClient
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
 from edml.core.server import DeviceServer
-from edml.generated import connection_pb2
-from edml.generated.connection_pb2 import (
-    SetGradientsRequest,
-    SetWeightsRequest,
-    TrainBatchRequest,
-    TrainGlobalResponse,
-    TrainEpochResponse,
-    TrainBatchResponse,
-    EvalGlobalResponse,
-    EvalResponse,
-    EvalBatchResponse,
-    FullModelTrainResponse,
-    BatteryStatusResponse,
-    DatasetModelInfoResponse,
-    EndExperimentResponse,
-    StartExperimentResponse,
-    SingleBatchTrainingResponse,
-    SingleBatchBackwardRequest,
-    TrainGlobalParallelSplitLearningResponse,
-    SingleBatchBackwardResponse,
-)
-from edml.generated.connection_pb2_grpc import DeviceServicer, DeviceStub
-from edml.generated.datastructures_pb2 import (
-    Gradients,
-    Weights,
-    DeviceInfo,
-    Activations,
-    Labels,
-    BatteryStatus,
-    Empty,
-    Metrics,
-)
 from edml.helpers.decorators import (
     log_execution_time,
     update_battery,
     add_time_to_diagnostic_metrics,
 )
-from edml.helpers.interceptors import DeviceClientInterceptor
 from edml.helpers.logging import SimpleLogger
 from edml.helpers.metrics import (
     ModelMetricResultContainer,
     DiagnosticMetricResultContainer,
 )
-from edml.helpers.proto_helpers import (
-    proto_to_tensor,
-    tensor_to_proto,
-    state_dict_to_proto,
-    proto_to_state_dict,
-    proto_to_weights,
-    metrics_to_proto,
-    proto_to_metrics,
-    _proto_size_per_field,
-)
-from edml.helpers.types import (
-    HasMetrics,
-    StateDict,
-    DeviceBatteryStatus,
-    DeviceBatteryStatusReport,
-)
 
 
 class Device(ABC):
@@ -224,6 +173,38 @@ class Device(ABC):
 
 
 class NetworkDevice(Device):
+    """
+    Main implementation of the Device class. Provides actions related to ML tasks which are delegated accordingly.
+    There are two hooks for each action, e.g. set_weights and set_weights_on. Calling set_weights lets the device
+    set the provided weights on itself, while set_weights_on delegates the weights to a certain receiver. This way,
+    the "client" or "server" part of split learning can request certain actions without handling the communication details.
+    Outgoing communication with other devices is delegated to the DeviceRequestDispatcher and incoming requests are handled
+    by the RPCDeviceServicer calling appropriate actions on the device.
+    To set up the communication, set_devices must be called with the other devices and connection information.
+    Other responsibilities of this class are logging and updating the battery.
+
+    Attributes:
+        device_id (str): This device's id.
+        logger (SimpleLogger): The logger instance the device can use.
+        battery (Battery): The device's battery. Certain function consume energy and drain the battery.
+        client (DeviceClient): The client part of this device. Initialized later by explicitly calling
+            py:meth:`set_client`.
+        server (DeviceServer): The server part of this device. Initialized later by explicitly calling
+            py:meth:`set_server`.
+    """
+
+    def __init__(
+        self,
+        device_id: str,
+        logger: SimpleLogger,
+        battery: Battery,
+        stop_event: Optional[threading.Event] = None,
+    ):
+        self.devices: List[DictConfig[str, Any]] = []
+        self.request_dispatcher = DeviceRequestDispatcher([], device_id=device_id)
+        self.stop_event = stop_event
+        super().__init__(device_id, logger, battery)
+
     @update_battery
     @log_execution_time("logger", "finalize_gradients")
     def set_gradient_and_finalize_training_on_client_only_on(
@@ -287,18 +268,6 @@ class NetworkDevice(Device):
                 device_id=device_id, batch_index=batch_index, round_no=round_no
             )
 
-    def __init__(
-        self,
-        device_id: str,
-        logger: SimpleLogger,
-        battery: Battery,
-        stop_event: Optional[threading.Event] = None,
-    ):
-        self.devices: List[DictConfig[str, Any]] = []
-        self.request_dispatcher = DeviceRequestDispatcher([], device_id=device_id)
-        self.stop_event = stop_event
-        super().__init__(device_id, logger, battery)
-
     @add_time_to_diagnostic_metrics("train_global")
     @update_battery
     @log_execution_time("logger", "train_global_time")
@@ -444,625 +413,3 @@ class NetworkDevice(Device):
     def _log_current_battery_capacity(self):
         """Wrapper for logging the current battery capacity"""
         self.logger.log({"battery": self.battery.remaining_capacity()})
-
-
-class RPCDeviceServicer(DeviceServicer):
-    def __init__(self, device: NetworkDevice):
-        self.device = device
-
-    def TrainGlobal(self, request, context):
-        print(f"Called TrainGlobal on device {self.device.device_id}")
-        client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-            self.device.train_global(
-                request.epochs,
-                request.round_no,
-                request.adaptive_threshold_value,
-                proto_to_state_dict(request.optimizer_state),
-            )
-        )
-        response = connection_pb2.TrainGlobalResponse(
-            client_weights=Weights(weights=state_dict_to_proto(client_weights)),
-            server_weights=Weights(weights=state_dict_to_proto(server_weights)),
-            metrics=metrics_to_proto(metrics),
-            optimizer_state=state_dict_to_proto(optimizer_state),
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
-        )
-        return response
-
-    def SetWeights(self, request, context):
-        print(f"Called SetWeights on device {self.device.device_id}")
-        weights = proto_to_state_dict(request.weights.weights)
-        self.device.set_weights(weights, request.on_client)
-        return connection_pb2.SetWeightsResponse()
-
-    def TrainEpoch(self, request, context):
-        print(f"Called TrainEpoch on device {self.device.device_id}")
-        device_info = request.server
-        device_id = device_info.device_id
-        round_no = request.round_no
-        weights, diagnostic_metrics = self.device.train_epoch(device_id, round_no)
-        proto_weights = state_dict_to_proto(weights)
-        return connection_pb2.TrainEpochResponse(
-            weights=Weights(weights=proto_weights),
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
-        )
-
-    def TrainBatch(self, request, context):
-        activations = proto_to_tensor(request.smashed_data.activations)
-        labels = proto_to_tensor(request.labels.labels)
-        gradients, loss, diagnostic_metrics = self.device.train_batch(
-            activations, labels
-        )
-        proto_gradients = Gradients(gradients=tensor_to_proto(gradients))
-        return connection_pb2.TrainBatchResponse(
-            gradients=proto_gradients,
-            loss=loss,
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
-        )
-
-    def EvaluateGlobal(self, request, context):
-        print(f"Called EvaluateGlobal on device {self.device.device_id}")
-        metrics, diagnostic_metrics = self.device.evaluate_global(
-            request.validation, request.federated
-        )
-        return connection_pb2.EvalGlobalResponse(
-            metrics=metrics_to_proto(metrics),
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
-        )
-
-    def Evaluate(self, request, context):
-        print(f"Called Evaluate on device {self.device.device_id}")
-        diagnostic_metrics = self.device.evaluate(
-            request.server.device_id, request.validation
-        )
-        return connection_pb2.EvalResponse(
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics)
-        )
-
-    def EvaluateBatch(self, request, context):
-        activations = proto_to_tensor(request.smashed_data.activations)
-        labels = proto_to_tensor(request.labels.labels)
-        diagnostic_metrics = self.device.evaluate_batch(activations, labels)
-        return connection_pb2.EvalBatchResponse(
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics)
-        )
-
-    def FullModelTraining(self, request, context):
-        print(f"Called Full Training on device {self.device.device_id}")
-        client_weights, server_weights, num_samples, metrics, diagnostic_metrics = (
-            self.device.federated_train(request.round_no)
-        )
-        return connection_pb2.FullModelTrainResponse(
-            client_weights=Weights(weights=state_dict_to_proto(client_weights)),
-            server_weights=Weights(weights=state_dict_to_proto(server_weights)),
-            num_samples=num_samples,
-            metrics=metrics_to_proto(metrics),
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
-        )
-
-    def StartExperiment(self, request, context) -> StartExperimentResponse:
-        print(f"Start Experiment on device {self.device.device_id}")
-        self.device.start_experiment()
-        return connection_pb2.StartExperimentResponse()
-
-    def EndExperiment(self, request, context) -> EndExperimentResponse:
-        print(f"End Experiment on device {self.device.device_id}")
-        print(f"Remaining battery capacity {self.device.battery.remaining_capacity()}")
-        self.device.end_experiment()
-        return connection_pb2.EndExperimentResponse()
-
-    def GetBatteryStatus(self, request, context):
-        print(f"Get Battery Status on device {self.device.device_id}")
-        initial_capacity, remaining_capacity = self.device.get_battery_status()
-
-        return connection_pb2.BatteryStatusResponse(
-            status=BatteryStatus(
-                initial_battery_level=initial_capacity,
-                current_battery_level=remaining_capacity,
-            )
-        )
-
-    def GetDatasetModelInfo(self, request, context):
-        print(f"Get Dataset and Model Info on device {self.device.device_id}")
-        return connection_pb2.DatasetModelInfoResponse(
-            train_samples=len(self.device.client._train_data.dataset),
-            validation_samples=len(self.device.client._val_data.dataset),
-            client_fw_flops=int(self.device.client._model_flops["FW"]),
-            server_fw_flops=int(self.device.server._model_flops["FW"]),
-            client_bw_flops=int(self.device.client._model_flops["BW"]),
-            server_bw_flops=int(self.device.server._model_flops["BW"]),
-        )
-
-    def TrainGlobalParallelSplitLearning(self, request, context):
-        print(f"Starting parallel split learning")
-        clients = self.device.__get_device_ids__()
-        round_no = request.round_no
-        adaptive_threshold_value = request.adaptive_threshold_value
-        optimizer_state = proto_to_state_dict(request.optimizer_state)
-
-        cw, sw, model_metrics, optimizer_state, diagnostic_metrics = (
-            self.device.train_parallel_split_learning(
-                clients=clients,
-                round_no=round_no,
-                adaptive_threshold_value=adaptive_threshold_value,
-                optimizer_state=optimizer_state,
-            )
-        )
-        response = connection_pb2.TrainGlobalParallelSplitLearningResponse(
-            client_weights=Weights(weights=state_dict_to_proto(cw)),
-            server_weights=Weights(weights=state_dict_to_proto(sw)),
-            metrics=metrics_to_proto(model_metrics),
-            optimizer_state=state_dict_to_proto(optimizer_state),
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
-        )
-        return response
-
-    def TrainSingleBatchOnClient(self, request, context):
-        batch_index = request.batch_index
-        round_no = request.round_no
-
-        smashed_data, labels = self.device.client.train_single_batch(
-            batch_index, round_no=round_no
-        )
-
-        smashed_data = Activations(activations=tensor_to_proto(smashed_data))
-        labels = Labels(labels=tensor_to_proto(labels))
-        return connection_pb2.SingleBatchTrainingResponse(
-            smashed_data=smashed_data,
-            labels=labels,
-        )
-
-    def BackwardPropagationSingleBatchOnClient(
-        self, request: SingleBatchBackwardRequest, context
-    ):
-        gradients = proto_to_tensor(request.gradients.gradients)
-
-        metrics, gradients = self.device.client.backward_single_batch(
-            gradients=gradients
-        )
-        return connection_pb2.SingleBatchBackwardResponse(
-            metrics=metrics_to_proto(metrics),
-            gradients=Gradients(gradients=tensor_to_proto(gradients)),
-        )
-
-    def SetGradientsAndFinalizeTrainingStep(
-        self, request: SetGradientsRequest, context
-    ):
-        gradients = proto_to_tensor(request.gradients.gradients)
-        self.device.client.set_gradient_and_finalize_training(gradients=gradients)
-        return Empty()
-
-
-class DeviceRequestDispatcher:
-
-    def __init__(
-        self,
-        devices: List[DictConfig[str, Any]],
-        logger: Optional[SimpleLogger] = None,
-        battery: Optional[Battery] = None,
-        stop_event: Optional[threading.Event] = None,
-        device_id: Optional[str] = None,
-    ):
-        self.devices = devices
-        self.connections: Dict[str, DeviceStub] = {}
-        # optional, interceptor only works if all three are set
-        self.logger = logger
-        self.battery = battery
-        self.stop_event = stop_event
-
-        self._establish_connections()
-        self.connections_lock = threading.Lock()
-        self.device_id = device_id  # used for diagnostic metrics to assign the source device correctly
-
-    def __get_device_address__(self, device_id: str) -> Optional[str]:
-        for device in self.devices:
-            if device.device_id == device_id:
-                return device.address
-        return None
-
-    def train_parallel_on_server(
-        self,
-        server_device_id: str,
-        epochs: int,
-        round_no: int,
-        adaptive_threshold_value: Optional[float] = None,
-        optimizer_state: dict[str, Any] = None,
-    ):
-        print(f"><><><> {adaptive_threshold_value}")
-
-        try:
-            response: TrainGlobalParallelSplitLearningResponse = self._get_connection(
-                server_device_id
-            ).TrainGlobalParallelSplitLearning(
-                connection_pb2.TrainGlobalParallelSplitLearningRequest(
-                    round_no=round_no,
-                    adaptive_threshold_value=adaptive_threshold_value,
-                    optimizer_state=state_dict_to_proto(optimizer_state),
-                )
-            )
-            return (
-                proto_to_weights(response.client_weights),
-                proto_to_weights(response.server_weights),
-                proto_to_metrics(response.metrics),
-                proto_to_state_dict(response.optimizer_state),
-                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
-            )
-        except grpc.RpcError:
-            self._handle_rpc_error(server_device_id)
-        except KeyError:
-            self._handle_unknown_device_id(server_device_id)
-        return False
-
-    def _establish_connections(self):
-        for device in self.devices:
-            channel = grpc.insecure_channel(
-                device.address,
-                options=[
-                    ("grpc.max_send_message_length", -1),
-                    ("grpc.max_receive_message_length", -1),
-                ],
-            )
-            if (
-                self.logger is not None
-                and self.battery is not None
-                and self.stop_event is not None
-            ):
-                channel = grpc.intercept_channel(
-                    channel,
-                    DeviceClientInterceptor(self.logger, self.battery, self.stop_event),
-                )
-            stub = DeviceStub(channel)
-            self.connections[device.device_id] = stub
-        print(f"active devices {self.connections.keys()}")
-
-    def _get_connection(self, device_id: str) -> DeviceStub:
-        """Returns the connection to the device with the given id.
-        If the device_id is not a valid dictionary key, the caller has to handle the KeyError.
-        """
-        with self.connections_lock:
-            return self.connections[device_id]
-
-    def active_devices(self) -> List[str]:
-        with self.connections_lock:  # lock connections to prevent remove while iterating
-            return list(self.connections.keys())
-
-    def _handle_rpc_error(self, device_id: str):
-        """Hook for handling an RpcError. Happens when the device encounters an empty battery during handling
-        the request or is not reachable (i.e. most probably shut down).
-        Here, the connection to the device is removed using a lock to avoid race conditions.
-        """
-        with self.connections_lock:
-            del self.connections[device_id]
-
-    def _handle_unknown_device_id(self, device_id: str):
-        """Handler in case the requested device id is not in the connection dictionary.
-        Could happen if the active devices weren't refreshed properly."""
-        if device_id in [device.device_id for device in self.devices]:
-            print(f"Device {device_id} not active")
-        else:
-            print(f"Unknown Device ID {device_id}")
-
-    def _add_byte_size_to_diagnostic_metrics(
-        self, response: Message, device_id: str, request=None
-    ):
-        """
-        Adds the byte size of the response to the diagnostic metrics.
-
-        Args:
-            response: A protobuf message.
-            device_id: The id of the device the request was sent from.
-            request: A protobuf message, if the request is of interest.
-        Returns:
-            A DiagnosticMetricResultContainer including the added byte size and previous metrics from the response (and possibly the request).
-        Raises:
-            None
-        Notes:
-        """
-        if response.HasField("diagnostic_metrics"):
-            response: HasMetrics
-            diagnostic_metrics = proto_to_metrics(response.diagnostic_metrics)
-        else:
-            diagnostic_metrics = DiagnosticMetricResultContainer()
-        diagnostic_metrics.merge(_proto_size_per_field(response, device_id))
-        if request is not None:
-            diagnostic_metrics.merge(_proto_size_per_field(request, device_id))
-        return diagnostic_metrics
-
-    def train_global_on(
-        self,
-        device_id: str,
-        epochs: int,
-        round_no: int = -1,
-        adaptive_threshold_value: Optional[float] = None,
-        optimizer_state: dict[str, Any] = None,
-    ) -> Union[
-        Tuple[
-            Dict[str, Any],
-            Dict[str, Any],
-            ModelMetricResultContainer,
-            Dict[str, Any],
-            DiagnosticMetricResultContainer,
-        ],
-        bool,
-    ]:
-        try:
-            response: TrainGlobalResponse = self._get_connection(device_id).TrainGlobal(
-                connection_pb2.TrainGlobalRequest(
-                    epochs=epochs,
-                    round_no=round_no,
-                    adaptive_threshold_value=adaptive_threshold_value,
-                    optimizer_state=state_dict_to_proto(optimizer_state),
-                )
-            )
-            return (
-                proto_to_weights(response.client_weights),
-                proto_to_weights(response.server_weights),
-                proto_to_metrics(response.metrics),
-                proto_to_state_dict(response.optimizer_state),
-                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
-            )
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def set_weights_on(
-        self, device_id: str, state_dict, on_client: bool, wait_for_ready: bool = False
-    ):
-        try:
-            self._get_connection(device_id).SetWeights(
-                SetWeightsRequest(
-                    weights=Weights(weights=state_dict_to_proto(state_dict)),
-                    on_client=on_client,
-                ),
-                wait_for_ready=wait_for_ready,
-            )
-            return True
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def train_epoch_on(self, device_id: str, server_device: str, round_no: int = -1):
-        try:
-            response: TrainEpochResponse = self._get_connection(device_id).TrainEpoch(
-                connection_pb2.TrainEpochRequest(
-                    server=DeviceInfo(
-                        device_id=server_device,
-                        address=self.__get_device_address__(server_device),
-                    ),
-                    round_no=round_no,
-                )
-            )
-            return proto_to_state_dict(
-                response.weights.weights
-            ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def train_batch_on(self, device_id: str, smashed_data, labels):
-        try:
-            request = TrainBatchRequest(
-                smashed_data=Activations(activations=tensor_to_proto(smashed_data)),
-                labels=Labels(labels=tensor_to_proto(labels)),
-            )
-            response: TrainBatchResponse = self._get_connection(device_id).TrainBatch(
-                request
-            )
-            return (
-                proto_to_tensor(response.gradients.gradients),
-                response.loss,
-                self._add_byte_size_to_diagnostic_metrics(
-                    response, self.device_id, request
-                ),
-            )
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def evaluate_global_on(self, device_id: str, val: bool = True, fed: bool = False):
-        try:
-            response: EvalGlobalResponse = self._get_connection(
-                device_id
-            ).EvaluateGlobal(
-                connection_pb2.EvalGlobalRequest(validation=val, federated=fed)
-            )
-            return proto_to_metrics(
-                response.metrics
-            ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def evaluate_on(self, device_id: str, server_device: str, val: bool):
-        try:
-            response: EvalResponse = self._get_connection(device_id).Evaluate(
-                connection_pb2.EvalRequest(
-                    server=DeviceInfo(
-                        device_id=server_device,
-                        address=self.__get_device_address__(server_device),
-                    ),
-                    validation=val,
-                )
-            )
-            return self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def evaluate_batch_on(self, device_id: str, smashed_data, labels):
-        try:
-            request = connection_pb2.EvalBatchRequest(
-                smashed_data=Activations(activations=tensor_to_proto(smashed_data)),
-                labels=Labels(labels=tensor_to_proto(labels)),
-            )
-            response: EvalBatchResponse = self._get_connection(device_id).EvaluateBatch(
-                request
-            )
-            return self._add_byte_size_to_diagnostic_metrics(
-                response, self.device_id, request
-            )
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def federated_train_on(self, device_id: str, round_no: int = -1) -> Union[
-        Tuple[
-            StateDict,
-            StateDict,
-            int,
-        ],
-        bool,
-    ]:
-        try:
-            response: FullModelTrainResponse = self._get_connection(
-                device_id
-            ).FullModelTraining(connection_pb2.FullModelTrainRequest(round_no=round_no))
-            return (
-                proto_to_weights(response.client_weights),
-                proto_to_weights(response.server_weights),
-                response.num_samples,
-                proto_to_metrics(response.metrics),
-                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
-            )
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def start_experiment_on(self, device_id: str, wait_for_ready: bool = False) -> bool:
-        try:
-            self._get_connection(device_id).StartExperiment(
-                connection_pb2.StartExperimentRequest(), wait_for_ready=wait_for_ready
-            )
-            return True
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def end_experiment_on(self, device_id: str) -> bool:
-        try:
-            self._get_connection(device_id).EndExperiment(
-                connection_pb2.EndExperimentRequest()
-            )
-            return True
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def get_battery_status_on(self, device_id: str) -> DeviceBatteryStatusReport:
-        try:
-            response: BatteryStatusResponse = self._get_connection(
-                device_id
-            ).GetBatteryStatus(connection_pb2.BatteryStatusRequest())
-            return DeviceBatteryStatus(
-                current_capacity=response.status.current_battery_level,
-                initial_capacity=response.status.initial_battery_level,
-            )
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def train_batch_on_client_only(
-        self, device_id: str, batch_index: int, round_no: int
-    ) -> Tuple[Tensor, Tensor] | None:
-        try:
-            response: SingleBatchTrainingResponse = self._get_connection(
-                device_id
-            ).TrainSingleBatchOnClient(
-                connection_pb2.SingleBatchTrainingRequest(
-                    batch_index=batch_index, round_no=round_no
-                )
-            )
-
-            # The response can only be None if the last batch was smaller than the configured batch size.
-            if response.HasField("smashed_data"):
-                return (
-                    proto_to_tensor(response.smashed_data.activations),
-                    proto_to_tensor(response.labels.labels),
-                )
-
-            return None
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def get_dataset_model_info_on(
-        self, device_id: str
-    ) -> Union[Tuple[int, int, int, int, int, int], bool]:
-        try:
-            response: DatasetModelInfoResponse = self._get_connection(
-                device_id
-            ).GetDatasetModelInfo(connection_pb2.DatasetModelInfoRequest())
-            return (
-                response.train_samples,
-                response.validation_samples,
-                response.client_fw_flops,
-                response.server_fw_flops,
-                response.client_bw_flops,
-                response.server_bw_flops,
-            )
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def backpropagation_on_client_only(self, device_id, gradients):
-        try:
-            response: SingleBatchBackwardResponse = self._get_connection(
-                device_id
-            ).BackwardPropagationSingleBatchOnClient(
-                connection_pb2.SingleBatchBackwardRequest(
-                    gradients=Gradients(gradients=tensor_to_proto(gradients))
-                )
-            )
-            return (
-                None,
-                proto_to_tensor(response.gradients.gradients),
-            )
-        except grpc.RpcError:
-            self._handle_rpc_error(device_id)
-        except KeyError:
-            self._handle_unknown_device_id(device_id)
-        return False
-
-    def set_gradient_and_finalize_training_on_client_only(
-        self, client_id: str, gradients: Any
-    ):
-        try:
-            response: Empty = self._get_connection(
-                client_id
-            ).SetGradientsAndFinalizeTrainingStep(
-                connection_pb2.SetGradientsRequest(
-                    gradients=Gradients(gradients=tensor_to_proto(gradients))
-                )
-            )
-            return response
-        except grpc.RpcError:
-            self._handle_rpc_error(client_id)
-        except KeyError:
-            self._handle_unknown_device_id(client_id)
-        return False
diff --git a/edml/core/device_request_dispatcher.py b/edml/core/device_request_dispatcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c7802ffb35d9cb4eea543938047d1e91ac48200
--- /dev/null
+++ b/edml/core/device_request_dispatcher.py
@@ -0,0 +1,468 @@
+from __future__ import annotations
+
+import threading
+from contextlib import contextmanager
+from typing import List, Any, Optional, Dict, Union, Tuple
+
+import grpc
+from google.protobuf.message import Message
+from omegaconf import DictConfig
+from torch import Tensor
+
+from edml.core.battery import Battery
+from edml.generated import connection_pb2
+from edml.generated.connection_pb2 import (
+    TrainGlobalParallelSplitLearningResponse,
+    TrainGlobalResponse,
+    SetWeightsRequest,
+    TrainEpochResponse,
+    TrainBatchRequest,
+    TrainBatchResponse,
+    EvalGlobalResponse,
+    EvalResponse,
+    EvalBatchResponse,
+    FullModelTrainResponse,
+    BatteryStatusResponse,
+    SingleBatchTrainingResponse,
+    DatasetModelInfoResponse,
+    SingleBatchBackwardResponse,
+)
+from edml.generated.connection_pb2_grpc import DeviceStub
+from edml.generated.datastructures_pb2 import (
+    Weights,
+    DeviceInfo,
+    Activations,
+    Labels,
+    Gradients,
+    Empty,
+)
+from edml.helpers.interceptors import DeviceClientInterceptor
+from edml.helpers.logging import SimpleLogger
+from edml.helpers.metrics import (
+    DiagnosticMetricResultContainer,
+    ModelMetricResultContainer,
+)
+from edml.helpers.proto_helpers import (
+    proto_to_metrics,
+    _proto_size_per_field,
+    state_dict_to_proto,
+    proto_to_weights,
+    proto_to_state_dict,
+    tensor_to_proto,
+    proto_to_tensor,
+)
+from edml.helpers.types import (
+    HasMetrics,
+    StateDict,
+    DeviceBatteryStatusReport,
+    DeviceBatteryStatus,
+)
+
+
+class DeviceRequestDispatcher:
+    """
+    Implements functionality to make gRPC calls to call actions on other devices. Serializes python objects to protobufs
+    for sending. Maintains connection information to the other devices.
+    If the device's battery is empty, it triggers a stop_event to terminate the device's process.
+    This way, it is assured that a device cannot call other devices while having no (simulated) energy left.
+    To account for the energy, an interceptor is defined to measure the communication overhead.
+
+    Attributes:
+        device_id (str): This device's id.
+        logger (SimpleLogger): The logger instance the device can use.
+        battery (Battery): The device's battery. Certain function consume energy and drain the battery. Here it is
+        updated to account for the communication overhead.
+        stop_event (threading.Event): Event to be set when the device is out of battery
+        devices (List[DictConfig[str, Any]]): Connection information to the other devices.
+    """
+
+    def __init__(
+        self,
+        devices: List[DictConfig[str, Any]],
+        logger: Optional[SimpleLogger] = None,
+        battery: Optional[Battery] = None,
+        stop_event: Optional[threading.Event] = None,
+        device_id: Optional[str] = None,
+    ):
+        self.devices = devices
+        self.connections: Dict[str, DeviceStub] = {}
+        # optional, interceptor only works if all three are set
+        self.logger = logger
+        self.battery = battery
+        self.stop_event = stop_event
+
+        self._establish_connections()
+        self.connections_lock = threading.Lock()
+        self.device_id = device_id  # used for diagnostic metrics to assign the source device correctly
+
+    def __get_device_address__(self, device_id: str) -> Optional[str]:
+        for device in self.devices:
+            if device.device_id == device_id:
+                return device.address
+        return None
+
+    def _establish_connections(self):
+        for device in self.devices:
+            channel = grpc.insecure_channel(
+                device.address,
+                options=[
+                    ("grpc.max_send_message_length", -1),
+                    ("grpc.max_receive_message_length", -1),
+                ],
+            )
+            if (
+                self.logger is not None
+                and self.battery is not None
+                and self.stop_event is not None
+            ):
+                channel = grpc.intercept_channel(
+                    channel,
+                    DeviceClientInterceptor(self.logger, self.battery, self.stop_event),
+                )
+            stub = DeviceStub(channel)
+            self.connections[device.device_id] = stub
+        print(f"active devices {self.connections.keys()}")
+
+    def _get_connection(self, device_id: str) -> DeviceStub:
+        """Returns the connection to the device with the given id.
+        If the device_id is not a valid dictionary key, the caller has to handle the KeyError.
+        """
+        with self.connections_lock:
+            return self.connections[device_id]
+
+    def active_devices(self) -> List[str]:
+        with self.connections_lock:  # lock connections to prevent remove while iterating
+            return list(self.connections.keys())
+
+    def _handle_rpc_error(self, device_id: str):
+        """Hook for handling an RpcError. Happens when the device encounters an empty battery during handling
+        the request or is not reachable (i.e. most probably shut down).
+        Here, the connection to the device is removed using a lock to avoid race conditions.
+        """
+        with self.connections_lock:
+            del self.connections[device_id]
+
+    def _handle_unknown_device_id(self, device_id: str):
+        """Handler in case the requested device id is not in the connection dictionary.
+        Could happen if the active devices weren't refreshed properly."""
+        if device_id in [device.device_id for device in self.devices]:
+            print(f"Device {device_id} not active")
+        else:
+            print(f"Unknown Device ID {device_id}")
+
+    def _add_byte_size_to_diagnostic_metrics(
+        self, response: Message, device_id: str, request=None
+    ):
+        """
+        Adds the byte size of the response to the diagnostic metrics.
+
+        Args:
+            response: A protobuf message.
+            device_id: The id of the device the request was sent from.
+            request: A protobuf message, if the request is of interest.
+        Returns:
+            A DiagnosticMetricResultContainer including the added byte size and previous metrics from the response (and possibly the request).
+        Raises:
+            None
+        Notes:
+        """
+        if response.HasField("diagnostic_metrics"):
+            response: HasMetrics
+            diagnostic_metrics = proto_to_metrics(response.diagnostic_metrics)
+        else:
+            diagnostic_metrics = DiagnosticMetricResultContainer()
+        diagnostic_metrics.merge(_proto_size_per_field(response, device_id))
+        if request is not None:
+            diagnostic_metrics.merge(_proto_size_per_field(request, device_id))
+        return diagnostic_metrics
+
+    @contextmanager
+    def handle_connection(self, device_id):
+        """
+        Context manager that wraps RPCs and handles errors appropriately.
+
+        Args:
+            device_id: The id of the device the request is sent to.
+        Returns:
+            None
+        Raises:
+            None
+        Notes:
+            Successful RPCs should return within the with handle_connection() block.
+            To indicate an unsuccessful RPC, the with block should be followed by a return False statement (some linters
+            may count this as unreachable which is not true).
+        """
+        try:
+            # this returns the control to the RPC call
+            yield
+        # check for errors after RPC call
+        except grpc.RpcError:
+            self._handle_rpc_error(device_id)
+        except KeyError:
+            self._handle_unknown_device_id(device_id)
+
+    def train_parallel_on_server(
+        self,
+        server_device_id: str,
+        epochs: int,
+        round_no: int,
+        adaptive_threshold_value: Optional[float] = None,
+        optimizer_state: dict[str, Any] = None,
+    ):
+        with self.handle_connection(server_device_id):
+            response: TrainGlobalParallelSplitLearningResponse = self._get_connection(
+                server_device_id
+            ).TrainGlobalParallelSplitLearning(
+                connection_pb2.TrainGlobalParallelSplitLearningRequest(
+                    round_no=round_no,
+                    adaptive_threshold_value=adaptive_threshold_value,
+                    optimizer_state=state_dict_to_proto(optimizer_state),
+                )
+            )
+            return (
+                proto_to_weights(response.client_weights),
+                proto_to_weights(response.server_weights),
+                proto_to_metrics(response.metrics),
+                proto_to_state_dict(response.optimizer_state),
+                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
+            )
+        return False
+
+    def train_global_on(
+        self,
+        device_id: str,
+        epochs: int,
+        round_no: int = -1,
+        adaptive_threshold_value: Optional[float] = None,
+        optimizer_state: dict[str, Any] = None,
+    ) -> Union[
+        Tuple[
+            Dict[str, Any],
+            Dict[str, Any],
+            ModelMetricResultContainer,
+            Dict[str, Any],
+            DiagnosticMetricResultContainer,
+        ],
+        bool,
+    ]:
+        with self.handle_connection(device_id):
+            response: TrainGlobalResponse = self._get_connection(device_id).TrainGlobal(
+                connection_pb2.TrainGlobalRequest(
+                    epochs=epochs,
+                    round_no=round_no,
+                    adaptive_threshold_value=adaptive_threshold_value,
+                    optimizer_state=state_dict_to_proto(optimizer_state),
+                )
+            )
+            return (
+                proto_to_weights(response.client_weights),
+                proto_to_weights(response.server_weights),
+                proto_to_metrics(response.metrics),
+                proto_to_state_dict(response.optimizer_state),
+                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
+            )
+        return False
+
+    def set_weights_on(
+        self, device_id: str, state_dict, on_client: bool, wait_for_ready: bool = False
+    ):
+        with self.handle_connection(device_id):
+            self._get_connection(device_id).SetWeights(
+                SetWeightsRequest(
+                    weights=Weights(weights=state_dict_to_proto(state_dict)),
+                    on_client=on_client,
+                ),
+                wait_for_ready=wait_for_ready,
+            )
+            return True
+        return False
+
+    def train_epoch_on(self, device_id: str, server_device: str, round_no: int = -1):
+        with self.handle_connection(device_id):
+            response: TrainEpochResponse = self._get_connection(device_id).TrainEpoch(
+                connection_pb2.TrainEpochRequest(
+                    server=DeviceInfo(
+                        device_id=server_device,
+                        address=self.__get_device_address__(server_device),
+                    ),
+                    round_no=round_no,
+                )
+            )
+            return proto_to_state_dict(
+                response.weights.weights
+            ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
+        return False
+
+    def train_batch_on(self, device_id: str, smashed_data, labels):
+        with self.handle_connection(device_id):
+            request = TrainBatchRequest(
+                smashed_data=Activations(activations=tensor_to_proto(smashed_data)),
+                labels=Labels(labels=tensor_to_proto(labels)),
+            )
+            response: TrainBatchResponse = self._get_connection(device_id).TrainBatch(
+                request
+            )
+            return (
+                proto_to_tensor(response.gradients.gradients),
+                response.loss,
+                self._add_byte_size_to_diagnostic_metrics(
+                    response, self.device_id, request
+                ),
+            )
+        return False
+
+    def evaluate_global_on(self, device_id: str, val: bool = True, fed: bool = False):
+        with self.handle_connection(device_id):
+            response: EvalGlobalResponse = self._get_connection(
+                device_id
+            ).EvaluateGlobal(
+                connection_pb2.EvalGlobalRequest(validation=val, federated=fed)
+            )
+            return proto_to_metrics(
+                response.metrics
+            ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
+        return False
+
+    def evaluate_on(self, device_id: str, server_device: str, val: bool):
+        with self.handle_connection(device_id):
+            response: EvalResponse = self._get_connection(device_id).Evaluate(
+                connection_pb2.EvalRequest(
+                    server=DeviceInfo(
+                        device_id=server_device,
+                        address=self.__get_device_address__(server_device),
+                    ),
+                    validation=val,
+                )
+            )
+            return self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
+        return False
+
+    def evaluate_batch_on(self, device_id: str, smashed_data, labels):
+        with self.handle_connection(device_id):
+            request = connection_pb2.EvalBatchRequest(
+                smashed_data=Activations(activations=tensor_to_proto(smashed_data)),
+                labels=Labels(labels=tensor_to_proto(labels)),
+            )
+            response: EvalBatchResponse = self._get_connection(device_id).EvaluateBatch(
+                request
+            )
+            return self._add_byte_size_to_diagnostic_metrics(
+                response, self.device_id, request
+            )
+        return False
+
+    def federated_train_on(self, device_id: str, round_no: int = -1) -> Union[
+        Tuple[
+            StateDict,
+            StateDict,
+            int,
+        ],
+        bool,
+    ]:
+        with self.handle_connection(device_id):
+            response: FullModelTrainResponse = self._get_connection(
+                device_id
+            ).FullModelTraining(connection_pb2.FullModelTrainRequest(round_no=round_no))
+            return (
+                proto_to_weights(response.client_weights),
+                proto_to_weights(response.server_weights),
+                response.num_samples,
+                proto_to_metrics(response.metrics),
+                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
+            )
+        return False
+
+    def start_experiment_on(self, device_id: str, wait_for_ready: bool = False) -> bool:
+        with self.handle_connection(device_id):
+            self._get_connection(device_id).StartExperiment(
+                connection_pb2.StartExperimentRequest(), wait_for_ready=wait_for_ready
+            )
+            return True
+        return False
+
+    def end_experiment_on(self, device_id: str) -> bool:
+        with self.handle_connection(device_id):
+            self._get_connection(device_id).EndExperiment(
+                connection_pb2.EndExperimentRequest()
+            )
+            return True
+        return False
+
+    def get_battery_status_on(self, device_id: str) -> DeviceBatteryStatusReport:
+        with self.handle_connection(device_id):
+            response: BatteryStatusResponse = self._get_connection(
+                device_id
+            ).GetBatteryStatus(connection_pb2.BatteryStatusRequest())
+            return DeviceBatteryStatus(
+                current_capacity=response.status.current_battery_level,
+                initial_capacity=response.status.initial_battery_level,
+            )
+        return False
+
+    def train_batch_on_client_only(
+        self, device_id: str, batch_index: int, round_no: int
+    ) -> Tuple[Tensor, Tensor] | None:
+        with self.handle_connection(device_id):
+            response: SingleBatchTrainingResponse = self._get_connection(
+                device_id
+            ).TrainSingleBatchOnClient(
+                connection_pb2.SingleBatchTrainingRequest(
+                    batch_index=batch_index, round_no=round_no
+                )
+            )
+
+            # The response can only be None if the last batch was smaller than the configured batch size.
+            if response.HasField("smashed_data"):
+                return (
+                    proto_to_tensor(response.smashed_data.activations),
+                    proto_to_tensor(response.labels.labels),
+                )
+
+            return None
+        return False
+
+    def get_dataset_model_info_on(
+        self, device_id: str
+    ) -> Union[Tuple[int, int, int, int, int, int], bool]:
+        with self.handle_connection(device_id):
+            response: DatasetModelInfoResponse = self._get_connection(
+                device_id
+            ).GetDatasetModelInfo(connection_pb2.DatasetModelInfoRequest())
+            return (
+                response.train_samples,
+                response.validation_samples,
+                response.client_fw_flops,
+                response.server_fw_flops,
+                response.client_bw_flops,
+                response.server_bw_flops,
+            )
+        return False
+
+    def backpropagation_on_client_only(self, device_id, gradients):
+        with self.handle_connection(device_id):
+            response: SingleBatchBackwardResponse = self._get_connection(
+                device_id
+            ).BackwardPropagationSingleBatchOnClient(
+                connection_pb2.SingleBatchBackwardRequest(
+                    gradients=Gradients(gradients=tensor_to_proto(gradients))
+                )
+            )
+            return (
+                None,
+                proto_to_tensor(response.gradients.gradients),
+            )
+        return False
+
+    def set_gradient_and_finalize_training_on_client_only(
+        self, client_id: str, gradients: Any
+    ):
+        with self.handle_connection(client_id):
+            response: Empty = self._get_connection(
+                client_id
+            ).SetGradientsAndFinalizeTrainingStep(
+                connection_pb2.SetGradientsRequest(
+                    gradients=Gradients(gradients=tensor_to_proto(gradients))
+                )
+            )
+            return response
+        return False
diff --git a/edml/core/rpc_device_servicer.py b/edml/core/rpc_device_servicer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbaddec41b5cd8eec2bead163b3119e548218af6
--- /dev/null
+++ b/edml/core/rpc_device_servicer.py
@@ -0,0 +1,221 @@
+from __future__ import annotations
+
+from edml.core.device import NetworkDevice
+from edml.generated import connection_pb2
+from edml.generated.connection_pb2 import (
+    StartExperimentResponse,
+    EndExperimentResponse,
+    SingleBatchBackwardRequest,
+    SetGradientsRequest,
+)
+from edml.generated.connection_pb2_grpc import DeviceServicer
+from edml.generated.datastructures_pb2 import (
+    Weights,
+    Gradients,
+    BatteryStatus,
+    Activations,
+    Labels,
+    Empty,
+)
+from edml.helpers.proto_helpers import (
+    proto_to_state_dict,
+    state_dict_to_proto,
+    metrics_to_proto,
+    proto_to_tensor,
+    tensor_to_proto,
+)
+
+
+class RPCDeviceServicer(DeviceServicer):
+    """
+    Implements the handling for incoming gRPC calls according to the protobuf definition. Added to the gRPC server upon
+    starting a device. Deserializes protobufs to python objects.
+
+    Attributes:
+        device (NetworkDevice): This device to receive the calls.
+    """
+
+    def __init__(self, device: NetworkDevice):
+        self.device = device
+
+    def TrainGlobal(self, request, context):
+        print(f"Called TrainGlobal on device {self.device.device_id}")
+        client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
+            self.device.train_global(
+                request.epochs,
+                request.round_no,
+                request.adaptive_threshold_value,
+                proto_to_state_dict(request.optimizer_state),
+            )
+        )
+        response = connection_pb2.TrainGlobalResponse(
+            client_weights=Weights(weights=state_dict_to_proto(client_weights)),
+            server_weights=Weights(weights=state_dict_to_proto(server_weights)),
+            metrics=metrics_to_proto(metrics),
+            optimizer_state=state_dict_to_proto(optimizer_state),
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
+        )
+        return response
+
+    def SetWeights(self, request, context):
+        print(f"Called SetWeights on device {self.device.device_id}")
+        weights = proto_to_state_dict(request.weights.weights)
+        self.device.set_weights(weights, request.on_client)
+        return connection_pb2.SetWeightsResponse()
+
+    def TrainEpoch(self, request, context):
+        print(f"Called TrainEpoch on device {self.device.device_id}")
+        device_info = request.server
+        device_id = device_info.device_id
+        round_no = request.round_no
+        weights, diagnostic_metrics = self.device.train_epoch(device_id, round_no)
+        proto_weights = state_dict_to_proto(weights)
+        return connection_pb2.TrainEpochResponse(
+            weights=Weights(weights=proto_weights),
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
+        )
+
+    def TrainBatch(self, request, context):
+        activations = proto_to_tensor(request.smashed_data.activations)
+        labels = proto_to_tensor(request.labels.labels)
+        gradients, loss, diagnostic_metrics = self.device.train_batch(
+            activations, labels
+        )
+        proto_gradients = Gradients(gradients=tensor_to_proto(gradients))
+        return connection_pb2.TrainBatchResponse(
+            gradients=proto_gradients,
+            loss=loss,
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
+        )
+
+    def EvaluateGlobal(self, request, context):
+        print(f"Called EvaluateGlobal on device {self.device.device_id}")
+        metrics, diagnostic_metrics = self.device.evaluate_global(
+            request.validation, request.federated
+        )
+        return connection_pb2.EvalGlobalResponse(
+            metrics=metrics_to_proto(metrics),
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
+        )
+
+    def Evaluate(self, request, context):
+        print(f"Called Evaluate on device {self.device.device_id}")
+        diagnostic_metrics = self.device.evaluate(
+            request.server.device_id, request.validation
+        )
+        return connection_pb2.EvalResponse(
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics)
+        )
+
+    def EvaluateBatch(self, request, context):
+        activations = proto_to_tensor(request.smashed_data.activations)
+        labels = proto_to_tensor(request.labels.labels)
+        diagnostic_metrics = self.device.evaluate_batch(activations, labels)
+        return connection_pb2.EvalBatchResponse(
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics)
+        )
+
+    def FullModelTraining(self, request, context):
+        print(f"Called Full Training on device {self.device.device_id}")
+        client_weights, server_weights, num_samples, metrics, diagnostic_metrics = (
+            self.device.federated_train(request.round_no)
+        )
+        return connection_pb2.FullModelTrainResponse(
+            client_weights=Weights(weights=state_dict_to_proto(client_weights)),
+            server_weights=Weights(weights=state_dict_to_proto(server_weights)),
+            num_samples=num_samples,
+            metrics=metrics_to_proto(metrics),
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
+        )
+
+    def StartExperiment(self, request, context) -> StartExperimentResponse:
+        print(f"Start Experiment on device {self.device.device_id}")
+        self.device.start_experiment()
+        return connection_pb2.StartExperimentResponse()
+
+    def EndExperiment(self, request, context) -> EndExperimentResponse:
+        print(f"End Experiment on device {self.device.device_id}")
+        print(f"Remaining battery capacity {self.device.battery.remaining_capacity()}")
+        self.device.end_experiment()
+        return connection_pb2.EndExperimentResponse()
+
+    def GetBatteryStatus(self, request, context):
+        print(f"Get Battery Status on device {self.device.device_id}")
+        initial_capacity, remaining_capacity = self.device.get_battery_status()
+
+        return connection_pb2.BatteryStatusResponse(
+            status=BatteryStatus(
+                initial_battery_level=initial_capacity,
+                current_battery_level=remaining_capacity,
+            )
+        )
+
+    def GetDatasetModelInfo(self, request, context):
+        print(f"Get Dataset and Model Info on device {self.device.device_id}")
+        return connection_pb2.DatasetModelInfoResponse(
+            train_samples=len(self.device.client._train_data.dataset),
+            validation_samples=len(self.device.client._val_data.dataset),
+            client_fw_flops=int(self.device.client._model_flops["FW"]),
+            server_fw_flops=int(self.device.server._model_flops["FW"]),
+            client_bw_flops=int(self.device.client._model_flops["BW"]),
+            server_bw_flops=int(self.device.server._model_flops["BW"]),
+        )
+
+    def TrainGlobalParallelSplitLearning(self, request, context):
+        print(f"Starting parallel split learning")
+        clients = self.device.__get_device_ids__()
+        round_no = request.round_no
+        adaptive_threshold_value = request.adaptive_threshold_value
+        optimizer_state = proto_to_state_dict(request.optimizer_state)
+
+        cw, sw, model_metrics, optimizer_state, diagnostic_metrics = (
+            self.device.train_parallel_split_learning(
+                clients=clients,
+                round_no=round_no,
+                adaptive_threshold_value=adaptive_threshold_value,
+                optimizer_state=optimizer_state,
+            )
+        )
+        response = connection_pb2.TrainGlobalParallelSplitLearningResponse(
+            client_weights=Weights(weights=state_dict_to_proto(cw)),
+            server_weights=Weights(weights=state_dict_to_proto(sw)),
+            metrics=metrics_to_proto(model_metrics),
+            optimizer_state=state_dict_to_proto(optimizer_state),
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
+        )
+        return response
+
+    def TrainSingleBatchOnClient(self, request, context):
+        batch_index = request.batch_index
+        round_no = request.round_no
+
+        smashed_data, labels = self.device.client.train_single_batch(
+            batch_index, round_no=round_no
+        )
+
+        smashed_data = Activations(activations=tensor_to_proto(smashed_data))
+        labels = Labels(labels=tensor_to_proto(labels))
+        return connection_pb2.SingleBatchTrainingResponse(
+            smashed_data=smashed_data,
+            labels=labels,
+        )
+
+    def BackwardPropagationSingleBatchOnClient(
+        self, request: SingleBatchBackwardRequest, context
+    ):
+        gradients = proto_to_tensor(request.gradients.gradients)
+
+        metrics, gradients = self.device.client.backward_single_batch(
+            gradients=gradients
+        )
+        return connection_pb2.SingleBatchBackwardResponse(
+            metrics=metrics_to_proto(metrics),
+            gradients=Gradients(gradients=tensor_to_proto(gradients)),
+        )
+
+    def SetGradientsAndFinalizeTrainingStep(
+        self, request: SetGradientsRequest, context
+    ):
+        gradients = proto_to_tensor(request.gradients.gradients)
+        self.device.client.set_gradient_and_finalize_training(gradients=gradients)
+        return Empty()
diff --git a/edml/core/start_device.py b/edml/core/start_device.py
index 49df053b5073a334d185ce21bc30c2188a545716..79222554154103605e88a25f3bd21cebec9db7ef 100644
--- a/edml/core/start_device.py
+++ b/edml/core/start_device.py
@@ -8,7 +8,8 @@ from torch import nn
 
 from edml.core.battery import Battery
 from edml.core.client import DeviceClient
-from edml.core.device import NetworkDevice, RPCDeviceServicer
+from edml.core.device import NetworkDevice
+from edml.core.rpc_device_servicer import RPCDeviceServicer
 from edml.core.server import DeviceServer
 from edml.dataset_utils.mnist.mnist import single_batch_dataloaders
 from edml.generated import connection_pb2_grpc
diff --git a/edml/tests/controllers/base_controller_test.py b/edml/tests/controllers/base_controller_test.py
index 5238ae39c2e535ac22443d8727a13b0939821e84..f2b01b4ccfa62382fe7bcbccff97e0a5d40a05df 100644
--- a/edml/tests/controllers/base_controller_test.py
+++ b/edml/tests/controllers/base_controller_test.py
@@ -2,7 +2,7 @@ import unittest
 from unittest.mock import Mock
 
 from edml.controllers.base_controller import BaseController
-from edml.core.device import DeviceRequestDispatcher
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
 from edml.tests.controllers.test_helper import load_sample_config
 
 
diff --git a/edml/tests/controllers/fed_controller_test.py b/edml/tests/controllers/fed_controller_test.py
index 45c4ef900c74a02035674ad7cf96be12bd95a5af..df07b2b0f8cc3724de47216a1fba97e4293ba524 100644
--- a/edml/tests/controllers/fed_controller_test.py
+++ b/edml/tests/controllers/fed_controller_test.py
@@ -6,7 +6,7 @@ from omegaconf import DictConfig
 from torch import Tensor
 
 from edml.controllers.fed_controller import FedController, fed_average
-from edml.core.device import DeviceRequestDispatcher
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
 from edml.generated.connection_pb2 import FullModelTrainResponse
 from edml.helpers.metrics import (
     ModelMetricResultContainer,
diff --git a/edml/tests/controllers/split_controller_test.py b/edml/tests/controllers/split_controller_test.py
index 8531a2d1fe0aa146494d1ffb45317739a1c01dc1..c5533a1e38bb490f00ae714fae87fbbcd5b01aca 100644
--- a/edml/tests/controllers/split_controller_test.py
+++ b/edml/tests/controllers/split_controller_test.py
@@ -2,7 +2,7 @@ import unittest
 from unittest.mock import Mock
 
 from edml.controllers.split_controller import SplitController
-from edml.core.device import DeviceRequestDispatcher
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
 from edml.helpers.metrics import (
     ModelMetricResultContainer,
     DiagnosticMetricResultContainer,
diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py
index 89ad8aa1fe830bc397ff2831cd037df49ba3b4b7..9aa29a7d64befd7bf85b016bf809152e43487da7 100644
--- a/edml/tests/controllers/swarm_controller_test.py
+++ b/edml/tests/controllers/swarm_controller_test.py
@@ -3,12 +3,9 @@ from unittest.mock import Mock, call
 
 from omegaconf import ListConfig
 
-from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
-from edml.controllers.scheduler.random import RandomNextServerScheduler
 from edml.controllers.scheduler.sequential import SequentialNextServerScheduler
-from edml.controllers.scheduler.smart import SmartNextServerScheduler
 from edml.controllers.swarm_controller import SwarmController
-from edml.core.device import DeviceRequestDispatcher
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
 from edml.helpers.metrics import (
     ModelMetricResultContainer,
     DiagnosticMetricResultContainer,
diff --git a/edml/tests/controllers/test_controller_test.py b/edml/tests/controllers/test_controller_test.py
index 848071a068d6418f6ac44303f83ada770755bbe9..f63fc5d21afad4de9ae1403f4834eb5a0fa56f63 100644
--- a/edml/tests/controllers/test_controller_test.py
+++ b/edml/tests/controllers/test_controller_test.py
@@ -2,7 +2,7 @@ import unittest
 from unittest.mock import patch, Mock, call
 
 from edml.controllers.test_controller import TestController
-from edml.core.device import DeviceRequestDispatcher
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
 from edml.tests.controllers.test_helper import load_sample_config
 
 
diff --git a/edml/tests/core/device_request_dispatcher_test.py b/edml/tests/core/device_request_dispatcher_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6de7ec4be7a87b578a01088835c9f4b3fd13763
--- /dev/null
+++ b/edml/tests/core/device_request_dispatcher_test.py
@@ -0,0 +1,538 @@
+import concurrent.futures
+import math
+import threading
+import unittest
+from inspect import signature
+from unittest.mock import Mock, patch
+
+import grpc
+from omegaconf import DictConfig
+from torch import Tensor
+
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
+from edml.generated import connection_pb2
+from edml.generated.connection_pb2 import (
+    SetWeightsResponse,
+    TrainEpochResponse,
+    EvalResponse,
+    EvalBatchResponse,
+)
+from edml.generated.datastructures_pb2 import (
+    DeviceInfo,
+    Activations,
+    Labels,
+    BatteryStatus,
+)
+from edml.helpers.metrics import (
+    ModelMetricResultContainer,
+    ModelMetricResult,
+    DiagnosticMetricResultContainer,
+    DiagnosticMetricResult,
+)
+from edml.helpers.proto_helpers import (
+    weights_to_proto,
+    metrics_to_proto,
+    state_dict_to_proto,
+    gradients_to_proto,
+    tensor_to_proto,
+)
+from edml.helpers.types import DeviceBatteryStatus
+from edml.tests.controllers.test_helper import get_side_effect
+
+
+class RequestDispatcherTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.dispatcher = DeviceRequestDispatcher(
+            []
+        )  # pass no devices to avoid grpc calls
+        # mock connections instead
+        self.dispatcher.connections["1"] = Mock(
+            spec=[
+                "TrainGlobal",
+                "SetWeights",
+                "TrainEpoch",
+                "TrainBatch",
+                "EvaluateGlobal",
+                "Evaluate",
+                "EvaluateBatch",
+                "FullModelTraining",
+                "StartExperiment",
+                "EndExperiment",
+                "GetBatteryStatus",
+                "GetDatasetModelInfo",
+            ]
+        )
+        self.dispatcher.devices = [
+            DictConfig({"device_id": "0", "address": "42"}),  # inactive device
+            DictConfig({"device_id": "1", "address": "43"}),
+        ]
+        self.mock_stub = self.dispatcher.connections["1"]  # for convenience
+        self.weights = {"weights": Tensor([42])}
+        self.gradients = Tensor([42])
+        self.activations = Tensor([1])
+        self.labels = Tensor([1])
+        self.metrics = ModelMetricResultContainer(
+            [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)]
+        )
+        self.diagnostic_metrics = DiagnosticMetricResultContainer(
+            [DiagnosticMetricResult("d1", "comp_time", "train", 42)]
+        )
+
+    def test_train_global_on_without_error(self):
+        self.mock_stub.TrainGlobal.return_value = connection_pb2.TrainGlobalResponse(
+            client_weights=weights_to_proto(self.weights),
+            server_weights=weights_to_proto(self.weights),
+            metrics=metrics_to_proto(self.metrics),
+            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
+            optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
+        )
+
+        client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
+            self.dispatcher.train_global_on("1", 42, 43, 3, {"optimizer_state": 44})
+        )
+
+        self.assertEqual(client_weights, self.weights)
+        self.assertEqual(server_weights, self.weights)
+        self.assertEqual(metrics, self.metrics)
+        self.assertEqual(optimizer_state, {"optimizer_state": 42})
+
+        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
+        self.mock_stub.TrainGlobal.assert_called_once_with(
+            connection_pb2.TrainGlobalRequest(
+                epochs=42,
+                round_no=43,
+                adaptive_threshold_value=3,
+                optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
+            )
+        )
+
+    def test_train_global_on_with_error(self):
+        self.mock_stub.TrainGlobal.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.train_global_on(
+            "1",
+            42,
+            round_no=43,
+            adaptive_threshold_value=3,
+            optimizer_state={"optimizer_state": 44},
+        )
+
+        self.assertEqual(response, False)
+        self.mock_stub.TrainGlobal.assert_called_once_with(
+            connection_pb2.TrainGlobalRequest(
+                epochs=42,
+                round_no=43,
+                adaptive_threshold_value=3,
+                optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
+            )
+        )
+
+    def test_set_weights_on_without_error(self):
+        self.mock_stub.SetWeights.return_value = SetWeightsResponse()
+
+        self.dispatcher.set_weights_on("1", self.weights, True)
+
+        self.mock_stub.SetWeights.assert_called_once_with(
+            connection_pb2.SetWeightsRequest(
+                weights=weights_to_proto(self.weights), on_client=True
+            ),
+            wait_for_ready=False,
+        )
+
+    def test_set_weights_on_with_error(self):
+        self.mock_stub.SetWeights.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.set_weights_on("1", self.weights, True)
+
+        self.mock_stub.SetWeights.assert_called_once_with(
+            connection_pb2.SetWeightsRequest(
+                weights=weights_to_proto(self.weights), on_client=True
+            ),
+            wait_for_ready=False,
+        )
+        self.assertEqual(response, False)
+
+    def test_train_epoch_on_without_error(self):
+        self.mock_stub.TrainEpoch.return_value = TrainEpochResponse(
+            weights=weights_to_proto(self.weights),
+            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
+        )
+
+        weights, diagnostic_metrics = self.dispatcher.train_epoch_on("1", "0", 42)
+
+        self.assertEqual(weights, self.weights)
+        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
+        self.mock_stub.TrainEpoch.assert_called_once_with(
+            connection_pb2.TrainEpochRequest(
+                server=DeviceInfo(device_id="0", address="42"), round_no=42
+            )
+        )
+
+    def test_train_epoch_on_with_error(self):
+        self.mock_stub.TrainEpoch.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.train_epoch_on("1", "0", 42)
+
+        self.assertEqual(response, False)
+        self.mock_stub.TrainEpoch.assert_called_once_with(
+            connection_pb2.TrainEpochRequest(
+                server=DeviceInfo(device_id="0", address="42"), round_no=42
+            )
+        )
+
+    def test_train_batch_on_without_error(self):
+        self.mock_stub.TrainBatch.return_value = connection_pb2.TrainBatchResponse(
+            gradients=gradients_to_proto(self.gradients),
+            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
+            loss=42.0,
+        )
+
+        gradients, loss, diagnostic_metrics = self.dispatcher.train_batch_on(
+            "1", self.activations, self.labels
+        )
+
+        self.assertEqual(gradients, self.gradients)
+        self.assertEqual(loss, 42.0)
+        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
+        self.mock_stub.TrainBatch.assert_called_once_with(
+            connection_pb2.TrainBatchRequest(
+                smashed_data=Activations(activations=tensor_to_proto(self.activations)),
+                labels=Labels(labels=tensor_to_proto(self.labels)),
+            )
+        )
+
+    def test_train_batch_on_with_error(self):
+        self.mock_stub.TrainBatch.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.train_batch_on("1", self.activations, self.labels)
+
+        self.assertEqual(response, False)
+        self.mock_stub.TrainBatch.assert_called_once_with(
+            connection_pb2.TrainBatchRequest(
+                smashed_data=Activations(activations=tensor_to_proto(self.activations)),
+                labels=Labels(labels=tensor_to_proto(self.labels)),
+            )
+        )
+
+    def test_evaluate_global_on_without_error(self):
+        self.mock_stub.EvaluateGlobal.return_value = connection_pb2.EvalGlobalResponse(
+            metrics=metrics_to_proto(self.metrics),
+            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
+        )
+
+        metrics, diagnostic_metrics = self.dispatcher.evaluate_global_on(
+            "1", True, True
+        )
+
+        self.assertEqual(metrics, self.metrics)
+        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
+        self.mock_stub.EvaluateGlobal.assert_called_once_with(
+            connection_pb2.EvalGlobalRequest(validation=True, federated=True)
+        )
+
+    def test_evaluate_global_on_with_error(self):
+        self.mock_stub.EvaluateGlobal.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.evaluate_global_on("1", True, True)
+
+        self.assertEqual(response, False)
+        self.mock_stub.EvaluateGlobal.assert_called_once_with(
+            connection_pb2.EvalGlobalRequest(validation=True, federated=True)
+        )
+
+    def test_evaluate_on_without_error(self):
+        self.mock_stub.Evaluate.return_value = EvalResponse(
+            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics)
+        )
+
+        diagnostic_metrics = self.dispatcher.evaluate_on("1", "0", True)
+
+        self.assertEqual(diagnostic_metrics, self.diagnostic_metrics)
+        self.mock_stub.Evaluate.assert_called_once_with(
+            connection_pb2.EvalRequest(
+                server=DeviceInfo(device_id="0", address="42"), validation=True
+            )
+        )
+
+    def test_evaluate_on_with_error(self):
+        self.mock_stub.Evaluate.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.evaluate_on("1", "0", True)
+
+        self.assertEqual(response, False)
+        self.mock_stub.Evaluate.assert_called_once_with(
+            connection_pb2.EvalRequest(
+                server=DeviceInfo(device_id="0", address="42"), validation=True
+            )
+        )
+
+    def test_evaluate_batch_on_without_error(self):
+        self.mock_stub.EvaluateBatch.return_value = EvalBatchResponse(
+            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics)
+        )
+
+        diagnostic_metrics = self.dispatcher.evaluate_batch_on(
+            "1", self.activations, self.labels
+        )
+
+        self._assert_field_size_added_to_diagnostic_metrics(
+            diagnostic_metrics
+        )  # metric field present in response, but not used in practice. Thus, a field size is added
+        self.mock_stub.EvaluateBatch.assert_called_once_with(
+            connection_pb2.EvalBatchRequest(
+                smashed_data=Activations(activations=tensor_to_proto(self.activations)),
+                labels=Labels(labels=tensor_to_proto(self.labels)),
+            )
+        )
+
+    def test_evaluate_batch_on_with_error(self):
+        self.mock_stub.EvaluateBatch.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.evaluate_batch_on("1", self.activations, self.labels)
+
+        self.assertEqual(response, False)
+        self.mock_stub.EvaluateBatch.assert_called_once_with(
+            connection_pb2.EvalBatchRequest(
+                smashed_data=Activations(activations=tensor_to_proto(self.activations)),
+                labels=Labels(labels=tensor_to_proto(self.labels)),
+            )
+        )
+
+    def test_full_model_training_without_error(self):
+        self.mock_stub.FullModelTraining.return_value = (
+            connection_pb2.FullModelTrainResponse(
+                client_weights=weights_to_proto(self.weights),
+                server_weights=weights_to_proto(self.weights),
+                num_samples=42,
+                metrics=metrics_to_proto(self.metrics),
+                diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
+            )
+        )
+
+        client_weights, server_weights, num_samples, metrics, diagnostic_metrics = (
+            self.dispatcher.federated_train_on("1", 42)
+        )
+
+        self.assertEqual(client_weights, self.weights)
+        self.assertEqual(server_weights, self.weights)
+        self.assertEqual(num_samples, 42)
+        self.assertEqual(metrics, self.metrics)
+        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
+        self.mock_stub.FullModelTraining.assert_called_once_with(
+            connection_pb2.FullModelTrainRequest(round_no=42)
+        )
+
+    def test_full_model_training_with_error(self):
+        self.mock_stub.FullModelTraining.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.federated_train_on("1", 42)
+
+        self.assertEqual(response, False)
+        self.mock_stub.FullModelTraining.assert_called_once_with(
+            connection_pb2.FullModelTrainRequest(round_no=42)
+        )
+
+    def test_start_experiment_on_without_error(self):
+        self.mock_stub.StartExperiment.return_value = (
+            connection_pb2.StartExperimentResponse()
+        )
+
+        response = self.dispatcher.start_experiment_on("1", True)
+
+        self.assertEqual(response, True)
+        self.mock_stub.StartExperiment.assert_called_once_with(
+            connection_pb2.StartExperimentRequest(), wait_for_ready=True
+        )
+
+    def test_start_experiment_on_with_error(self):
+        self.mock_stub.StartExperiment.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.start_experiment_on("1", True)
+
+        self.assertEqual(response, False)
+        self.mock_stub.StartExperiment.assert_called_once_with(
+            connection_pb2.StartExperimentRequest(), wait_for_ready=True
+        )
+
+    def test_end_experiment_on_without_error(self):
+        self.mock_stub.EndExperiment.return_value = (
+            connection_pb2.EndExperimentResponse()
+        )
+
+        response = self.dispatcher.end_experiment_on("1")
+
+        self.assertEqual(response, True)
+        self.mock_stub.EndExperiment.assert_called_once_with(
+            connection_pb2.EndExperimentRequest()
+        )
+
+    def test_end_experiment_on_with_error(self):
+        self.mock_stub.EndExperiment.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.end_experiment_on("1")
+
+        self.assertEqual(response, False)
+        self.mock_stub.EndExperiment.assert_called_once_with(
+            connection_pb2.EndExperimentRequest()
+        )
+
+    def test_get_battery_status_without_error(self):
+        self.mock_stub.GetBatteryStatus.return_value = (
+            connection_pb2.BatteryStatusResponse(
+                status=BatteryStatus(initial_battery_level=42, current_battery_level=21)
+            )
+        )
+
+        response = self.dispatcher.get_battery_status_on("1")
+
+        self.assertEqual(
+            response, DeviceBatteryStatus(initial_capacity=42, current_capacity=21)
+        )
+        self.mock_stub.GetBatteryStatus.assert_called_once_with(
+            connection_pb2.BatteryStatusRequest()
+        )
+
+    def test_get_battery_status_with_error(self):
+        self.mock_stub.GetBatteryStatus.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.get_battery_status_on("1")
+
+        self.assertEqual(response, False)
+        self.mock_stub.GetBatteryStatus.assert_called_once_with(
+            connection_pb2.BatteryStatusRequest()
+        )
+
+    def test_get_dataset_model_info_without_error(self):
+        self.mock_stub.GetDatasetModelInfo.return_value = (
+            connection_pb2.DatasetModelInfoResponse(
+                train_samples=42,
+                validation_samples=21,
+                client_fw_flops=1,
+                server_fw_flops=2,
+                client_bw_flops=3,
+                server_bw_flops=4,
+            )
+        )
+
+        response = self.dispatcher.get_dataset_model_info_on("1")
+
+        self.assertEqual(response, (42, 21, 1, 2, 3, 4))
+        self.mock_stub.GetDatasetModelInfo.assert_called_once_with(
+            connection_pb2.DatasetModelInfoRequest()
+        )
+
+    def test_get_dataset_model_info_with_error(self):
+        self.mock_stub.GetDatasetModelInfo.side_effect = grpc.RpcError()
+
+        response = self.dispatcher.get_dataset_model_info_on("1")
+
+        self.assertEqual(response, False)
+        self.mock_stub.GetDatasetModelInfo.assert_called_once_with(
+            connection_pb2.DatasetModelInfoRequest()
+        )
+
+    def test_handle_calls_to_inactive_device(self):
+        """Test each method of the dispatcher that does RPC calls to handle calls to inactive devices.
+        One test where the device is known, but inactive and one where the device is unknown.
+        """
+        methods_names = [
+            "train_global_on",
+            "set_weights_on",
+            "train_epoch_on",
+            "train_batch_on",
+            "evaluate_global_on",
+            "evaluate_on",
+            "evaluate_batch_on",
+            "federated_train_on",
+            "start_experiment_on",
+            "end_experiment_on",
+            "get_battery_status_on",
+            "get_dataset_model_info_on",
+        ]
+        for device in [("0", True), ("2", False)]:  # 0 is inactive, 2 is unknown
+            for method_name in methods_names:
+                with self.subTest(method_name=f"{method_name}_device{device[0]}"):
+                    with patch("builtins.print") as print_patch:
+                        method = getattr(self.dispatcher, method_name)
+                        params = list(signature(method).parameters)
+
+                        # make method call with device id and all other params as None
+                        response = method(
+                            device[0],
+                            *[None for _ in params if _ not in ["self", "device_id"]],
+                        )
+
+                        self.assertEqual(response, False)
+                        self.mock_stub.assert_not_called()
+                        if device[1]:
+                            print_patch.assert_called_once_with(
+                                f"Device {device[0]} not active"
+                            )
+                        else:
+                            print_patch.assert_called_once_with(
+                                f"Unknown Device ID {device[0]}"
+                            )
+
+    def _assert_field_size_added_to_diagnostic_metrics(self, diagnostic_metrics):
+        """Not to use when response only includes diagnostic metrics, as these are ignored for the field size"""
+        self.assertEqual(
+            diagnostic_metrics.get_as_list()[0],
+            self.diagnostic_metrics.get_as_list()[0],
+        )  # check that previous diagnostic metrics still exist
+        self.assertEqual(
+            diagnostic_metrics.get_as_list()[1].name, "size"
+        )  # check that at least one new diagnostic metric for size was added
+
+
+class RequestDispatcherThreadingTest(unittest.TestCase):
+
+    def test_handle_rpc_error_thread_safety(self):
+
+        # run the procedure with 2 to 20 threads
+        for n in range(2, 21):
+            with self.subTest(method_name=f"Test with {n} threads"):
+                print(f"Test with {n} threads")
+                self.request_dispatcher = DeviceRequestDispatcher([])
+
+                # mock n connections with half of them returning errors and the other half valid responses
+                self.request_dispatcher.connections = {
+                    f"d{i}": Mock(spec=["GetBatteryStatus"]) for i in range(n)
+                }
+                response = connection_pb2.BatteryStatusResponse(
+                    status=BatteryStatus(
+                        initial_battery_level=42, current_battery_level=21
+                    )
+                )
+                for i, mock in enumerate(self.request_dispatcher.connections.values()):
+                    mock.GetBatteryStatus.side_effect = get_side_effect(i, response)
+
+                # make n concurrent calls to the dispatcher
+                responses = []
+                response_lock = threading.Lock()
+                with concurrent.futures.ThreadPoolExecutor(max_workers=n) as executor:
+                    futures = [
+                        executor.submit(
+                            self.request_dispatcher.get_battery_status_on, device_id
+                        )
+                        for device_id in self.request_dispatcher.active_devices()
+                    ]
+                    for future in concurrent.futures.as_completed(futures):
+                        with response_lock:
+                            responses.append(future.result())
+
+                # check that only error connections were removed
+                self.assertEqual(
+                    list(self.request_dispatcher.connections.keys()),
+                    [f"d{i}" for i in range(n) if i % 2 == 0],
+                )
+
+                response_battery_level = DeviceBatteryStatus(
+                    initial_capacity=42, current_capacity=21
+                )
+
+                # check that there are n responses in total, half of them valid and half of them None
+                self.assertEqual(responses.count(False), math.floor(n / 2))
+                self.assertEqual(
+                    responses.count(response_battery_level), math.ceil(n / 2)
+                )
diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py
index 85b0e1f7f7204efa6a66b3c95f89a5c392098855..809c69587676f3024fff22af46e17e1347469e4b 100644
--- a/edml/tests/core/device_test.py
+++ b/edml/tests/core/device_test.py
@@ -1,53 +1,24 @@
-import concurrent.futures
-import math
-import threading
 import unittest
-from inspect import signature
-from unittest.mock import Mock, patch
+from unittest.mock import Mock
 
-import grpc
-from grpc import StatusCode
-from grpc_testing import server_from_dictionary, strict_real_time
 from omegaconf import DictConfig
 from torch import Tensor
 
-from edml.core.battery import Battery, BatteryEmptyException
+from edml.core.battery import Battery
 from edml.core.client import DeviceClient
-from edml.core.device import RPCDeviceServicer, NetworkDevice, DeviceRequestDispatcher
+from edml.core.device import NetworkDevice
 from edml.core.server import DeviceServer
-from edml.generated import connection_pb2, datastructures_pb2
+from edml.generated import connection_pb2
 from edml.generated.connection_pb2 import (
-    EvalResponse,
-    EvalBatchResponse,
-    SetWeightsResponse,
     TrainEpochResponse,
 )
 from edml.generated.datastructures_pb2 import (
-    Activations,
-    Labels,
     Weights,
     DeviceInfo,
-    BatteryStatus,
-)
-from edml.helpers.metrics import (
-    ModelMetricResult,
-    ModelMetricResultContainer,
-    DiagnosticMetricResultContainer,
-    DiagnosticMetricResult,
 )
 from edml.helpers.proto_helpers import (
-    tensor_to_proto,
-    proto_to_tensor,
     state_dict_to_proto,
-    proto_to_state_dict,
-    proto_to_weights,
-    weights_to_proto,
-    gradients_to_proto,
-    metrics_to_proto,
-    proto_to_metrics,
 )
-from edml.helpers.types import DeviceBatteryStatus
-from edml.tests.controllers.test_helper import get_side_effect
 
 
 class NetworkDeviceTest(unittest.TestCase):
@@ -93,365 +64,6 @@ class NetworkDeviceTest(unittest.TestCase):
         )
 
 
-class RPCDeviceServicerTest(unittest.TestCase):
-
-    def setUp(self) -> None:
-        # instantiating NetworkDevice to mock Device(ABC)'s properties
-        # store mock_device reference for later function call assertions
-        self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery)))
-        self.mock_device.device_id = 42
-
-        my_servicer = RPCDeviceServicer(device=self.mock_device)
-        servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer}
-        self.test_server = server_from_dictionary(servicers, strict_real_time())
-
-        self.metrics = ModelMetricResultContainer(
-            [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)]
-        )
-        self.diagnostic_metrics = DiagnosticMetricResultContainer(
-            [DiagnosticMetricResult("d1", "comp_time", "train", 42)]
-        )
-
-    def make_call(self, method_name, request):
-        method = self.test_server.invoke_unary_unary(
-            method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[
-                "Device"
-            ].methods_by_name[method_name],
-            invocation_metadata={},
-            request=request,
-            timeout=None,
-        )
-        return method.termination()
-
-    def test_train_global(self):
-        self.mock_device.train_global.return_value = (
-            {"weights": Tensor([42])},
-            {"weights": Tensor([43])},
-            self.metrics,
-            {"optimizer_state": 44},
-            self.diagnostic_metrics,
-        )
-        request = connection_pb2.TrainGlobalRequest(
-            epochs=42,
-            round_no=1,
-            adaptive_threshold_value=3,
-            optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
-        )
-
-        response, metadata, code, details = self.make_call("TrainGlobal", request)
-
-        self.assertEqual(
-            (
-                proto_to_weights(response.client_weights),
-                proto_to_weights(response.server_weights),
-            ),
-            ({"weights": Tensor([42])}, {"weights": Tensor([43])}),
-        )
-        self.assertEqual(proto_to_metrics(response.metrics), self.metrics)
-        self.assertEqual(
-            proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44}
-        )
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.train_global.assert_called_once_with(
-            42, 1, 3, {"optimizer_state": 42}
-        )
-        self.assertEqual(
-            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
-        )
-
-    def test_set_weights(self):
-        request = connection_pb2.SetWeightsRequest(
-            weights=Weights(weights=state_dict_to_proto({"weights": Tensor([42])})),
-            on_client=True,
-        )
-
-        response, metadata, code, details = self.make_call("SetWeights", request)
-
-        self.assertEqual(type(response), type(SetWeightsResponse()))
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.set_weights.assert_called_once_with(
-            {"weights": Tensor([42])}, True
-        )
-
-    def test_train_epoch(self):
-        self.mock_device.train_epoch.return_value = {
-            "weights": Tensor([42])
-        }, self.diagnostic_metrics
-        request = connection_pb2.TrainEpochRequest(
-            server=datastructures_pb2.DeviceInfo(device_id="42", address="")
-        )
-
-        response, metadata, code, details = self.make_call("TrainEpoch", request)
-
-        self.assertEqual(
-            proto_to_state_dict(response.weights.weights), {"weights": Tensor([42])}
-        )
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.train_epoch.assert_called_once_with("42", 0)
-        self.assertEqual(
-            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
-        )
-
-    def test_train_batch(self):
-        self.mock_device.train_batch.return_value = (
-            Tensor([42]),
-            42.0,
-            self.diagnostic_metrics,
-        )
-        request = connection_pb2.TrainBatchRequest(
-            smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))),
-            labels=Labels(labels=tensor_to_proto(Tensor([1]))),
-        )
-
-        response, metadata, code, details = self.make_call("TrainBatch", request)
-
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.train_batch.assert_called_once_with(Tensor([1.0]), Tensor([1]))
-        self.assertEqual(proto_to_tensor(response.gradients.gradients), Tensor([42]))
-        self.assertEqual(
-            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
-        )
-
-    def test_evaluate_global(self):
-        self.mock_device.evaluate_global.return_value = (
-            self.metrics,
-            self.diagnostic_metrics,
-        )
-        request = connection_pb2.EvalGlobalRequest(validation=False, federated=False)
-
-        response, metadata, code, details = self.make_call("EvaluateGlobal", request)
-
-        self.assertEqual(proto_to_metrics(response.metrics), self.metrics)
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.evaluate_global.assert_called_once_with(False, False)
-        self.assertEqual(
-            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
-        )
-
-    def test_evaluate(self):
-        self.mock_device.evaluate.return_value = self.diagnostic_metrics
-        request = connection_pb2.EvalRequest(
-            server=DeviceInfo(device_id="42", address=""), validation=True
-        )
-
-        response, metadata, code, details = self.make_call("Evaluate", request)
-
-        self.assertEqual(type(response), type(EvalResponse()))
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.evaluate.assert_called_once_with("42", True)
-        self.assertEqual(
-            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
-        )
-
-    def test_evaluate_batch(self):
-        self.mock_device.evaluate_batch.return_value = self.diagnostic_metrics
-        request = connection_pb2.EvalBatchRequest(
-            smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))),
-            labels=Labels(labels=tensor_to_proto(Tensor([1]))),
-        )
-
-        response, metadata, code, details = self.make_call("EvaluateBatch", request)
-
-        self.assertEqual(type(response), type(EvalBatchResponse()))
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.evaluate_batch.assert_called_once_with(
-            Tensor([1.0]), Tensor([1])
-        )
-        self.assertEqual(
-            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
-        )
-
-    def test_full_model_training(self):
-        self.mock_device.federated_train.return_value = (
-            {"weights": Tensor([42])},
-            {"weights": Tensor([43])},
-            44,
-            self.metrics,
-            self.diagnostic_metrics,
-        )
-        request = connection_pb2.FullModelTrainRequest()
-
-        response, metadata, code, details = self.make_call("FullModelTraining", request)
-
-        self.assertEqual(
-            proto_to_state_dict(response.client_weights.weights),
-            {"weights": Tensor([42])},
-        )
-        self.assertEqual(
-            proto_to_state_dict(response.server_weights.weights),
-            {"weights": Tensor([43])},
-        )
-        self.assertEqual(response.num_samples, 44)
-        self.assertEqual(proto_to_metrics(response.metrics), self.metrics)
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.federated_train.assert_called_once()
-
-    def test_start_experiment(self):
-        self.mock_device.start_experiment.return_value = None
-        request = connection_pb2.StartExperimentRequest()
-
-        response, metadata, code, details = self.make_call("StartExperiment", request)
-
-        self.assertEqual(type(response), type(connection_pb2.StartExperimentResponse()))
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.start_experiment.assert_called_once()
-
-    def test_end_experiment(self):
-        self.mock_device.end_experiment.return_value = None
-        request = connection_pb2.EndExperimentRequest()
-
-        response, metadata, code, details = self.make_call("EndExperiment", request)
-
-        self.assertEqual(type(response), type(connection_pb2.EndExperimentResponse()))
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.end_experiment.assert_called_once()
-
-    def test_battery_status(self):
-        self.mock_device.get_battery_status.return_value = (42, 21)
-        request = connection_pb2.BatteryStatusRequest()
-
-        response, metadata, code, details = self.make_call("GetBatteryStatus", request)
-
-        self.assertEqual(
-            response.status,
-            BatteryStatus(initial_battery_level=42, current_battery_level=21),
-        )
-        self.assertEqual(code, StatusCode.OK)
-        self.mock_device.get_battery_status.assert_called_once()
-
-    def test_dataset_model_info(self):
-        self.mock_device.client._train_data.dataset = [1]
-        self.mock_device.client._val_data.dataset = [2]
-        self.mock_device.client._model_flops = {"FW": 3, "BW": 6}
-        self.mock_device.server._model_flops = {"FW": 4, "BW": 8}
-        request = connection_pb2.DatasetModelInfoRequest()
-
-        response, metadata, code, details = self.make_call(
-            "GetDatasetModelInfo", request
-        )
-
-        self.assertEqual(code, StatusCode.OK)
-        self.assertEqual(response.train_samples, 1)
-        self.assertEqual(response.validation_samples, 1)
-        self.assertEqual(response.client_fw_flops, 3)
-        self.assertEqual(response.server_fw_flops, 4)
-        self.assertEqual(response.client_bw_flops, 6)
-        self.assertEqual(response.server_bw_flops, 8)
-
-
-class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase):
-
-    def setUp(self) -> None:
-        # instantiating NetworkDevice to mock Device(ABC)'s properties
-        # store mock_device reference for later function call assertions
-        self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery)))
-        my_servicer = RPCDeviceServicer(device=self.mock_device)
-        servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer}
-        self.test_server = server_from_dictionary(servicers, strict_real_time())
-
-    def make_call(self, method_name, request):
-        method = self.test_server.invoke_unary_unary(
-            method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[
-                "Device"
-            ].methods_by_name[method_name],
-            invocation_metadata={},
-            request=request,
-            timeout=None,
-        )
-        return method.termination()
-
-    def test_stop_at_train_global(self):
-        self.mock_device.train_global.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.TrainGlobalRequest(
-            optimizer_state=state_dict_to_proto(None)
-        )
-        self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal")
-
-    def test_stop_at_set_weights(self):
-        self.mock_device.set_weights.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.SetWeightsRequest(weights=weights_to_proto({}))
-        self._test_device_out_of_battery_lets_rpc_fail(request, "SetWeights")
-
-    def test_stop_at_train_epoch(self):
-        self.mock_device.train_epoch.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.TrainEpochRequest(server=None)
-        self._test_device_out_of_battery_lets_rpc_fail(request, "TrainEpoch")
-
-    def test_stop_at_train_batch(self):
-        self.mock_device.train_batch.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.TrainBatchRequest(
-            smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))),
-            labels=Labels(labels=tensor_to_proto(Tensor([1]))),
-        )
-        self._test_device_out_of_battery_lets_rpc_fail(request, "TrainBatch")
-
-    def test_stop_at_evaluate_global(self):
-        self.mock_device.evaluate_global.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.EvalGlobalRequest()
-        self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateGlobal")
-
-    def test_stop_at_evaluate(self):
-        self.mock_device.evaluate.side_effect = BatteryEmptyException("Battery empty")
-        request = connection_pb2.EvalRequest(server=None)
-        self._test_device_out_of_battery_lets_rpc_fail(request, "Evaluate")
-
-    def test_stop_at_evaluate_batch(self):
-        self.mock_device.evaluate_batch.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.EvalBatchRequest(
-            smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))),
-            labels=Labels(labels=tensor_to_proto(Tensor([1]))),
-        )
-        self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateBatch")
-
-    def test_stop_at_full_model_training(self):
-        self.mock_device.federated_train.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.FullModelTrainRequest()
-        self._test_device_out_of_battery_lets_rpc_fail(request, "FullModelTraining")
-
-    def test_stop_at_start_experiment(self):
-        self.mock_device.start_experiment.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.StartExperimentRequest()
-        self._test_device_out_of_battery_lets_rpc_fail(request, "StartExperiment")
-
-    def test_stop_at_end_experiment(self):
-        self.mock_device.end_experiment.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.EndExperimentRequest()
-        self._test_device_out_of_battery_lets_rpc_fail(request, "EndExperiment")
-
-    def test_stop_at_get_battery_status(self):
-        self.mock_device.get_battery_status.side_effect = BatteryEmptyException(
-            "Battery empty"
-        )
-        request = connection_pb2.BatteryStatusRequest()
-        self._test_device_out_of_battery_lets_rpc_fail(request, "GetBatteryStatus")
-
-    def _test_device_out_of_battery_lets_rpc_fail(self, request, servicer_method_name):
-        response, metadata, code, details = self.make_call(
-            servicer_method_name, request
-        )
-        self.assertIsNone(response)
-        self.assertEqual(code, StatusCode.UNKNOWN)
-        self.assertEqual(details, "Exception calling application: Battery empty")
-
-
 class BatteryUpdateTest(unittest.TestCase):
 
     def setUp(self):
@@ -462,501 +74,3 @@ class BatteryUpdateTest(unittest.TestCase):
         device.set_client(Mock(DeviceClient))
         device.train_epoch(None)
         self.battery.update_time.assert_called()
-
-
-class RequestDispatcherTest(unittest.TestCase):
-
-    def setUp(self) -> None:
-        self.dispatcher = DeviceRequestDispatcher(
-            []
-        )  # pass no devices to avoid grpc calls
-        # mock connections instead
-        self.dispatcher.connections["1"] = Mock(
-            spec=[
-                "TrainGlobal",
-                "SetWeights",
-                "TrainEpoch",
-                "TrainBatch",
-                "EvaluateGlobal",
-                "Evaluate",
-                "EvaluateBatch",
-                "FullModelTraining",
-                "StartExperiment",
-                "EndExperiment",
-                "GetBatteryStatus",
-                "GetDatasetModelInfo",
-            ]
-        )
-        self.dispatcher.devices = [
-            DictConfig({"device_id": "0", "address": "42"}),  # inactive device
-            DictConfig({"device_id": "1", "address": "43"}),
-        ]
-        self.mock_stub = self.dispatcher.connections["1"]  # for convenience
-        self.weights = {"weights": Tensor([42])}
-        self.gradients = Tensor([42])
-        self.activations = Tensor([1])
-        self.labels = Tensor([1])
-        self.metrics = ModelMetricResultContainer(
-            [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)]
-        )
-        self.diagnostic_metrics = DiagnosticMetricResultContainer(
-            [DiagnosticMetricResult("d1", "comp_time", "train", 42)]
-        )
-
-    def test_train_global_on_without_error(self):
-        self.mock_stub.TrainGlobal.return_value = connection_pb2.TrainGlobalResponse(
-            client_weights=weights_to_proto(self.weights),
-            server_weights=weights_to_proto(self.weights),
-            metrics=metrics_to_proto(self.metrics),
-            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
-            optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
-        )
-
-        client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-            self.dispatcher.train_global_on("1", 42, 43, 3, {"optimizer_state": 44})
-        )
-
-        self.assertEqual(client_weights, self.weights)
-        self.assertEqual(server_weights, self.weights)
-        self.assertEqual(metrics, self.metrics)
-        self.assertEqual(optimizer_state, {"optimizer_state": 42})
-
-        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
-        self.mock_stub.TrainGlobal.assert_called_once_with(
-            connection_pb2.TrainGlobalRequest(
-                epochs=42,
-                round_no=43,
-                adaptive_threshold_value=3,
-                optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
-            )
-        )
-
-    def test_train_global_on_with_error(self):
-        self.mock_stub.TrainGlobal.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.train_global_on(
-            "1",
-            42,
-            round_no=43,
-            adaptive_threshold_value=3,
-            optimizer_state={"optimizer_state": 44},
-        )
-
-        self.assertEqual(response, False)
-        self.mock_stub.TrainGlobal.assert_called_once_with(
-            connection_pb2.TrainGlobalRequest(
-                epochs=42,
-                round_no=43,
-                adaptive_threshold_value=3,
-                optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
-            )
-        )
-
-    def test_set_weights_on_without_error(self):
-        self.mock_stub.SetWeights.return_value = SetWeightsResponse()
-
-        self.dispatcher.set_weights_on("1", self.weights, True)
-
-        self.mock_stub.SetWeights.assert_called_once_with(
-            connection_pb2.SetWeightsRequest(
-                weights=weights_to_proto(self.weights), on_client=True
-            ),
-            wait_for_ready=False,
-        )
-
-    def test_set_weights_on_with_error(self):
-        self.mock_stub.SetWeights.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.set_weights_on("1", self.weights, True)
-
-        self.mock_stub.SetWeights.assert_called_once_with(
-            connection_pb2.SetWeightsRequest(
-                weights=weights_to_proto(self.weights), on_client=True
-            ),
-            wait_for_ready=False,
-        )
-        self.assertEqual(response, False)
-
-    def test_train_epoch_on_without_error(self):
-        self.mock_stub.TrainEpoch.return_value = TrainEpochResponse(
-            weights=weights_to_proto(self.weights),
-            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
-        )
-
-        weights, diagnostic_metrics = self.dispatcher.train_epoch_on("1", "0", 42)
-
-        self.assertEqual(weights, self.weights)
-        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
-        self.mock_stub.TrainEpoch.assert_called_once_with(
-            connection_pb2.TrainEpochRequest(
-                server=DeviceInfo(device_id="0", address="42"), round_no=42
-            )
-        )
-
-    def test_train_epoch_on_with_error(self):
-        self.mock_stub.TrainEpoch.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.train_epoch_on("1", "0", 42)
-
-        self.assertEqual(response, False)
-        self.mock_stub.TrainEpoch.assert_called_once_with(
-            connection_pb2.TrainEpochRequest(
-                server=DeviceInfo(device_id="0", address="42"), round_no=42
-            )
-        )
-
-    def test_train_batch_on_without_error(self):
-        self.mock_stub.TrainBatch.return_value = connection_pb2.TrainBatchResponse(
-            gradients=gradients_to_proto(self.gradients),
-            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
-            loss=42.0,
-        )
-
-        gradients, loss, diagnostic_metrics = self.dispatcher.train_batch_on(
-            "1", self.activations, self.labels
-        )
-
-        self.assertEqual(gradients, self.gradients)
-        self.assertEqual(loss, 42.0)
-        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
-        self.mock_stub.TrainBatch.assert_called_once_with(
-            connection_pb2.TrainBatchRequest(
-                smashed_data=Activations(activations=tensor_to_proto(self.activations)),
-                labels=Labels(labels=tensor_to_proto(self.labels)),
-            )
-        )
-
-    def test_train_batch_on_with_error(self):
-        self.mock_stub.TrainBatch.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.train_batch_on("1", self.activations, self.labels)
-
-        self.assertEqual(response, False)
-        self.mock_stub.TrainBatch.assert_called_once_with(
-            connection_pb2.TrainBatchRequest(
-                smashed_data=Activations(activations=tensor_to_proto(self.activations)),
-                labels=Labels(labels=tensor_to_proto(self.labels)),
-            )
-        )
-
-    def test_evaluate_global_on_without_error(self):
-        self.mock_stub.EvaluateGlobal.return_value = connection_pb2.EvalGlobalResponse(
-            metrics=metrics_to_proto(self.metrics),
-            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
-        )
-
-        metrics, diagnostic_metrics = self.dispatcher.evaluate_global_on(
-            "1", True, True
-        )
-
-        self.assertEqual(metrics, self.metrics)
-        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
-        self.mock_stub.EvaluateGlobal.assert_called_once_with(
-            connection_pb2.EvalGlobalRequest(validation=True, federated=True)
-        )
-
-    def test_evaluate_global_on_with_error(self):
-        self.mock_stub.EvaluateGlobal.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.evaluate_global_on("1", True, True)
-
-        self.assertEqual(response, False)
-        self.mock_stub.EvaluateGlobal.assert_called_once_with(
-            connection_pb2.EvalGlobalRequest(validation=True, federated=True)
-        )
-
-    def test_evaluate_on_without_error(self):
-        self.mock_stub.Evaluate.return_value = EvalResponse(
-            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics)
-        )
-
-        diagnostic_metrics = self.dispatcher.evaluate_on("1", "0", True)
-
-        self.assertEqual(diagnostic_metrics, self.diagnostic_metrics)
-        self.mock_stub.Evaluate.assert_called_once_with(
-            connection_pb2.EvalRequest(
-                server=DeviceInfo(device_id="0", address="42"), validation=True
-            )
-        )
-
-    def test_evaluate_on_with_error(self):
-        self.mock_stub.Evaluate.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.evaluate_on("1", "0", True)
-
-        self.assertEqual(response, False)
-        self.mock_stub.Evaluate.assert_called_once_with(
-            connection_pb2.EvalRequest(
-                server=DeviceInfo(device_id="0", address="42"), validation=True
-            )
-        )
-
-    def test_evaluate_batch_on_without_error(self):
-        self.mock_stub.EvaluateBatch.return_value = EvalBatchResponse(
-            diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics)
-        )
-
-        diagnostic_metrics = self.dispatcher.evaluate_batch_on(
-            "1", self.activations, self.labels
-        )
-
-        self._assert_field_size_added_to_diagnostic_metrics(
-            diagnostic_metrics
-        )  # metric field present in response, but not used in practice. Thus, a field size is added
-        self.mock_stub.EvaluateBatch.assert_called_once_with(
-            connection_pb2.EvalBatchRequest(
-                smashed_data=Activations(activations=tensor_to_proto(self.activations)),
-                labels=Labels(labels=tensor_to_proto(self.labels)),
-            )
-        )
-
-    def test_evaluate_batch_on_with_error(self):
-        self.mock_stub.EvaluateBatch.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.evaluate_batch_on("1", self.activations, self.labels)
-
-        self.assertEqual(response, False)
-        self.mock_stub.EvaluateBatch.assert_called_once_with(
-            connection_pb2.EvalBatchRequest(
-                smashed_data=Activations(activations=tensor_to_proto(self.activations)),
-                labels=Labels(labels=tensor_to_proto(self.labels)),
-            )
-        )
-
-    def test_full_model_training_without_error(self):
-        self.mock_stub.FullModelTraining.return_value = (
-            connection_pb2.FullModelTrainResponse(
-                client_weights=weights_to_proto(self.weights),
-                server_weights=weights_to_proto(self.weights),
-                num_samples=42,
-                metrics=metrics_to_proto(self.metrics),
-                diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
-            )
-        )
-
-        client_weights, server_weights, num_samples, metrics, diagnostic_metrics = (
-            self.dispatcher.federated_train_on("1", 42)
-        )
-
-        self.assertEqual(client_weights, self.weights)
-        self.assertEqual(server_weights, self.weights)
-        self.assertEqual(num_samples, 42)
-        self.assertEqual(metrics, self.metrics)
-        self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
-        self.mock_stub.FullModelTraining.assert_called_once_with(
-            connection_pb2.FullModelTrainRequest(round_no=42)
-        )
-
-    def test_full_model_training_with_error(self):
-        self.mock_stub.FullModelTraining.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.federated_train_on("1", 42)
-
-        self.assertEqual(response, False)
-        self.mock_stub.FullModelTraining.assert_called_once_with(
-            connection_pb2.FullModelTrainRequest(round_no=42)
-        )
-
-    def test_start_experiment_on_without_error(self):
-        self.mock_stub.StartExperiment.return_value = (
-            connection_pb2.StartExperimentResponse()
-        )
-
-        response = self.dispatcher.start_experiment_on("1", True)
-
-        self.assertEqual(response, True)
-        self.mock_stub.StartExperiment.assert_called_once_with(
-            connection_pb2.StartExperimentRequest(), wait_for_ready=True
-        )
-
-    def test_start_experiment_on_with_error(self):
-        self.mock_stub.StartExperiment.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.start_experiment_on("1", True)
-
-        self.assertEqual(response, False)
-        self.mock_stub.StartExperiment.assert_called_once_with(
-            connection_pb2.StartExperimentRequest(), wait_for_ready=True
-        )
-
-    def test_end_experiment_on_without_error(self):
-        self.mock_stub.EndExperiment.return_value = (
-            connection_pb2.EndExperimentResponse()
-        )
-
-        response = self.dispatcher.end_experiment_on("1")
-
-        self.assertEqual(response, True)
-        self.mock_stub.EndExperiment.assert_called_once_with(
-            connection_pb2.EndExperimentRequest()
-        )
-
-    def test_end_experiment_on_with_error(self):
-        self.mock_stub.EndExperiment.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.end_experiment_on("1")
-
-        self.assertEqual(response, False)
-        self.mock_stub.EndExperiment.assert_called_once_with(
-            connection_pb2.EndExperimentRequest()
-        )
-
-    def test_get_battery_status_without_error(self):
-        self.mock_stub.GetBatteryStatus.return_value = (
-            connection_pb2.BatteryStatusResponse(
-                status=BatteryStatus(initial_battery_level=42, current_battery_level=21)
-            )
-        )
-
-        response = self.dispatcher.get_battery_status_on("1")
-
-        self.assertEqual(
-            response, DeviceBatteryStatus(initial_capacity=42, current_capacity=21)
-        )
-        self.mock_stub.GetBatteryStatus.assert_called_once_with(
-            connection_pb2.BatteryStatusRequest()
-        )
-
-    def test_get_battery_status_with_error(self):
-        self.mock_stub.GetBatteryStatus.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.get_battery_status_on("1")
-
-        self.assertEqual(response, False)
-        self.mock_stub.GetBatteryStatus.assert_called_once_with(
-            connection_pb2.BatteryStatusRequest()
-        )
-
-    def test_get_dataset_model_info_without_error(self):
-        self.mock_stub.GetDatasetModelInfo.return_value = (
-            connection_pb2.DatasetModelInfoResponse(
-                train_samples=42,
-                validation_samples=21,
-                client_fw_flops=1,
-                server_fw_flops=2,
-                client_bw_flops=3,
-                server_bw_flops=4,
-            )
-        )
-
-        response = self.dispatcher.get_dataset_model_info_on("1")
-
-        self.assertEqual(response, (42, 21, 1, 2, 3, 4))
-        self.mock_stub.GetDatasetModelInfo.assert_called_once_with(
-            connection_pb2.DatasetModelInfoRequest()
-        )
-
-    def test_get_dataset_model_info_with_error(self):
-        self.mock_stub.GetDatasetModelInfo.side_effect = grpc.RpcError()
-
-        response = self.dispatcher.get_dataset_model_info_on("1")
-
-        self.assertEqual(response, False)
-        self.mock_stub.GetDatasetModelInfo.assert_called_once_with(
-            connection_pb2.DatasetModelInfoRequest()
-        )
-
-    def test_handle_calls_to_inactive_device(self):
-        """Test each method of the dispatcher that does RPC calls to handle calls to inactive devices.
-        One test where the device is known, but inactive and one where the device is unknown.
-        """
-        methods_names = [
-            "train_global_on",
-            "set_weights_on",
-            "train_epoch_on",
-            "train_batch_on",
-            "evaluate_global_on",
-            "evaluate_on",
-            "evaluate_batch_on",
-            "federated_train_on",
-            "start_experiment_on",
-            "end_experiment_on",
-            "get_battery_status_on",
-            "get_dataset_model_info_on",
-        ]
-        for device in [("0", True), ("2", False)]:  # 0 is inactive, 2 is unknown
-            for method_name in methods_names:
-                with self.subTest(method_name=f"{method_name}_device{device[0]}"):
-                    with patch("builtins.print") as print_patch:
-                        method = getattr(self.dispatcher, method_name)
-                        params = list(signature(method).parameters)
-
-                        # make method call with device id and all other params as None
-                        response = method(
-                            device[0],
-                            *[None for _ in params if _ not in ["self", "device_id"]],
-                        )
-
-                        self.assertEqual(response, False)
-                        self.mock_stub.assert_not_called()
-                        if device[1]:
-                            print_patch.assert_called_once_with(
-                                f"Device {device[0]} not active"
-                            )
-                        else:
-                            print_patch.assert_called_once_with(
-                                f"Unknown Device ID {device[0]}"
-                            )
-
-    def _assert_field_size_added_to_diagnostic_metrics(self, diagnostic_metrics):
-        """Not to use when response only includes diagnostic metrics, as these are ignored for the field size"""
-        self.assertEqual(
-            diagnostic_metrics.get_as_list()[0],
-            self.diagnostic_metrics.get_as_list()[0],
-        )  # check that previous diagnostic metrics still exist
-        self.assertEqual(
-            diagnostic_metrics.get_as_list()[1].name, "size"
-        )  # check that at least one new diagnostic metric for size was added
-
-
-class RequestDispatcherThreadingTest(unittest.TestCase):
-
-    def test_handle_rpc_error_thread_safety(self):
-
-        # run the procedure with 2 to 20 threads
-        for n in range(2, 21):
-            with self.subTest(method_name=f"Test with {n} threads"):
-                print(f"Test with {n} threads")
-                self.request_dispatcher = DeviceRequestDispatcher([])
-
-                # mock n connections with half of them returning errors and the other half valid responses
-                self.request_dispatcher.connections = {
-                    f"d{i}": Mock(spec=["GetBatteryStatus"]) for i in range(n)
-                }
-                response = connection_pb2.BatteryStatusResponse(
-                    status=BatteryStatus(
-                        initial_battery_level=42, current_battery_level=21
-                    )
-                )
-                for i, mock in enumerate(self.request_dispatcher.connections.values()):
-                    mock.GetBatteryStatus.side_effect = get_side_effect(i, response)
-
-                # make n concurrent calls to the dispatcher
-                responses = []
-                response_lock = threading.Lock()
-                with concurrent.futures.ThreadPoolExecutor(max_workers=n) as executor:
-                    futures = [
-                        executor.submit(
-                            self.request_dispatcher.get_battery_status_on, device_id
-                        )
-                        for device_id in self.request_dispatcher.active_devices()
-                    ]
-                    for future in concurrent.futures.as_completed(futures):
-                        with response_lock:
-                            responses.append(future.result())
-
-                # check that only error connections were removed
-                self.assertEqual(
-                    list(self.request_dispatcher.connections.keys()),
-                    [f"d{i}" for i in range(n) if i % 2 == 0],
-                )
-
-                response_battery_level = DeviceBatteryStatus(
-                    initial_capacity=42, current_capacity=21
-                )
-
-                # check that there are n responses in total, half of them valid and half of them None
-                self.assertEqual(responses.count(False), math.floor(n / 2))
-                self.assertEqual(
-                    responses.count(response_battery_level), math.ceil(n / 2)
-                )
diff --git a/edml/tests/core/rpc_device_servicer_test.py b/edml/tests/core/rpc_device_servicer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d392f3e8318e7417b1185405dda793a9fda3247d
--- /dev/null
+++ b/edml/tests/core/rpc_device_servicer_test.py
@@ -0,0 +1,397 @@
+import unittest
+from unittest.mock import Mock
+
+from grpc import StatusCode
+from grpc_testing import server_from_dictionary, strict_real_time
+from torch import Tensor
+
+from edml.core.battery import Battery, BatteryEmptyException
+from edml.core.device import NetworkDevice
+from edml.core.rpc_device_servicer import RPCDeviceServicer
+from edml.generated import connection_pb2, datastructures_pb2
+from edml.generated.connection_pb2 import (
+    SetWeightsResponse,
+    EvalResponse,
+    EvalBatchResponse,
+)
+from edml.generated.datastructures_pb2 import (
+    Weights,
+    Activations,
+    Labels,
+    DeviceInfo,
+    BatteryStatus,
+)
+from edml.helpers.metrics import (
+    ModelMetricResultContainer,
+    ModelMetricResult,
+    DiagnosticMetricResultContainer,
+    DiagnosticMetricResult,
+)
+from edml.helpers.proto_helpers import (
+    state_dict_to_proto,
+    proto_to_weights,
+    proto_to_metrics,
+    proto_to_state_dict,
+    tensor_to_proto,
+    proto_to_tensor,
+    weights_to_proto,
+)
+
+
+class RPCDeviceServicerTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        # instantiating NetworkDevice to mock Device(ABC)'s properties
+        # store mock_device reference for later function call assertions
+        self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery)))
+        self.mock_device.device_id = 42
+
+        my_servicer = RPCDeviceServicer(device=self.mock_device)
+        servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer}
+        self.test_server = server_from_dictionary(servicers, strict_real_time())
+
+        self.metrics = ModelMetricResultContainer(
+            [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)]
+        )
+        self.diagnostic_metrics = DiagnosticMetricResultContainer(
+            [DiagnosticMetricResult("d1", "comp_time", "train", 42)]
+        )
+
+    def make_call(self, method_name, request):
+        method = self.test_server.invoke_unary_unary(
+            method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[
+                "Device"
+            ].methods_by_name[method_name],
+            invocation_metadata={},
+            request=request,
+            timeout=None,
+        )
+        return method.termination()
+
+    def test_train_global(self):
+        self.mock_device.train_global.return_value = (
+            {"weights": Tensor([42])},
+            {"weights": Tensor([43])},
+            self.metrics,
+            {"optimizer_state": 44},
+            self.diagnostic_metrics,
+        )
+        request = connection_pb2.TrainGlobalRequest(
+            epochs=42,
+            round_no=1,
+            adaptive_threshold_value=3,
+            optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
+        )
+
+        response, metadata, code, details = self.make_call("TrainGlobal", request)
+
+        self.assertEqual(
+            (
+                proto_to_weights(response.client_weights),
+                proto_to_weights(response.server_weights),
+            ),
+            ({"weights": Tensor([42])}, {"weights": Tensor([43])}),
+        )
+        self.assertEqual(proto_to_metrics(response.metrics), self.metrics)
+        self.assertEqual(
+            proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44}
+        )
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.train_global.assert_called_once_with(
+            42, 1, 3, {"optimizer_state": 42}
+        )
+        self.assertEqual(
+            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
+        )
+
+    def test_set_weights(self):
+        request = connection_pb2.SetWeightsRequest(
+            weights=Weights(weights=state_dict_to_proto({"weights": Tensor([42])})),
+            on_client=True,
+        )
+
+        response, metadata, code, details = self.make_call("SetWeights", request)
+
+        self.assertEqual(type(response), type(SetWeightsResponse()))
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.set_weights.assert_called_once_with(
+            {"weights": Tensor([42])}, True
+        )
+
+    def test_train_epoch(self):
+        self.mock_device.train_epoch.return_value = {
+            "weights": Tensor([42])
+        }, self.diagnostic_metrics
+        request = connection_pb2.TrainEpochRequest(
+            server=datastructures_pb2.DeviceInfo(device_id="42", address="")
+        )
+
+        response, metadata, code, details = self.make_call("TrainEpoch", request)
+
+        self.assertEqual(
+            proto_to_state_dict(response.weights.weights), {"weights": Tensor([42])}
+        )
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.train_epoch.assert_called_once_with("42", 0)
+        self.assertEqual(
+            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
+        )
+
+    def test_train_batch(self):
+        self.mock_device.train_batch.return_value = (
+            Tensor([42]),
+            42.0,
+            self.diagnostic_metrics,
+        )
+        request = connection_pb2.TrainBatchRequest(
+            smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))),
+            labels=Labels(labels=tensor_to_proto(Tensor([1]))),
+        )
+
+        response, metadata, code, details = self.make_call("TrainBatch", request)
+
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.train_batch.assert_called_once_with(Tensor([1.0]), Tensor([1]))
+        self.assertEqual(proto_to_tensor(response.gradients.gradients), Tensor([42]))
+        self.assertEqual(
+            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
+        )
+
+    def test_evaluate_global(self):
+        self.mock_device.evaluate_global.return_value = (
+            self.metrics,
+            self.diagnostic_metrics,
+        )
+        request = connection_pb2.EvalGlobalRequest(validation=False, federated=False)
+
+        response, metadata, code, details = self.make_call("EvaluateGlobal", request)
+
+        self.assertEqual(proto_to_metrics(response.metrics), self.metrics)
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.evaluate_global.assert_called_once_with(False, False)
+        self.assertEqual(
+            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
+        )
+
+    def test_evaluate(self):
+        self.mock_device.evaluate.return_value = self.diagnostic_metrics
+        request = connection_pb2.EvalRequest(
+            server=DeviceInfo(device_id="42", address=""), validation=True
+        )
+
+        response, metadata, code, details = self.make_call("Evaluate", request)
+
+        self.assertEqual(type(response), type(EvalResponse()))
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.evaluate.assert_called_once_with("42", True)
+        self.assertEqual(
+            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
+        )
+
+    def test_evaluate_batch(self):
+        self.mock_device.evaluate_batch.return_value = self.diagnostic_metrics
+        request = connection_pb2.EvalBatchRequest(
+            smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))),
+            labels=Labels(labels=tensor_to_proto(Tensor([1]))),
+        )
+
+        response, metadata, code, details = self.make_call("EvaluateBatch", request)
+
+        self.assertEqual(type(response), type(EvalBatchResponse()))
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.evaluate_batch.assert_called_once_with(
+            Tensor([1.0]), Tensor([1])
+        )
+        self.assertEqual(
+            proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
+        )
+
+    def test_full_model_training(self):
+        self.mock_device.federated_train.return_value = (
+            {"weights": Tensor([42])},
+            {"weights": Tensor([43])},
+            44,
+            self.metrics,
+            self.diagnostic_metrics,
+        )
+        request = connection_pb2.FullModelTrainRequest()
+
+        response, metadata, code, details = self.make_call("FullModelTraining", request)
+
+        self.assertEqual(
+            proto_to_state_dict(response.client_weights.weights),
+            {"weights": Tensor([42])},
+        )
+        self.assertEqual(
+            proto_to_state_dict(response.server_weights.weights),
+            {"weights": Tensor([43])},
+        )
+        self.assertEqual(response.num_samples, 44)
+        self.assertEqual(proto_to_metrics(response.metrics), self.metrics)
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.federated_train.assert_called_once()
+
+    def test_start_experiment(self):
+        self.mock_device.start_experiment.return_value = None
+        request = connection_pb2.StartExperimentRequest()
+
+        response, metadata, code, details = self.make_call("StartExperiment", request)
+
+        self.assertEqual(type(response), type(connection_pb2.StartExperimentResponse()))
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.start_experiment.assert_called_once()
+
+    def test_end_experiment(self):
+        self.mock_device.end_experiment.return_value = None
+        request = connection_pb2.EndExperimentRequest()
+
+        response, metadata, code, details = self.make_call("EndExperiment", request)
+
+        self.assertEqual(type(response), type(connection_pb2.EndExperimentResponse()))
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.end_experiment.assert_called_once()
+
+    def test_battery_status(self):
+        self.mock_device.get_battery_status.return_value = (42, 21)
+        request = connection_pb2.BatteryStatusRequest()
+
+        response, metadata, code, details = self.make_call("GetBatteryStatus", request)
+
+        self.assertEqual(
+            response.status,
+            BatteryStatus(initial_battery_level=42, current_battery_level=21),
+        )
+        self.assertEqual(code, StatusCode.OK)
+        self.mock_device.get_battery_status.assert_called_once()
+
+    def test_dataset_model_info(self):
+        self.mock_device.client._train_data.dataset = [1]
+        self.mock_device.client._val_data.dataset = [2]
+        self.mock_device.client._model_flops = {"FW": 3, "BW": 6}
+        self.mock_device.server._model_flops = {"FW": 4, "BW": 8}
+        request = connection_pb2.DatasetModelInfoRequest()
+
+        response, metadata, code, details = self.make_call(
+            "GetDatasetModelInfo", request
+        )
+
+        self.assertEqual(code, StatusCode.OK)
+        self.assertEqual(response.train_samples, 1)
+        self.assertEqual(response.validation_samples, 1)
+        self.assertEqual(response.client_fw_flops, 3)
+        self.assertEqual(response.server_fw_flops, 4)
+        self.assertEqual(response.client_bw_flops, 6)
+        self.assertEqual(response.server_bw_flops, 8)
+
+
+class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        # instantiating NetworkDevice to mock Device(ABC)'s properties
+        # store mock_device reference for later function call assertions
+        self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery)))
+        my_servicer = RPCDeviceServicer(device=self.mock_device)
+        servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer}
+        self.test_server = server_from_dictionary(servicers, strict_real_time())
+
+    def make_call(self, method_name, request):
+        method = self.test_server.invoke_unary_unary(
+            method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[
+                "Device"
+            ].methods_by_name[method_name],
+            invocation_metadata={},
+            request=request,
+            timeout=None,
+        )
+        return method.termination()
+
+    def test_stop_at_train_global(self):
+        self.mock_device.train_global.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.TrainGlobalRequest(
+            optimizer_state=state_dict_to_proto(None)
+        )
+        self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal")
+
+    def test_stop_at_set_weights(self):
+        self.mock_device.set_weights.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.SetWeightsRequest(weights=weights_to_proto({}))
+        self._test_device_out_of_battery_lets_rpc_fail(request, "SetWeights")
+
+    def test_stop_at_train_epoch(self):
+        self.mock_device.train_epoch.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.TrainEpochRequest(server=None)
+        self._test_device_out_of_battery_lets_rpc_fail(request, "TrainEpoch")
+
+    def test_stop_at_train_batch(self):
+        self.mock_device.train_batch.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.TrainBatchRequest(
+            smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))),
+            labels=Labels(labels=tensor_to_proto(Tensor([1]))),
+        )
+        self._test_device_out_of_battery_lets_rpc_fail(request, "TrainBatch")
+
+    def test_stop_at_evaluate_global(self):
+        self.mock_device.evaluate_global.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.EvalGlobalRequest()
+        self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateGlobal")
+
+    def test_stop_at_evaluate(self):
+        self.mock_device.evaluate.side_effect = BatteryEmptyException("Battery empty")
+        request = connection_pb2.EvalRequest(server=None)
+        self._test_device_out_of_battery_lets_rpc_fail(request, "Evaluate")
+
+    def test_stop_at_evaluate_batch(self):
+        self.mock_device.evaluate_batch.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.EvalBatchRequest(
+            smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))),
+            labels=Labels(labels=tensor_to_proto(Tensor([1]))),
+        )
+        self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateBatch")
+
+    def test_stop_at_full_model_training(self):
+        self.mock_device.federated_train.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.FullModelTrainRequest()
+        self._test_device_out_of_battery_lets_rpc_fail(request, "FullModelTraining")
+
+    def test_stop_at_start_experiment(self):
+        self.mock_device.start_experiment.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.StartExperimentRequest()
+        self._test_device_out_of_battery_lets_rpc_fail(request, "StartExperiment")
+
+    def test_stop_at_end_experiment(self):
+        self.mock_device.end_experiment.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.EndExperimentRequest()
+        self._test_device_out_of_battery_lets_rpc_fail(request, "EndExperiment")
+
+    def test_stop_at_get_battery_status(self):
+        self.mock_device.get_battery_status.side_effect = BatteryEmptyException(
+            "Battery empty"
+        )
+        request = connection_pb2.BatteryStatusRequest()
+        self._test_device_out_of_battery_lets_rpc_fail(request, "GetBatteryStatus")
+
+    def _test_device_out_of_battery_lets_rpc_fail(self, request, servicer_method_name):
+        response, metadata, code, details = self.make_call(
+            servicer_method_name, request
+        )
+        self.assertIsNone(response)
+        self.assertEqual(code, StatusCode.UNKNOWN)
+        self.assertEqual(details, "Exception calling application: Battery empty")
diff --git a/edml/tests/integration/rpc_server_test.py b/edml/tests/integration/rpc_server_test.py
index faafd7b9b94300c9509d777453694381dfd1529c..0314d1094ffb392d7702f5d7660d2abda8f6f68c 100644
--- a/edml/tests/integration/rpc_server_test.py
+++ b/edml/tests/integration/rpc_server_test.py
@@ -7,7 +7,9 @@ import grpc
 from omegaconf import DictConfig
 
 from edml.core.battery import Battery, BatteryEmptyException
-from edml.core.device import NetworkDevice, RPCDeviceServicer, DeviceRequestDispatcher
+from edml.core.device import NetworkDevice
+from edml.core.device_request_dispatcher import DeviceRequestDispatcher
+from edml.core.rpc_device_servicer import RPCDeviceServicer
 from edml.generated import connection_pb2_grpc
 from edml.helpers.interceptors import DeviceServerInterceptor
 from edml.helpers.logging import SimpleLogger