-
Tim Tobias Bauerle authoredTim Tobias Bauerle authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
device.py 37.29 KiB
from __future__ import annotations
import threading
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List, Union, Tuple, cast
import grpc
from google.protobuf.message import Message
from omegaconf import DictConfig
from torch import Tensor
from torch.autograd import Variable
from edml.core.battery import Battery
from edml.core.client import DeviceClient
from edml.core.server import DeviceServer
from edml.generated import connection_pb2
from edml.generated.connection_pb2 import (
SetWeightsRequest,
TrainBatchRequest,
TrainGlobalResponse,
TrainEpochResponse,
TrainBatchResponse,
EvalGlobalResponse,
EvalResponse,
EvalBatchResponse,
FullModelTrainResponse,
BatteryStatusResponse,
DatasetModelInfoResponse,
EndExperimentResponse,
StartExperimentResponse,
SingleBatchTrainingResponse,
SingleBatchBackwardRequest,
TrainGlobalParallelSplitLearningResponse,
)
from edml.generated.connection_pb2_grpc import DeviceServicer, DeviceStub
from edml.generated.datastructures_pb2 import (
Gradients,
Weights,
DeviceInfo,
Activations,
Labels,
BatteryStatus,
Empty,
Metrics,
)
from edml.helpers.decorators import (
log_execution_time,
update_battery,
add_time_to_diagnostic_metrics,
)
from edml.helpers.interceptors import DeviceClientInterceptor
from edml.helpers.logging import SimpleLogger
from edml.helpers.metrics import (
ModelMetricResultContainer,
DiagnosticMetricResultContainer,
)
from edml.helpers.proto_helpers import (
proto_to_tensor,
tensor_to_proto,
state_dict_to_proto,
proto_to_state_dict,
proto_to_weights,
metrics_to_proto,
proto_to_metrics,
_proto_size_per_field,
)
from edml.helpers.types import (
HasMetrics,
StateDict,
DeviceBatteryStatus,
DeviceBatteryStatusReport,
)
class Device(ABC):
"""
Base class that represents a (physical or virtual) device. Every device is split into two parts: a client part and
a server part.
Attributes:
device_id (str): This device's id.
logger (SimpleLogger): The logger instance the device can use.
battery (Battery): The device's battery. Certain function consume energy and drain the battery.
client (DeviceClient): The client part of this device. Initialized later by explicitly calling
py:meth:`set_client`.
server (DeviceServer): The server part of this device. Initialized later by explicitly calling
py:meth:`set_server`.
"""
def __init__(self, device_id: str, logger: SimpleLogger, battery: Battery):
self.client: DeviceClient = cast(DeviceClient, None)
self.server: DeviceServer = cast(DeviceServer, None)
self.device_id = device_id
self.logger = logger
self.battery = battery
def set_client(self, client: DeviceClient):
"""
Sets the client part of this device.
Notes:
Also sets the `device` instance on the client part.
"""
self.client = client
self.client.set_device(self)
def set_server(self, server: DeviceServer):
"""
Sets the server part of this device.
Notes:
Also sets the `device` instance on the server part.
"""
self.server = server
self.server.set_device(self)
def log(self, message: Any):
"""Logging wrapper to be accessed by server and client"""
self.logger.log(message)
def start_experiment(self):
"""
Lifecycle hook that is called at the start of an experiment.
Notes:
Ensure that the super method is called if you override this method.
"""
self.logger.start_experiment()
self.battery.start_experiment()
def end_experiment(self):
"""
Lifecycle hook that is called at the end of an experiment.
Notes:
Ensure that the super method is called if you override this method.
"""
self.logger.end_experiment()
def get_battery_status(self) -> Tuple[int, int]:
"""
Returns the initial and remaining battery of this device.
Returns:
Tuple[int, int]: The first component holds the initial battery capacity, the second the current capacity.
"""
return self.battery.initial_capacity(), self.battery.remaining_capacity()
def shutdown(self):
"""
Shuts the device down and cleans up resources.
"""
self.end_experiment()
@abstractmethod
def train_global(self, epochs: int):
"""Trains globally for a given number of epochs using the device's server"""
@abstractmethod
def set_devices(self, devices):
"""Sets references to all devices in the network"""
@abstractmethod
def set_weights(self, state_dict, on_client: bool):
"""Sets the weights for one of the device's networks"""
@abstractmethod
def set_weights_on(self, device_id: str, state_dict, on_client: bool):
"""Sets the weights for on of the networks on the device with the given id"""
@abstractmethod
def train_epoch(self, server_device: str):
"""Trains an epoch on the device's client"""
@abstractmethod
def train_epoch_on(self, device_id: str, server_device: str, round_no: int):
"""Trains an epoch on the device's client with the given id"""
@abstractmethod
def train_batch(self, smashed_data, labels):
"""Trains a batch on the device's server"""
@abstractmethod
def train_batch_on(self, device_id: str, smashed_data, labels):
"""Trains a batch on the server of the device with the given id"""
@abstractmethod
def evaluate_global(self, val: bool = True, fed: bool = False):
"""Starts evaluation on all devices' clients using the device's server. val determines whether the validation (True) or test (False) set is used"""
@abstractmethod
def evaluate(self, server_device: str, val: bool = True):
"""Starts evaluation on the device's client using the specified server. val determines whether the validation (True) or test (False) set is used"""
@abstractmethod
def evaluate_on(self, device_id: str, server_device: str, val: bool):
"""Starts evaluation on the client of the device with the given id using the specified server. val determines whether the validation (True) or test (False) set is used"""
@abstractmethod
def evaluate_batch(self, smashed_data, labels):
"""Evaluates a batch on the device's server"""
@abstractmethod
def evaluate_batch_on(self, device_id: str, smashed_data, labels):
"""Evaluates a batch on the server of the device with the given id"""
@abstractmethod
def train_batch_on_client_only_on(self, device_id: str, batch_index: int):
""""""
@abstractmethod
def backpropagation_on_client_only_on(self, client_id: str, gradients: Any):
""""""
class NetworkDevice(Device):
@update_battery
@log_execution_time("logger", "train_parallel_split_learning")
def train_parallel_split_learning(
self,
clients: list[str],
round_no: int,
adaptive_learning_threshold: Optional[float] = None,
optimizer_state: dict[str, Any] = None,
):
return self.server.train_parallel_split_learning(
clients=clients,
round_no=round_no,
adaptive_learning_threshold=adaptive_learning_threshold,
optimizer_state=optimizer_state,
)
@update_battery
@log_execution_time("logger", "client_only_backpropagation_train")
def backpropagation_on_client_only(self, gradients: Any):
return self.client.backward_single_batch(gradients)
@update_battery
def backpropagation_on_client_only_on(self, client_id: str, gradients: Any):
if client_id == self.device_id:
return self.backpropagation_on_client_only(gradients=gradients)
else:
return self.request_dispatcher.backpropagation_on_client_only(
device_id=client_id, gradients=gradients
)
@update_battery
@log_execution_time("logger", "client_only_batch_train")
def train_batch_on_client_only(self, batch_index: int):
result = self.client.train_single_batch(batch_index=batch_index)
if result is None:
return result
return result.smashed_data, result.labels
@update_battery
def train_batch_on_client_only_on(self, device_id: str, batch_index: int):
if self.device_id == device_id:
return self.train_batch_on_client_only(batch_index=batch_index)
else:
return self.request_dispatcher.train_batch_on_client_only(
device_id=device_id, batch_index=batch_index
)
def __init__(
self,
device_id: str,
logger: SimpleLogger,
battery: Battery,
stop_event: Optional[threading.Event] = None,
):
self.devices: List[DictConfig[str, Any]] = []
self.request_dispatcher = DeviceRequestDispatcher([], device_id=device_id)
self.stop_event = stop_event
super().__init__(device_id, logger, battery)
@add_time_to_diagnostic_metrics("train_global")
@update_battery
@log_execution_time("logger", "train_global_time")
def train_global(
self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None
) -> Tuple[
Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
]:
return self.server.train(
devices=self.__get_device_ids__(),
epochs=epochs,
round_no=round_no,
optimizer_state=optimizer_state,
)
def __get_device_ids__(self) -> List[str]:
return [d.device_id for d in self.devices]
def set_devices(self, devices: List[DictConfig[str, Any]]):
"""
Sets a Dictionary with references to all devices in the network.
Expects a list of dictionaries containing with keys for the device_id and address of each device.
"""
self.devices = devices
self.request_dispatcher = DeviceRequestDispatcher(
devices,
self.logger,
self.battery,
self.stop_event,
device_id=self.device_id,
)
@update_battery
def set_weights(self, state_dict, on_client: bool = True):
if on_client:
self.client.set_weights(state_dict)
else:
self.server.set_weights(state_dict)
@update_battery
def set_weights_on(self, device_id: str, state_dict, on_client: bool = True):
if device_id == self.device_id:
self.set_weights(state_dict, on_client)
else:
self.request_dispatcher.set_weights_on(device_id, state_dict, on_client)
@update_battery
@log_execution_time("logger", "client_train_epoch_time")
def train_epoch(self, server_device: str, round_no: int = -1):
# the execution time is measured in the client in order to deduct the time for the server
return self.client.train_epoch(server_device, round_no=round_no)
@update_battery
def train_epoch_on(self, device_id: str, server_device: str, round_no: int = -1):
if device_id == self.device_id:
return self.train_epoch(server_device, round_no)
return self.request_dispatcher.train_epoch_on(
device_id, server_device, round_no
)
@add_time_to_diagnostic_metrics("train_batch")
@update_battery
def train_batch(self, smashed_data, labels) -> Variable:
result = self.server.train_batch(smashed_data, labels)
self._log_current_battery_capacity()
return result
@update_battery
def train_batch_on(self, device_id: str, smashed_data, labels):
if device_id == self.device_id:
return self.train_batch(smashed_data, labels)
result = self.request_dispatcher.train_batch_on(device_id, smashed_data, labels)
self._log_current_battery_capacity()
return result
@add_time_to_diagnostic_metrics("evaluate_global")
@update_battery
@log_execution_time("logger", "evaluate_global_time")
def evaluate_global(
self, val: bool = True, fed: bool = False
) -> ModelMetricResultContainer:
if fed:
return self.server.evaluate_global(devices=[self.device_id], val=val)
else:
return self.server.evaluate_global(
devices=self.__get_device_ids__(), val=val
)
@update_battery
@log_execution_time("logger", "client_evaluate_time")
def evaluate(self, server_device: str, val=True) -> DiagnosticMetricResultContainer:
# the execution time is measured in the client in order to deduct the time for the server
return self.client.evaluate(server_device, val)
@update_battery
def evaluate_on(
self, device_id, server_device, val
) -> DiagnosticMetricResultContainer:
if device_id == self.device_id:
return self.evaluate(server_device, val)
else:
return self.request_dispatcher.evaluate_on(device_id, server_device, val)
@add_time_to_diagnostic_metrics("evaluate_batch")
@update_battery
def evaluate_batch(self, smashed_data, labels):
result = self.server.evaluate_batch(smashed_data, labels)
self._log_current_battery_capacity()
return result
@update_battery
def evaluate_batch_on(self, device_id, smashed_data, labels):
if device_id == self.device_id:
return self.evaluate_batch(smashed_data, labels)
else:
self._log_current_battery_capacity()
return self.request_dispatcher.evaluate_batch_on(
device_id, smashed_data, labels
)
@add_time_to_diagnostic_metrics("federated_train")
@update_battery
@log_execution_time("logger", "fed_train_time")
def federated_train(
self, round_no: int = -1
) -> Tuple[
Any, Any, int, ModelMetricResultContainer, DiagnosticMetricResultContainer
]:
"""Returns client and server weights, the number of samples used for training and metrics"""
client_weights, server_weights, metrics, _, diagnostic_metrics = (
self.server.train(devices=[self.device_id], epochs=1, round_no=round_no)
)
num_samples = self.client.get_num_samples()
return client_weights, server_weights, num_samples, metrics, diagnostic_metrics
def _log_current_battery_capacity(self):
"""Wrapper for logging the current battery capacity"""
self.logger.log({"battery": self.battery.remaining_capacity()})
class RPCDeviceServicer(DeviceServicer):
def __init__(self, device: NetworkDevice):
self.device = device
def TrainGlobal(self, request, context):
print(f"Called TrainGlobal on device {self.device.device_id}")
client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
self.device.train_global(request.epochs, request.round_no)
)
response = connection_pb2.TrainGlobalResponse(
client_weights=Weights(weights=state_dict_to_proto(client_weights)),
server_weights=Weights(weights=state_dict_to_proto(server_weights)),
metrics=metrics_to_proto(metrics),
optimizer_state=state_dict_to_proto(optimizer_state),
diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
)
return response
def SetWeights(self, request, context):
print(f"Called SetWeights on device {self.device.device_id}")
weights = proto_to_state_dict(request.weights.weights)
self.device.set_weights(weights, request.on_client)
return connection_pb2.SetWeightsResponse()
def TrainEpoch(self, request, context):
print(f"Called TrainEpoch on device {self.device.device_id}")
device_info = request.server
device_id = device_info.device_id
round_no = request.round_no
weights, diagnostic_metrics = self.device.train_epoch(device_id, round_no)
proto_weights = state_dict_to_proto(weights)
return connection_pb2.TrainEpochResponse(
weights=Weights(weights=proto_weights),
diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
)
def TrainBatch(self, request, context):
activations = proto_to_tensor(request.smashed_data.activations)
labels = proto_to_tensor(request.labels.labels)
gradients, loss, diagnostic_metrics = self.device.train_batch(
activations, labels
)
proto_gradients = Gradients(gradients=tensor_to_proto(gradients))
return connection_pb2.TrainBatchResponse(
gradients=proto_gradients,
loss=loss,
diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
)
def EvaluateGlobal(self, request, context):
print(f"Called EvaluateGlobal on device {self.device.device_id}")
metrics, diagnostic_metrics = self.device.evaluate_global(
request.validation, request.federated
)
return connection_pb2.EvalGlobalResponse(
metrics=metrics_to_proto(metrics),
diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
)
def Evaluate(self, request, context):
print(f"Called Evaluate on device {self.device.device_id}")
diagnostic_metrics = self.device.evaluate(
request.server.device_id, request.validation
)
return connection_pb2.EvalResponse(
diagnostic_metrics=metrics_to_proto(diagnostic_metrics)
)
def EvaluateBatch(self, request, context):
activations = proto_to_tensor(request.smashed_data.activations)
labels = proto_to_tensor(request.labels.labels)
diagnostic_metrics = self.device.evaluate_batch(activations, labels)
return connection_pb2.EvalBatchResponse(
diagnostic_metrics=metrics_to_proto(diagnostic_metrics)
)
def FullModelTraining(self, request, context):
print(f"Called Full Training on device {self.device.device_id}")
client_weights, server_weights, num_samples, metrics, diagnostic_metrics = (
self.device.federated_train(request.round_no)
)
return connection_pb2.FullModelTrainResponse(
client_weights=Weights(weights=state_dict_to_proto(client_weights)),
server_weights=Weights(weights=state_dict_to_proto(server_weights)),
num_samples=num_samples,
metrics=metrics_to_proto(metrics),
diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
)
def StartExperiment(self, request, context) -> StartExperimentResponse:
print(f"Start Experiment on device {self.device.device_id}")
self.device.start_experiment()
return connection_pb2.StartExperimentResponse()
def EndExperiment(self, request, context) -> EndExperimentResponse:
print(f"End Experiment on device {self.device.device_id}")
print(f"Remaining battery capacity {self.device.battery.remaining_capacity()}")
self.device.end_experiment()
return connection_pb2.EndExperimentResponse()
def GetBatteryStatus(self, request, context):
print(f"Get Battery Status on device {self.device.device_id}")
initial_capacity, remaining_capacity = self.device.get_battery_status()
return connection_pb2.BatteryStatusResponse(
status=BatteryStatus(
initial_battery_level=initial_capacity,
current_battery_level=remaining_capacity,
)
)
def GetDatasetModelInfo(self, request, context):
print(f"Get Dataset and Model Info on device {self.device.device_id}")
return connection_pb2.DatasetModelInfoResponse(
train_samples=len(self.device.client._train_data.dataset),
validation_samples=len(self.device.client._val_data.dataset),
client_model_flops=int(self.device.client._model_flops),
server_model_flops=int(self.device.server._model_flops),
)
def TrainGlobalParallelSplitLearning(self, request, context):
print(f"Starting parallel split learning")
clients = self.device.__get_device_ids__()
round_no = request.round_no
adaptive_learning_threshold = request.adaptive_learning_threshold
cw, sw, model_metrics, optimizer_state, diagnostic_metrics = (
self.device.train_parallel_split_learning(
clients=clients,
round_no=round_no,
adaptive_learning_threshold=adaptive_learning_threshold,
)
)
response = connection_pb2.TrainGlobalParallelSplitLearningResponse(
client_weights=Weights(weights=state_dict_to_proto(cw)),
server_weights=Weights(weights=state_dict_to_proto(sw)),
metrics=metrics_to_proto(model_metrics),
optimizer_state=state_dict_to_proto(optimizer_state),
diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
)
return response
def TrainSingleBatchOnClient(self, request, context):
batch_index = request.batch_index
print(f"Starting single batch@{batch_index}")
result = self.device.client.train_single_batch(batch_index)
# If the last batch's size was smaller than the configured batch size, the train call returns None
if result is None:
return connection_pb2.SingleBatchTrainingResponse()
smashed_data = Activations(activations=tensor_to_proto(result.smashed_data))
labels = Labels(labels=tensor_to_proto(result.labels))
return connection_pb2.SingleBatchTrainingResponse(
smashed_data=smashed_data,
labels=labels,
)
def BackwardPropagationSingleBatchOnClient(
self, request: SingleBatchBackwardRequest, context
):
gradients = proto_to_tensor(request.gradients.gradients)
metrics = self.device.client.backward_single_batch(gradients=gradients)
return connection_pb2.SingleBatchBackwardResponse(
metrics=metrics_to_proto(metrics)
)
class DeviceRequestDispatcher:
def __init__(
self,
devices: List[DictConfig[str, Any]],
logger: Optional[SimpleLogger] = None,
battery: Optional[Battery] = None,
stop_event: Optional[threading.Event] = None,
device_id: Optional[str] = None,
):
self.devices = devices
self.connections: Dict[str, DeviceStub] = {}
# optional, interceptor only works if all three are set
self.logger = logger
self.battery = battery
self.stop_event = stop_event
self._establish_connections()
self.connections_lock = threading.Lock()
self.device_id = device_id # used for diagnostic metrics to assign the source device correctly
def __get_device_address__(self, device_id: str) -> Optional[str]:
for device in self.devices:
if device.device_id == device_id:
return device.address
return None
def train_parallel_on_server(
self,
server_device_id: str,
epochs: int,
round_no: int,
adaptive_learning_threshold: Optional[float] = None,
optimizer_state: dict[str, Any] = None,
):
print(f"><><><> {adaptive_learning_threshold}")
try:
response: TrainGlobalParallelSplitLearningResponse = self._get_connection(
server_device_id
).TrainGlobalParallelSplitLearning(
connection_pb2.TrainGlobalParallelSplitLearningRequest(
round_no=round_no,
adaptive_learning_threshold=adaptive_learning_threshold,
optimizer_state=state_dict_to_proto(optimizer_state),
)
)
return (
proto_to_weights(response.client_weights),
proto_to_weights(response.server_weights),
proto_to_metrics(response.metrics),
proto_to_state_dict(response.optimizer_state),
self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
)
except grpc.RpcError:
self._handle_rpc_error(server_device_id)
except KeyError:
self._handle_unknown_device_id(server_device_id)
return False
def _establish_connections(self):
for device in self.devices:
channel = grpc.insecure_channel(
device.address,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)
if (
self.logger is not None
and self.battery is not None
and self.stop_event is not None
):
channel = grpc.intercept_channel(
channel,
DeviceClientInterceptor(self.logger, self.battery, self.stop_event),
)
stub = DeviceStub(channel)
self.connections[device.device_id] = stub
print(f"active devices {self.connections.keys()}")
def _get_connection(self, device_id: str) -> DeviceStub:
"""Returns the connection to the device with the given id.
If the device_id is not a valid dictionary key, the caller has to handle the KeyError.
"""
with self.connections_lock:
return self.connections[device_id]
def active_devices(self) -> List[str]:
with self.connections_lock: # lock connections to prevent remove while iterating
return list(self.connections.keys())
def _handle_rpc_error(self, device_id: str):
"""Hook for handling an RpcError. Happens when the device encounters an empty battery during handling
the request or is not reachable (i.e. most probably shut down).
Here, the connection to the device is removed using a lock to avoid race conditions.
"""
with self.connections_lock:
del self.connections[device_id]
def _handle_unknown_device_id(self, device_id: str):
"""Handler in case the requested device id is not in the connection dictionary.
Could happen if the active devices weren't refreshed properly."""
if device_id in [device.device_id for device in self.devices]:
print(f"Device {device_id} not active")
else:
print(f"Unknown Device ID {device_id}")
def _add_byte_size_to_diagnostic_metrics(
self, response: Message, device_id: str, request=None
):
"""
Adds the byte size of the response to the diagnostic metrics.
Args:
response: A protobuf message.
device_id: The id of the device the request was sent from.
request: A protobuf message, if the request is of interest.
Returns:
A DiagnosticMetricResultContainer including the added byte size and previous metrics from the response (and possibly the request).
Raises:
None
Notes:
"""
if response.HasField("diagnostic_metrics"):
response: HasMetrics
diagnostic_metrics = proto_to_metrics(response.diagnostic_metrics)
else:
diagnostic_metrics = DiagnosticMetricResultContainer()
diagnostic_metrics.merge(_proto_size_per_field(response, device_id))
if request is not None:
diagnostic_metrics.merge(_proto_size_per_field(request, device_id))
return diagnostic_metrics
def train_global_on(
self,
device_id: str,
epochs: int,
round_no: int = -1,
optimizer_state: dict[str, Any] = None,
) -> Union[
Tuple[
Dict[str, Any],
Dict[str, Any],
ModelMetricResultContainer,
Dict[str, Any],
DiagnosticMetricResultContainer,
],
bool,
]:
try:
response: TrainGlobalResponse = self._get_connection(device_id).TrainGlobal(
connection_pb2.TrainGlobalRequest(
epochs=epochs,
round_no=round_no,
optimizer_state=state_dict_to_proto(optimizer_state),
)
)
return (
proto_to_weights(response.client_weights),
proto_to_weights(response.server_weights),
proto_to_metrics(response.metrics),
proto_to_state_dict(response.optimizer_state),
self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def set_weights_on(
self, device_id: str, state_dict, on_client: bool, wait_for_ready: bool = False
):
try:
self._get_connection(device_id).SetWeights(
SetWeightsRequest(
weights=Weights(weights=state_dict_to_proto(state_dict)),
on_client=on_client,
),
wait_for_ready=wait_for_ready,
)
return True
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def train_epoch_on(self, device_id: str, server_device: str, round_no: int = -1):
try:
response: TrainEpochResponse = self._get_connection(device_id).TrainEpoch(
connection_pb2.TrainEpochRequest(
server=DeviceInfo(
device_id=server_device,
address=self.__get_device_address__(server_device),
),
round_no=round_no,
)
)
return proto_to_state_dict(
response.weights.weights
), self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def train_batch_on(self, device_id: str, smashed_data, labels):
try:
request = TrainBatchRequest(
smashed_data=Activations(activations=tensor_to_proto(smashed_data)),
labels=Labels(labels=tensor_to_proto(labels)),
)
response: TrainBatchResponse = self._get_connection(device_id).TrainBatch(
request
)
return (
proto_to_tensor(response.gradients.gradients),
response.loss,
self._add_byte_size_to_diagnostic_metrics(
response, self.device_id, request
),
)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def evaluate_global_on(self, device_id: str, val: bool = True, fed: bool = False):
try:
response: EvalGlobalResponse = self._get_connection(
device_id
).EvaluateGlobal(
connection_pb2.EvalGlobalRequest(validation=val, federated=fed)
)
return proto_to_metrics(
response.metrics
), self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def evaluate_on(self, device_id: str, server_device: str, val: bool):
try:
response: EvalResponse = self._get_connection(device_id).Evaluate(
connection_pb2.EvalRequest(
server=DeviceInfo(
device_id=server_device,
address=self.__get_device_address__(server_device),
),
validation=val,
)
)
return self._add_byte_size_to_diagnostic_metrics(response, self.device_id)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def evaluate_batch_on(self, device_id: str, smashed_data, labels):
try:
request = connection_pb2.EvalBatchRequest(
smashed_data=Activations(activations=tensor_to_proto(smashed_data)),
labels=Labels(labels=tensor_to_proto(labels)),
)
response: EvalBatchResponse = self._get_connection(device_id).EvaluateBatch(
request
)
return self._add_byte_size_to_diagnostic_metrics(
response, self.device_id, request
)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def federated_train_on(self, device_id: str, round_no: int = -1) -> Union[
Tuple[
StateDict,
StateDict,
int,
],
bool,
]:
try:
response: FullModelTrainResponse = self._get_connection(
device_id
).FullModelTraining(connection_pb2.FullModelTrainRequest(round_no=round_no))
return (
proto_to_weights(response.client_weights),
proto_to_weights(response.server_weights),
response.num_samples,
proto_to_metrics(response.metrics),
self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def start_experiment_on(self, device_id: str, wait_for_ready: bool = False) -> bool:
try:
self._get_connection(device_id).StartExperiment(
connection_pb2.StartExperimentRequest(), wait_for_ready=wait_for_ready
)
return True
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def end_experiment_on(self, device_id: str) -> bool:
try:
self._get_connection(device_id).EndExperiment(
connection_pb2.EndExperimentRequest()
)
return True
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def get_battery_status_on(self, device_id: str) -> DeviceBatteryStatusReport:
try:
response: BatteryStatusResponse = self._get_connection(
device_id
).GetBatteryStatus(connection_pb2.BatteryStatusRequest())
return DeviceBatteryStatus(
current_capacity=response.status.current_battery_level,
initial_capacity=response.status.initial_battery_level,
)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def train_batch_on_client_only(
self, device_id: str, batch_index: int
) -> Tuple[Tensor, Tensor] | None:
try:
response: SingleBatchTrainingResponse = self._get_connection(
device_id
).TrainSingleBatchOnClient(
connection_pb2.SingleBatchTrainingRequest(batch_index=batch_index)
)
# The response can only be None if the last batch was smaller than the configured batch size.
if response.HasField("smashed_data"):
return (
proto_to_tensor(response.smashed_data.activations),
proto_to_tensor(response.labels.labels),
)
return None
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def get_dataset_model_info_on(
self, device_id: str
) -> Union[Tuple[int, int, float, float], bool]:
try:
response: DatasetModelInfoResponse = self._get_connection(
device_id
).GetDatasetModelInfo(connection_pb2.DatasetModelInfoRequest())
return (
response.train_samples,
response.validation_samples,
response.client_model_flops,
response.server_model_flops,
)
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False
def backpropagation_on_client_only(self, device_id, gradients):
try:
response: Empty = self._get_connection(
device_id
).BackwardPropagationSingleBatchOnClient(
connection_pb2.SingleBatchBackwardRequest(
gradients=Gradients(gradients=tensor_to_proto(gradients))
)
)
return response
except grpc.RpcError:
self._handle_rpc_error(device_id)
except KeyError:
self._handle_unknown_device_id(device_id)
return False