diff --git a/edml/core/client.py b/edml/core/client.py index d7ff6723297bc45a9f2b5fd3e3bb002bb3de0111..fe41e2cdb3c0be6847b37df8ae3c7cae8c4eba7d 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -135,7 +135,9 @@ class DeviceClient: return len(self._train_data.dataset) @check_device_set() - def train_single_batch(self, batch_index: int) -> Optional[SLTrainBatchResult]: + def train_single_batch( + self, batch_index: int + ) -> Optional[torch.Tensor, torch.Tensor]: torch.cuda.set_device(self._device) # We have to re-initialize the data loader in the case that we do another epoch. if batch_index == 0: @@ -180,11 +182,7 @@ class DeviceClient: "start_time": start_time, "end_time": end_time, } - - return SLTrainBatchResult( - smashed_data=smashed_data, - labels=batch_labels, - ) + return smashed_data, batch_labels @check_device_set() def backward_single_batch( @@ -378,7 +376,7 @@ class DeviceClient: ) return diagnostic_metric_results - def set_gradient_and_finalize_training(self, gradients: Any): + def set_gradient_and_finalize_training(self, gradients: Any): for param, grad in zip(self._model.parameters(), gradients): param.grad = grad.to(self._device) diff --git a/edml/core/device.py b/edml/core/device.py index 2ba6833226b537118f8c02399a0c6c4a4967e47f..51278532bfb7f1888f13671a24a289fc10e868b2 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -266,10 +266,8 @@ class NetworkDevice(Device): @update_battery @log_execution_time("logger", "client_only_batch_train") def train_batch_on_client_only(self, batch_index: int): - result = self.client.train_single_batch(batch_index=batch_index) - if result is None: - return result - return result.smashed_data, result.labels + smashed_data, labels = self.client.train_single_batch(batch_index=batch_index) + return smashed_data, labels @update_battery def train_batch_on_client_only_on(self, device_id: str, batch_index: int): @@ -579,14 +577,10 @@ class RPCDeviceServicer(DeviceServicer): batch_index = request.batch_index print(f"Starting single batch@{batch_index}") - result = self.device.client.train_single_batch(batch_index) + smashed_data, labels = self.device.client.train_single_batch(batch_index) - # If the last batch's size was smaller than the configured batch size, the train call returns None - if result is None: - return connection_pb2.SingleBatchTrainingResponse() - - smashed_data = Activations(activations=tensor_to_proto(result.smashed_data)) - labels = Labels(labels=tensor_to_proto(result.labels)) + smashed_data = Activations(activations=tensor_to_proto(smashed_data)) + labels = Labels(labels=tensor_to_proto(labels)) return connection_pb2.SingleBatchTrainingResponse( smashed_data=smashed_data, labels=labels, @@ -597,14 +591,16 @@ class RPCDeviceServicer(DeviceServicer): ): gradients = proto_to_tensor(request.gradients.gradients) - metrics, gradients = self.device.client.backward_single_batch(gradients=gradients) + metrics, gradients = self.device.client.backward_single_batch( + gradients=gradients + ) return connection_pb2.SingleBatchBackwardResponse( metrics=metrics_to_proto(metrics), gradients=Gradients(gradients=tensor_to_proto(gradients)), ) - + def SetGradientsAndFinalizeTrainingStep( - self, request: SetGradientsRequest, context + self, request: SetGradientsRequest, context ): gradients = proto_to_tensor(request.gradients.gradients) self.device.client.set_gradient_and_finalize_training(gradients=gradients) diff --git a/edml/core/server.py b/edml/core/server.py index b78793678f6f77bc84354c93f9d32064eee657c1..8853bd92bc3d5dc438e7ada06cb4dce5e62967c5 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -260,19 +260,15 @@ class DeviceServer: for client_id in clients ] for future in concurrent.futures.as_completed(futures): - client_forward_pass_responses.append(future.result()) + (client_id, result) = future.result() + if result is not None and result is not False: + client_forward_pass_responses.append((client_id, result)) # We want to split up the responses into a list of client IDs and batches again. client_ids = [b[0] for b in client_forward_pass_responses] client_batches = [b[1] for b in client_forward_pass_responses] - if _empty_batches(client_batches): - # Only the last batch anyway. - break - print(f"\n\n\nBATCHES: {len(client_batches)}\n\n\n") - # batches2 = [b for b in batches if b is not None] - # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n") server_batch = _concat_smashed_data( [b[0].to(self._device) for b in client_batches] ) @@ -307,10 +303,6 @@ class DeviceServer: ) client_gradients = torch.chunk(server_gradients, num_client_gradients) - # print(f"::: result shape: {client_gradients[1].shape}") - # concatenated_client_gradients = torch.stack(client_gradients, dim=0) - # mean_tensor = torch.mean(concatenated_client_gradients, dim=0) - # print(f"::: -> {mean_tensor.shape}") futures = [ executor.submit( @@ -323,13 +315,15 @@ class DeviceServer: client_backpropagation_results.append(future.result()) client_backpropagation_gradients = [ - result[1] for result in client_backpropagation_results + result[1] + for result in client_backpropagation_results + if result is not None and result is not False ] # We want to average the client's backpropagation gradients and send them over again to finalize the # current training step. averaged_gradient = _calculate_gradient_mean( - client_backpropagation_gradients + client_backpropagation_gradients, self._device ) futures = [ executor.submit( @@ -383,9 +377,11 @@ class DeviceServer: ) -def _calculate_gradient_mean(gradients: List[Variable], device: str = "cpu") -> Variable: +def _calculate_gradient_mean( + gradients: List[Variable], device: str = "cpu" +) -> Variable: num_devices = len(gradients) - weights = [1] * num_devices + weights = [1 / num_devices] * num_devices # We need to move all tensors to the same device to do calculations. for i, client_gradients in enumerate(gradients): @@ -401,8 +397,3 @@ def _calculate_gradient_mean(gradients: List[Variable], device: str = "cpu") -> def _concat_smashed_data(data: List[Any]) -> Any: """Creates a single batch tensor from a list of tensors.""" return torch.cat(data, dim=0) - - -def _empty_batches(batches): - """Checks if all the list entries are `None`.""" - return batches.count(None) == len(batches) diff --git a/edml/tests/core/server_test.py b/edml/tests/core/server_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1331f8aa7bad195eb02044bcf25d381fe2ee2a13 --- /dev/null +++ b/edml/tests/core/server_test.py @@ -0,0 +1,192 @@ +import unittest +from collections import OrderedDict +from copy import copy +from typing import Tuple, Any +from unittest.mock import Mock + +import torch.utils.data +from omegaconf import DictConfig +from torch import nn, tensor +from torch.utils.data import DataLoader, Dataset, TensorDataset + +from edml.core.battery import Battery +from edml.core.client import DeviceClient +from edml.core.device import Device +from edml.core.server import DeviceServer + + +class ClientModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 1) + self.output = nn.ReLU() + + def forward(self, x): + return self.output(self.layer(x)) + + +class ServerModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 2) + # self.layer.weight = tensor([[-0.6857]]) + self.output = nn.Softmax(dim=1) + + def forward(self, x): + return self.output(self.layer(x)) + + +class ToyDataset(Dataset): + def __init__(self, data: list, labels: list): + self.length = len(data) + self.data = torch.Tensor(data) + self.labels = torch.Tensor(labels, dtype=int) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + return self.data[index], self.labels[index] + + def __len__(self): + return self.length + + +class PSLTest(unittest.TestCase): + def setUp(self): + cfg = DictConfig( + { + "optimizer": { + "_target_": "torch.optim.SGD", + "lr": 1, + "momentum": 0, + "weight_decay": 0, + }, + "experiment": {"metrics": ["accuracy"]}, + "dataset": {"num_classes": 2, "average_setting": "micro"}, + "topology": { + "devices": [ + { + "device_id": "d0", + "address": "localhost:50051", + "torch_device": "cuda:0", + }, + { + "device_id": "d1", + "address": "localhost:50052", + "torch_device": "cuda:0", + }, + ] + }, + "own_device_id": "d0", + } + ) + # init models with fixed weights for repeatability + server_state_dict = OrderedDict( + [ + ("layer.weight", tensor([[-0.5], [-1.0]])), + ("layer.bias", tensor([-0.5, 0.25])), + ] + ) + client_state_dict = OrderedDict( + [("layer.weight", tensor([[-1.0]])), ("layer.bias", tensor([0.5]))] + ) + server_model = ServerModel() + server_model.load_state_dict(server_state_dict) + client_model1 = ClientModel() + client_model1.load_state_dict(client_state_dict) + client_model2 = ClientModel() + client_model2.load_state_dict(client_state_dict) + self.server = DeviceServer( + model=server_model, loss_fn=torch.nn.L1Loss(), cfg=cfg + ) + self.client1 = DeviceClient( + model=client_model1, + cfg=cfg, + train_dl=DataLoader(TensorDataset(tensor([[0.9]]), tensor([[0.0, 1.0]]))), + val_dl=DataLoader(TensorDataset(tensor([[0.75]]), tensor([[0.0, 1.0]]))), + test_dl=None, + ) + client2_cfg = cfg.copy() + client2_cfg["own_device_id"] = "d1" + self.client2 = DeviceClient( + model=client_model2, + cfg=client2_cfg, + train_dl=DataLoader(TensorDataset(tensor([[0.1]]), tensor([[1.0, 0.0]]))), + val_dl=DataLoader(TensorDataset(tensor([[0.25]]), tensor([[1.0, 0.0]]))), + test_dl=None, + ) + + def get_client_side_effect(fn): + """ + Creates side effects for methods of the form METHOD_on(device_id) skipping the request dispatcher. + """ + + def side_effect(*args, **kwargs): + # get the device id which is either a positional or keyword arg + if len(args) > 0: + device_id = args[0] + elif "client_id" in kwargs: + device_id = kwargs.pop("client_id") + elif "device_id" in kwargs: + device_id = kwargs.pop("device_id") + else: + return KeyError( + f"Could not find device_id in args or kwargs for function {fn}" + ) + # delegate to correct client then using the given method name + if device_id == "d0": + return self.client1.__class__.__dict__[fn]( + self.client1, *args, **kwargs + ) + elif device_id == "d1": + return self.client2.__class__.__dict__[fn]( + self.client2, *args, **kwargs + ) + else: + return KeyError(f"Unknown device_id {device_id}") + + return side_effect + + def get_server_side_effect(fn): + def side_effect(*args, **kwargs): + return self.server.__class__.__dict__[fn]( + self.server, *args, **kwargs + ) + ( + 1, + ) # Add (1,) as placeholder for DiagnosticMetricsContainer + + return side_effect + + node_device = Mock(Device) + node_device.battery = Mock(Battery) + node_device.train_batch_on_client_only_on.side_effect = get_client_side_effect( + "train_single_batch" + ) + node_device.backpropagation_on_client_only_on.side_effect = ( + get_client_side_effect("backward_single_batch") + ) + node_device.set_gradient_and_finalize_training_on_client_only_on.side_effect = ( + get_client_side_effect("set_gradient_and_finalize_training") + ) + node_device.train_batch.side_effect = get_server_side_effect("train_batch") + node_device.evaluate_on.side_effect = get_client_side_effect("evaluate") + node_device.evaluate_batch.side_effect = get_server_side_effect( + "evaluate_batch" + ) + + self.node_device1 = copy(node_device) + self.node_device2 = copy(node_device) + self.node_device1.client = self.client1 + self.node_device1.device_id = "d0" + self.node_device2.device_id = "d1" + self.client1.set_device(self.node_device1) + self.client2.set_device(self.node_device2) + self.server.set_device(self.node_device1) + + def test_train_parallel_sl(self): + ( + client_weights, + server_weights, + model_metrics, + optimizer_state, + diagnostic_metrics, + ) = self.server.train_parallel_split_learning(["d0", "d1"], round_no=0) + self.assertDictEqual(self.client1.get_weights(), self.client2.get_weights())