Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
server.py 23.86 KiB
from __future__ import annotations

import concurrent.futures
from typing import List, Optional, Tuple, Any, TYPE_CHECKING

import torch
from omegaconf import DictConfig
from colorama import Fore
from torch import nn
from torch.autograd import Variable

from edml.helpers.config_helpers import get_torch_device_id
from edml.helpers.decorators import check_device_set, simulate_latency_decorator, Timer
from edml.helpers.executor import create_executor_with_threads
from edml.helpers.flops import estimate_model_flops
from edml.helpers.load_optimizer import get_optimizer_and_scheduler
from edml.helpers.metrics import (
    create_metrics,
    ModelMetricResultContainer,
    DiagnosticMetricResultContainer,
)
from edml.helpers.types import StateDict, LossFn

if TYPE_CHECKING:
    from edml.core.device import Device


class DeviceServer:
    """Split learning server that runs on a device and communicates with clients on (potentially) other devices
    through the provided interface by its device."""

    def __init__(
        self,
        model: nn.Module,
        loss_fn: LossFn,
        cfg: DictConfig,
        latency_factor: float = 0.0,
    ):
        """Initializes the server with the given model, loss function, configuration and reference to its device."""
        self._device = torch.device(get_torch_device_id(cfg))
        self._model = model.to(self._device)
        self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
            cfg, self._model.parameters()
        )
        self._model_flops = 0  # determine later
        self._metrics = create_metrics(
            cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting
        )
        self._loss_fn = loss_fn
        self._cfg = cfg
        self.node_device: Optional[Device] = None
        self.latency_factor = latency_factor

    def set_device(self, node_device: Device):
        """Sets the device reference for the server."""
        self.node_device = node_device

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def set_weights(self, state_dict: StateDict):
        """Sets the weights of the server's model"""
        if state_dict is not None:
            self._model.load_state_dict(state_dict=state_dict)

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def get_weights(self):
        """Returns the weights of the server's model"""
        return self._model.state_dict()

    @check_device_set()
    def train(
        self,
        devices: List[str],
        epochs: int = 1,
        round_no: int = -1,
        optimizer_state: dict[str, Any] = None,
    ) -> Tuple[
        Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
    ]:
        """Train the model on the given devices for the given number of epochs.
        Shares the weights among clients and saves the final weights to the configured paths.
        Args:
            devices: The devices to train on
            epochs: Optionally, the number of epochs to train.
            round_no: Optionally, the current global epoch number if a learning rate scheduler is used.
            optimizer_state: Optionally, the optimizer_state to proceed from
        """
        client_weights = None
        metrics = ModelMetricResultContainer()
        diagnostic_metric_container = DiagnosticMetricResultContainer()
        if optimizer_state is not None:
            self._optimizer.load_state_dict(optimizer_state)
        for epoch in range(epochs):
            if self._lr_scheduler is not None:
                if round_no != -1:
                    self._lr_scheduler.step(round_no + epoch)
                else:
                    self._lr_scheduler.step()
            for device_id in devices:
                print(
                    f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
                )
                if client_weights is not None:
                    self.node_device.set_weights_on(
                        device_id=device_id,
                        state_dict=client_weights,
                        on_client=True,  # we want to set client weights
                    )
                train_epoch_response = self.node_device.train_epoch_on(
                    device_id, self.node_device.device_id, round_no=round_no + epoch
                )
                if (
                    train_epoch_response is not False
                    and train_epoch_response is not None
                ):
                    client_weights, diagnostic_metrics = train_epoch_response
                    train_metrics = self.finalize_metrics(str(device_id), "train")
                    diagnostic_metric_container.merge(diagnostic_metrics)

                    diagnostic_metrics = self.node_device.evaluate_on(
                        device_id, server_device=self.node_device.device_id, val=True
                    )
                    if diagnostic_metrics is not None:
                        diagnostic_metric_container.merge(diagnostic_metrics)
                    val_metrics = self.finalize_metrics(str(device_id), "val")

                    metrics.add_results(train_metrics)
                    metrics.add_results(val_metrics)
        return (
            client_weights,
            self.get_weights(),
            metrics,
            self._optimizer.state_dict(),
            diagnostic_metric_container,
        )

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]:
        """Train the model on the given batch of data and labels.
        Returns the gradients of the model's parameters."""
        smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)

        self._set_model_flops(smashed_data)

        self._optimizer.zero_grad()

        self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
        smashed_data = Variable(smashed_data, requires_grad=True)
        output_train = self._model(smashed_data)

        loss_train = self._loss_fn(output_train, labels)

        self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2)
        loss_train.backward()
        self._optimizer.step()

        # Capturing training metrics for the current batch.
        self.node_device.log({"loss": loss_train.item()})
        self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())

        return smashed_data.grad, loss_train.item()

    def _set_model_flops(self, smashed_data):
        """Helper to determine the model flops when smashed data are available for the first time."""
        if self._model_flops == 0:
            self._model_flops = estimate_model_flops(self._model, smashed_data) / len(
                smashed_data
            )

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def finalize_metrics(self, device_id: str, phase: str):
        """Computes the total results of the metrics. Logs the results clears the cached predictions.
        Returns a list of results."""
        metric_result_list = self._metrics.compute_metrics(
            phase=phase, device_id=device_id
        )
        for metric_result in metric_result_list:
            self.node_device.log(metric_result.as_loggable_dict())
        self._metrics.reset_metrics()
        return metric_result_list

    @check_device_set()
    def evaluate_global(
        self, devices: List[str], val: bool
    ) -> Tuple[ModelMetricResultContainer, DiagnosticMetricResultContainer]:
        """Evaluates on the given devices using the own server model. Returns the gathered metrics."""
        result_metrics = ModelMetricResultContainer()
        diagnostic_metric_results = DiagnosticMetricResultContainer()
        for device_id in devices:
            phase = "val" if val else "test"
            print(
                f"Evaluate with {phase} data on client {device_id} with server {self.node_device.device_id}"
            )

            diagnostic_metrics = self.node_device.evaluate_on(
                device_id, server_device=self.node_device.device_id, val=val
            )

            metrics = self.finalize_metrics(str(device_id), f"{phase}")
            result_metrics.add_results(metrics)
            diagnostic_metric_results.merge(diagnostic_metrics)
        return result_metrics, diagnostic_metric_results

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def evaluate_batch(self, smashed_data, labels):
        """Evaluates the model on the given batch of data and labels"""
        with torch.no_grad():
            smashed_data = smashed_data.to(self._device)
            self._set_model_flops(smashed_data)
            self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
            pred = self._model(smashed_data)
        self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int())

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def train_parallel_split_learning(
        self,
        clients: List[str],
        round_no: int,
        adaptive_learning_threshold: Optional[float] = None,
        optimizer_state: dict[str, Any] = None,
    ):
        def client_training_job(client_id: str, batch_index: int):
            result = self.node_device.train_batch_on_client_only_on(
                device_id=client_id,
                batch_index=batch_index,
                round_no=round_no,
                # round_no is taken from outer method arg
            )
            return (client_id, result)

        def client_backpropagation_job(client_id: str, gradients: Any):
            return self.node_device.backpropagation_on_client_only_on(
                client_id=client_id, gradients=gradients
            )

        def client_set_gradient_and_finalize_training_job(
            client_id: str, gradients: Any
        ):
            return (
                self.node_device.set_gradient_and_finalize_training_on_client_only_on(
                    client_id=client_id, gradients=gradients
                )
            )

        if optimizer_state is not None:
            self._optimizer.load_state_dict(optimizer_state)

        if self._lr_scheduler is not None:
            if round_no != -1:
                self._lr_scheduler.step(round_no + 1)  # epoch=1
            else:
                self._lr_scheduler.step()

        num_threads = len(clients)
        executor = create_executor_with_threads(num_threads)

        # batches = []
        model_metrics = ModelMetricResultContainer()
        diagnostic_metrics = DiagnosticMetricResultContainer()

        # We iterate over each batch, initializing all client training at once and processing the results afterward.
        num_batches = self.node_device.client.get_approximated_num_batches()
        print(f":: BATCHES :: {num_batches}")

        if self._cfg.simulate_parallelism:
            for batch_index in range(num_batches):
                client_forward_pass_responses = []
                parallel_times = []
                with Timer() as elapsed_time:
                    for client_id in clients:
                        with Timer() as individual_time:
                            (client_id, result) = client_training_job(
                                client_id, batch_index
                            )
                            if result is not None and result is not False:
                                client_forward_pass_responses.append(
                                    (client_id, result)
                                )
                        parallel_times.append(individual_time.execution_time)
                self.node_device.log(
                    {
                        "parallel_client_train_time": {
                            "elapsed_time": elapsed_time.execution_time,
                            "parallel_time": max(parallel_times),
                        }
                    }
                )
                # We want to split up the responses into a list of client IDs and batches again.
                client_ids = [b[0] for b in client_forward_pass_responses]
                client_batches = [b[1] for b in client_forward_pass_responses]

                server_batch = _concat_smashed_data(
                    [b[0].to(self._device) for b in client_batches]
                )
                server_labels = _concat_smashed_data(
                    [b[1].to(self._device) for b in client_batches]
                )
                # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
                # to split the gradients back into batch-sized tensors to average them before sending them to the client.
                server_gradients, server_loss, server_metrics = (
                    self.node_device.train_batch(server_batch, server_labels)
                )  # DiagnosticMetricResultContainer
                # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
                # do the client propagation once the current loss value is larger than the threshold.
                print(
                    f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
                )
                if (
                    adaptive_learning_threshold
                    and server_loss < adaptive_learning_threshold
                ):
                    print(
                        f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
                    )
                    self.node_device.log(
                        {
                            "adaptive_learning_threshold_applied": server_gradients.size(
                                0
                            )
                        }
                    )
                    continue

                num_client_gradients = len(client_forward_pass_responses)
                print(
                    f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
                )
                # clone single client gradients so that client_gradients is not a list of views of server_gradients
                # if we just use torch.chunk, each client will receive the whole server_gradients
                client_gradients = [
                    t.clone().detach()
                    for t in torch.chunk(server_gradients, num_client_gradients)
                ]
                client_backpropagation_gradients = []
                parallel_times = []
                with Timer() as elapsed_time:
                    for idx, client_id in enumerate(client_ids):
                        with Timer() as individual_time:
                            _, grads = client_backpropagation_job(
                                client_id, client_gradients[idx]
                            )
                            if grads is not None and grads is not False:
                                client_backpropagation_gradients.append(grads)
                        parallel_times.append(individual_time.execution_time)
                self.node_device.log(
                    {
                        "parallel_client_backprop_time": {
                            "elapsed_time": elapsed_time.execution_time,
                            "parallel_time": max(parallel_times),
                        }
                    }
                )
                # We want to average the client's backpropagation gradients and send them over again to finalize the
                # current training step.
                averaged_gradient = _calculate_gradient_mean(
                    client_backpropagation_gradients, self._device
                )
                parallel_times = []
                with Timer() as elapsed_time:
                    for client_id in clients:
                        with Timer() as individual_time:
                            client_set_gradient_and_finalize_training_job(
                                client_id, averaged_gradient
                            )
                        parallel_times.append(individual_time.execution_time)
                self.node_device.log(
                    {
                        "parallel_client_model_update_time": {
                            "elapsed_time": elapsed_time.execution_time,
                            "parallel_time": max(parallel_times),
                        }
                    }
                )
        else:
            for batch_index in range(num_batches):
                client_forward_pass_responses = []
                futures = [
                    executor.submit(client_training_job, client_id, batch_index)
                    for client_id in clients
                ]
                for future in concurrent.futures.as_completed(futures):
                    (client_id, result) = future.result()
                    if result is not None and result is not False:
                        client_forward_pass_responses.append((client_id, result))

                # We want to split up the responses into a list of client IDs and batches again.
                client_ids = [b[0] for b in client_forward_pass_responses]
                client_batches = [b[1] for b in client_forward_pass_responses]

                server_batch = _concat_smashed_data(
                    [b[0].to(self._device) for b in client_batches]
                )
                server_labels = _concat_smashed_data(
                    [b[1].to(self._device) for b in client_batches]
                )
                # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
                # to split the gradients back into batch-sized tensors to average them before sending them to the client.
                server_gradients, server_loss, server_metrics = (
                    self.node_device.train_batch(server_batch, server_labels)
                )  # DiagnosticMetricResultContainer
                # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
                # do the client propagation once the current loss value is larger than the threshold.
                print(
                    f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
                )
                if (
                    adaptive_learning_threshold
                    and server_loss < adaptive_learning_threshold
                ):
                    print(
                        f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
                    )
                    self.node_device.log(
                        {
                            "adaptive_learning_threshold_applied": server_gradients.size(
                                0
                            )
                        }
                    )
                    continue

                num_client_gradients = len(client_forward_pass_responses)
                print(
                    f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
                )
                # clone single client gradients so that client_gradients is not a list of views of server_gradients
                # if we just use torch.chunk, each client will receive the whole server_gradients
                client_gradients = [
                    t.clone().detach()
                    for t in torch.chunk(server_gradients, num_client_gradients)
                ]
                futures = [
                    executor.submit(
                        client_backpropagation_job, client_id, client_gradients[idx]
                    )
                    for (idx, client_id) in enumerate(client_ids)
                ]
                client_backpropagation_gradients = []
                for future in concurrent.futures.as_completed(futures):
                    _, grads = future.result()
                    if grads is not None and grads is not False:
                        client_backpropagation_gradients.append(grads)
                # We want to average the client's backpropagation gradients and send them over again to finalize the
                # current training step.
                averaged_gradient = _calculate_gradient_mean(
                    client_backpropagation_gradients, self._device
                )
                futures = [
                    executor.submit(
                        client_set_gradient_and_finalize_training_job,
                        client_id,
                        averaged_gradient,
                    )
                    for client_id in clients
                ]
                for future in concurrent.futures.as_completed(futures):
                    future.result()

        # Now we have to determine the model metrics for each client.

        if self._cfg.simulate_parallelism:
            parallel_times = []
            with Timer() as elapsed_time:
                for client_id in clients:
                    with Timer() as individual_time:
                        train_metrics = self.finalize_metrics(str(client_id), "train")

                        evaluation_diagnostics_metrics = self.node_device.evaluate_on(
                            device_id=client_id,
                            server_device=self.node_device.device_id,
                            val=True,
                        )
                        # if evaluation_diagnostics_metrics:
                        #     diagnostic_metrics.merge(evaluation_diagnostics_metrics)
                        val_metrics = self.finalize_metrics(str(client_id), "val")

                        model_metrics.add_results(train_metrics)
                        model_metrics.add_results(val_metrics)
                    parallel_times.append(individual_time.execution_time)
            self.node_device.log(
                {
                    "parallel_client_eval_time": {
                        "elapsed_time": elapsed_time.execution_time,
                        "parallel_time": max(parallel_times),
                    }
                }
            )
        else:
            for client_id in clients:
                train_metrics = self.finalize_metrics(str(client_id), "train")

                evaluation_diagnostics_metrics = self.node_device.evaluate_on(
                    device_id=client_id,
                    server_device=self.node_device.device_id,
                    val=True,
                )
                # if evaluation_diagnostics_metrics:
                #     diagnostic_metrics.merge(evaluation_diagnostics_metrics)
                val_metrics = self.finalize_metrics(str(client_id), "val")

                model_metrics.add_results(train_metrics)
                model_metrics.add_results(val_metrics)

        optimizer_state = self._optimizer.state_dict()
        # delete references and free GPU memory manually
        server_batch = None
        server_labels = None
        server_gradients = None
        client_gradients = None
        concatenated_client_gradients = None
        mean_tensor = None
        torch.cuda.empty_cache()
        torch.cuda.set_device(self._device)
        return (
            self.node_device.client.get_weights(),
            self.get_weights(),
            model_metrics,
            optimizer_state,
            diagnostic_metrics,
        )


def _calculate_gradient_mean(
    gradients: List[Variable], device: str = "cpu"
) -> Variable:
    num_devices = len(gradients)
    weights = [1 / num_devices] * num_devices

    # We need to move all tensors to the same device to do calculations.
    for i, client_gradients in enumerate(gradients):
        for j, grad in enumerate(client_gradients):
            gradients[i][j] = grad.to(device)

    return [
        sum(gradients[i][j] * weights[i] for i in range(num_devices))
        for j in range(len(gradients[0]))
    ]


def _concat_smashed_data(data: List[Any]) -> Any:
    """Creates a single batch tensor from a list of tensors."""
    return torch.cat(data, dim=0)