Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
server.py 15.03 KiB
from __future__ import annotations

import concurrent.futures
import time
from typing import List, Optional, Tuple, Any, TYPE_CHECKING

import torch
from omegaconf import DictConfig
from colorama import init, Fore
from torch import nn
from torch.autograd import Variable

from edml.helpers.config_helpers import get_torch_device_id
from edml.helpers.decorators import check_device_set, simulate_latency_decorator
from edml.helpers.executor import create_executor_with_threads
from edml.helpers.flops import estimate_model_flops
from edml.helpers.load_optimizer import get_optimizer_and_scheduler
from edml.helpers.metrics import (
    create_metrics,
    ModelMetricResultContainer,
    DiagnosticMetricResultContainer,
)
from edml.helpers.types import StateDict, SLTrainBatchResult, LossFn

if TYPE_CHECKING:
    from edml.core.device import Device


class DeviceServer:
    """Split learning server that runs on a device and communicates with clients on (potentially) other devices
    through the provided interface by its device."""

    def __init__(
        self,
        model: nn.Module,
        loss_fn: LossFn,
        cfg: DictConfig,
        latency_factor: float = 0.0,
    ):
        """Initializes the server with the given model, loss function, configuration and reference to its device."""
        self._device = torch.device(get_torch_device_id(cfg))
        self._model = model.to(self._device)
        self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
            cfg, self._model.parameters()
        )
        self._model_flops = 0  # determine later
        self._metrics = create_metrics(
            cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting
        )
        self._loss_fn = loss_fn
        self._cfg = cfg
        self.node_device: Optional[Device] = None
        self.latency_factor = latency_factor

    def set_device(self, node_device: Device):
        """Sets the device reference for the server."""
        self.node_device = node_device

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def set_weights(self, state_dict: StateDict):
        """Sets the weights of the server's model"""
        if state_dict is not None:
            self._model.load_state_dict(state_dict=state_dict)

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def get_weights(self):
        """Returns the weights of the server's model"""
        return self._model.state_dict()

    @check_device_set()
    def train(
        self,
        devices: List[str],
        epochs: int = 1,
        round_no: int = -1,
        optimizer_state: dict[str, Any] = None,
    ) -> Tuple[
        Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
    ]:
        """Train the model on the given devices for the given number of epochs.
        Shares the weights among clients and saves the final weights to the configured paths.
        Args:
            devices: The devices to train on
            epochs: Optionally, the number of epochs to train.
            round_no: Optionally, the current global epoch number if a learning rate scheduler is used.
            optimizer_state: Optionally, the optimizer_state to proceed from
        """
        client_weights = None
        metrics = ModelMetricResultContainer()
        diagnostic_metric_container = DiagnosticMetricResultContainer()
        if optimizer_state is not None:
            self._optimizer.load_state_dict(optimizer_state)
        for epoch in range(epochs):
            for device_id in devices:
                print(
                    f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
                )
                if client_weights is not None:
                    self.node_device.set_weights_on(
                        device_id=device_id,
                        state_dict=client_weights,
                        on_client=True,  # we want to set client weights
                    )
                train_epoch_response = self.node_device.train_epoch_on(
                    device_id, self.node_device.device_id, round_no=round_no + epoch
                )
                if (
                    train_epoch_response is not False
                    and train_epoch_response is not None
                ):
                    client_weights, diagnostic_metrics = train_epoch_response
                    train_metrics = self.finalize_metrics(str(device_id), "train")
                    diagnostic_metric_container.merge(diagnostic_metrics)

                    diagnostic_metrics = self.node_device.evaluate_on(
                        device_id, server_device=self.node_device.device_id, val=True
                    )
                    if diagnostic_metrics is not None:
                        diagnostic_metric_container.merge(diagnostic_metrics)
                    val_metrics = self.finalize_metrics(str(device_id), "val")

                    metrics.add_results(train_metrics)
                    metrics.add_results(val_metrics)
            if self._lr_scheduler is not None:
                if round_no != -1:
                    self._lr_scheduler.step(round_no + epoch)
                else:
                    self._lr_scheduler.step()
        return (
            client_weights,
            self.get_weights(),
            metrics,
            self._optimizer.state_dict(),
            diagnostic_metric_container,
        )

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]:
        """Train the model on the given batch of data and labels.
        Returns the gradients of the model's parameters."""
        smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)

        self._set_model_flops(smashed_data)

        self._optimizer.zero_grad()

        self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
        smashed_data = Variable(smashed_data, requires_grad=True)
        output_train = self._model(smashed_data)

        loss_train = self._loss_fn(output_train, labels)

        self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2)
        loss_train.backward()
        self._optimizer.step()

        # Capturing training metrics for the current batch.
        self.node_device.log({"loss": loss_train.item()})
        self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())

        return smashed_data.grad, loss_train.item()

    def _set_model_flops(self, smashed_data):
        """Helper to determine the model flops when smashed data are available for the first time."""
        if self._model_flops == 0:
            self._model_flops = estimate_model_flops(self._model, smashed_data) / len(
                smashed_data
            )

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def finalize_metrics(self, device_id: str, phase: str):
        """Computes the total results of the metrics. Logs the results clears the cached predictions.
        Returns a list of results."""
        metric_result_list = self._metrics.compute_metrics(
            phase=phase, device_id=device_id
        )
        for metric_result in metric_result_list:
            self.node_device.log(metric_result.as_loggable_dict())
        self._metrics.reset_metrics()
        return metric_result_list

    @check_device_set()
    def evaluate_global(
        self, devices: List[str], val: bool
    ) -> Tuple[ModelMetricResultContainer, DiagnosticMetricResultContainer]:
        """Evaluates on the given devices using the own server model. Returns the gathered metrics."""
        result_metrics = ModelMetricResultContainer()
        diagnostic_metric_results = DiagnosticMetricResultContainer()
        for device_id in devices:
            phase = "val" if val else "test"
            print(
                f"Evaluate with {phase} data on client {device_id} with server {self.node_device.device_id}"
            )

            diagnostic_metrics = self.node_device.evaluate_on(
                device_id, server_device=self.node_device.device_id, val=val
            )

            metrics = self.finalize_metrics(str(device_id), f"{phase}")
            result_metrics.add_results(metrics)
            diagnostic_metric_results.merge(diagnostic_metrics)
        return result_metrics, diagnostic_metric_results

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def evaluate_batch(self, smashed_data, labels):
        """Evaluates the model on the given batch of data and labels"""
        with torch.no_grad():
            smashed_data = smashed_data.to(self._device)
            self._set_model_flops(smashed_data)
            self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
            pred = self._model(smashed_data)
        self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int())

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def train_parallel_split_learning(
        self,
        clients: List[str],
        round_no: int,
        adaptive_learning_threshold: Optional[float] = None,
        optimizer_state: dict[str, Any] = None,
    ):
        def client_training_job(client_id: str, batch_index: int) -> SLTrainBatchResult:
            return self.node_device.train_batch_on_client_only_on(
                device_id=client_id, batch_index=batch_index
            )

        def client_backpropagation_job(client_id: str, gradients: Any):
            return self.node_device.backpropagation_on_client_only_on(
                client_id=client_id, gradients=gradients
            )

        if optimizer_state is not None:
            self._optimizer.load_state_dict(optimizer_state)

        num_threads = len(clients)
        executor = create_executor_with_threads(num_threads)

        # batches = []
        model_metrics = ModelMetricResultContainer()
        diagnostic_metrics = DiagnosticMetricResultContainer()

        # We iterate over each batch, initializing all client training at once and processing the results afterward.
        num_batches = self.node_device.client.get_approximated_num_batches()
        print(f"\n\n:: BATCHES :: {num_batches}\n\n")
        for batch_index in range(num_batches):
            batches = []

            futures = [
                executor.submit(client_training_job, client_id, batch_index)
                for client_id in clients
            ]
            for future in concurrent.futures.as_completed(futures):
                batches.append(future.result())

            if _empty_batches(batches):
                # Only last batch anyway.
                break

            print(f"\n\n\nBATCHES: {len(batches)}\n\n\n")
            # batches2 = [b for b in batches if b is not None]
            # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n")
            server_batch = _concat_smashed_data(
                [b[0].to(self._device) for b in batches]
            )
            server_labels = _concat_smashed_data(
                [b[1].to(self._device) for b in batches]
            )

            # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
            # to split the gradients back into batch-sized tensors to average them before sending them to the client.
            server_gradients, server_loss, server_metrics = (
                self.node_device.train_batch(server_batch, server_labels)
            )  # DiagnosticMetricResultContainer

            # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
            # do the client propagation once the current loss value is larger than the threshold.
            print(
                f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
            )
            if (
                adaptive_learning_threshold
                and server_loss < adaptive_learning_threshold
            ):
                print(
                    f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
                )
                self.node_device.log({"adaptive_learning_threshold_applied": True})
                continue

            num_client_gradients = len(batches)
            print(
                f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
            )

            client_gradients = torch.chunk(server_gradients, num_client_gradients)
            print(f"::: result shape: {client_gradients[1].shape}")
            concatenated_client_gradients = torch.stack(client_gradients, dim=0)
            mean_tensor = torch.mean(concatenated_client_gradients, dim=0)
            print(f"::: -> {mean_tensor.shape}")

            futures = [
                executor.submit(client_backpropagation_job, client_id, mean_tensor)
                for client_id in clients
            ]
            diagnostic_results = []
            for future in concurrent.futures.as_completed(futures):
                diagnostic_results.append(future.result())
            # Has to be done outside the loop due to thread-safety.
            # for diagnostic_result in diagnostic_results:
            #     diagnostic_metrics.merge(diagnostic_result)

            # # Resetting batches since we only need them per-batch-iteration.
            # batches = []

        # Now we have to determine the model metrics for each client.
        for client_id in clients:
            train_metrics = self.finalize_metrics(str(client_id), "train")

            print(f"::: evaluating on {client_id}")
            evaluation_diagnostics_metrics = self.node_device.evaluate_on(
                device_id=client_id,
                server_device=self.node_device.device_id,
                val=True,
            )
            # if evaluation_diagnostics_metrics:
            #     diagnostic_metrics.merge(evaluation_diagnostics_metrics)
            val_metrics = self.finalize_metrics(str(client_id), "val")

            model_metrics.add_results(train_metrics)
            model_metrics.add_results(val_metrics)

        optimizer_state = self._optimizer.state_dict()
        if self._lr_scheduler is not None:
            if round_no != -1:
                self._lr_scheduler.step(round_no + 1)  # epoch=1
            else:
                self._lr_scheduler.step()
        # delete references and free GPU memory manually
        server_batch = None
        server_labels = None
        server_gradients = None
        client_gradients = None
        concatenated_client_gradients = None
        mean_tensor = None
        torch.cuda.empty_cache()
        torch.cuda.set_device(self._device)
        return (
            self.node_device.client.get_weights(),
            self.get_weights(),
            model_metrics,
            optimizer_state,
            diagnostic_metrics,
        )


def _concat_smashed_data(data: List[Any]) -> Any:
    """Creates a single batch tensor from a list of tensors."""
    return torch.cat(data, dim=0)


def _empty_batches(batches):
    """Checks if all the list entries are `None`."""
    return batches.count(None) == len(batches)