Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
client.py 15.36 KiB
from __future__ import annotations

import itertools
import time
from typing import Optional, Tuple, TYPE_CHECKING, Any

import torch
from omegaconf import DictConfig
from torch import nn
from torch.utils.data import DataLoader

from edml.helpers.config_helpers import get_torch_device_id
from edml.helpers.decorators import (
    check_device_set,
    simulate_latency_decorator,
    LatencySimulator,
)
from edml.helpers.flops import estimate_model_flops
from edml.helpers.load_optimizer import get_optimizer_and_scheduler
from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult
from edml.helpers.types import StateDict, SLTrainBatchResult

if TYPE_CHECKING:
    from edml.core.device import Device


class DeviceClient:
    """
    A client in the context of split learning. I.e., a device that trains the first `n`-th layers on the client. The
    remaining layers will then be trained on the server device.

    Attributes:
        - TODO: latency_factor and node_device?

    See:
        - py:meth:`~edml.core.server.DeviceServer`
        - py:meth:`~edml.controllers.split_learning.SplitController`
        - py:meth:`~edml.controllers.swarm_learning.SwarmController`

    Split learning client that runs on a device and communicates with servers on (potentially) other devices
    through the provided interface by its device."""

    def __init__(
        self,
        model: nn.Module,
        cfg: DictConfig,
        train_dl: DataLoader,
        val_dl: DataLoader,
        test_dl: DataLoader,
        latency_factor: float = 0.0,
    ):
        """
        Initializes the split learning client with its (partial) model and training, validation and test data.

        Args:
            model (nn.Module): The pytorch neural network trained by the client.
            cfg (DictConfig): The experiment's configuration.
            train_dl (DataLoader): The data loader responsible for loading training data.
            val_dl (DataLoader): The data loader responsible for loading validation data.
            test_dl (DataLoader): The data loader responsible for loading testing data.
            latency_factor (float):

        Notes:
            This class moves the model to the GPU if CUDA is available. If not, the model will be moved to the CPU.
        """

        self._train_data, self._val_data, self._test_data = train_dl, val_dl, test_dl
        self._batchable_data_loader = None
        self._device = torch.device(get_torch_device_id(cfg))
        self._model = model.to(self._device)
        self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
            cfg, self._model.parameters()
        )
        # get first sample from train data to estimate model flops
        sample = self._train_data.dataset.__getitem__(0)[0]
        if not isinstance(sample, torch.Tensor):
            sample = torch.tensor(data=sample)
        self._model_flops = estimate_model_flops(
            self._model, sample.to(self._device).unsqueeze(0)
        )
        self._cfg = cfg
        self.node_device: Optional[Device] = None
        self.latency_factor = latency_factor

        self._psl_cache = None

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def set_device(self, node_device: Device):
        """
        Sets the device that this client-side is part of.

        Notes:
            If a latency factor is specified, this function sleeps for said amount before returning.
        """
        self.node_device = node_device

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def set_weights(self, state_dict: StateDict):
        """
        Updates the model's weights with the one specified.

        Args:
            state_dict (StateDict): The model's parameters and buffers.

        Notes:
            If a latency factor is specified, this function sleeps for said amount before returning.
        """
        if state_dict is not None:
            self._model.load_state_dict(state_dict=state_dict)

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def get_weights(self) -> StateDict:
        """
        Returns the model's parameters and buffers, including its weights.

        Returns:
            StateDict

        Notes:
            If a latency factor is specified, this function sleeps for said amount before returning.
        """
        return self._model.state_dict()

    @simulate_latency_decorator(latency_factor_attr="latency_factor")
    def get_num_samples(self) -> int:
        """
        Returns the number of samples in the client's training data.

        Returns:
            int: The length of the training data set.

        Notes:
            If a latency factor is specified, this function sleeps for said amount before returning.
        """
        return len(self._train_data.dataset)

    @check_device_set()
    def train_single_batch(
        self, batch_index: int
    ) -> Optional[torch.Tensor, torch.Tensor]:
        torch.cuda.set_device(self._device)
        # We have to re-initialize the data loader in the case that we do another epoch.
        if batch_index == 0:
            self._batchable_data_loader = iter(self._train_data)

        # Used to measure training time. The problem we have with parallel split learning is that forward- and backward-
        # passes are orchestrated by the current server.
        # Thus, we need to cache the time required for the forward pass to ensure that we collect the right execution
        # time.
        start_time = time.time()

        self._model.train()

        # We need to get the number of batches that the DataLoader can provide us with to properly index and retrieve
        # the correct batch.
        #
        # TODO: is there another way to do this? gRPC streaming does not work here, since we have to keep the streams
        #       alive while doing other RPC calls like settings weights, sending/averaging gradient data, ...
        num_batches = self.get_approximated_num_batches()
        assert 0 <= batch_index < num_batches

        # Safety check to ensure that we train same-sized batches only.
        batch_data, batch_labels = next(self._batchable_data_loader)

        # Updates the battery capacity by simulating the required energy consumption for conducting the training step.
        self.node_device.battery.update_flops(self._model_flops * len(batch_data))

        # We train the model using the single batch and return the activations and labels. These get send over to the
        # server to be then further processed

        with LatencySimulator(latency_factor=self.latency_factor):
            batch_data_to = batch_data.to(self._device)

            self._optimizer.zero_grad()
            smashed_data = self._model(batch_data_to)

            end_time = time.time()

            self._psl_cache = {
                "batch_data": batch_data,
                "smashed_data": smashed_data,
                "start_time": start_time,
                "end_time": end_time,
            }
            return smashed_data, batch_labels

    @check_device_set()
    def backward_single_batch(
        self, gradients
    ) -> Tuple[DiagnosticMetricResultContainer, torch.Tensor]:
        torch.cuda.set_device(self._device)
        batch_data, smashed_data, start_time, end_time = (
            self._psl_cache["batch_data"],
            self._psl_cache["smashed_data"],
            self._psl_cache["start_time"],
            self._psl_cache["end_time"],
        )

        start_time_2 = time.time()

        self.node_device.battery.update_flops(
            self._model_flops * len(batch_data) * 2
        )  # 2x for backward pass
        gradients = gradients.to(self._device)
        smashed_data.backward(gradients)
        # self._optimizer.step()

        # We need to store a reference to the smashed_data to make it possible to finalize the training step.
        self._psl_cache["smashed_data"] = smashed_data

        end_time_2 = time.time()

        metric = DiagnosticMetricResult(
            device_id=self.node_device.device_id,
            name="comp_time",
            value=end_time - start_time + (end_time_2 - start_time_2),
            method="client_train_batch_time",
        )
        metrics_container = DiagnosticMetricResultContainer([metric])

        gradients = []
        for param in self._model.parameters():
            if param.grad is not None:
                gradients.append(param.grad)
            else:
                gradients.append(torch.zeros_like(param))

        return metrics_container, gradients

    def get_approximated_num_batches(self) -> int:
        return len(self._train_data)

    @check_device_set()
    def train_epoch(
        self, server_device_id: str, round_no: int = -1
    ) -> Tuple[StateDict, DiagnosticMetricResultContainer]:
        """
        Trains the model on the client's data for one epoch, returning the new weights and training metrics.
        The server model is run on the device with the given id.

        Args:
            server_device_id (str): The id of the device on which the server model is run.
            round_no (int, optional): The current epoch number. Required when using a learning rate scheduler.

        Returns:
            StateDict: The updated weights of the client's model.
            DiagnosticMetricResultContainer: The diagnostic metrics collected when training on the server and the actual
                client model execution time.

        Notes:
            If configured, runtime latency is simulated on neural network operations.

            For optimizing the server device selection, the training time for the client model is needed. Therefore, the
            execution time (without the time for the server to process the batches) is measured and added as a
            diagnostic metrics.

            Usual designs measure the execution time at the device level (including batch processing time). Contrary to
            that, this approach does not require to deduce server batch processing time after a "traditional"
            measurement.
        """
        client_train_start_time = time.time()
        server_train_batch_times = (
            []
        )  # collects the time for the server to process the batches
        self._model.train()
        diagnostic_metric_container = DiagnosticMetricResultContainer()
        for idx, (batch_data, batch_labels) in enumerate(self._train_data):
            self.node_device.battery.update_flops(self._model_flops * len(batch_data))

            with LatencySimulator(latency_factor=self.latency_factor):
                batch_data = batch_data.to(self._device)
                batch_labels = batch_labels.to(self._device)

                self._optimizer.zero_grad()
                smashed_data = self._model(batch_data)

            # measure the time for the server to process the batch
            start_time = time.time()
            train_batch_response = self.node_device.train_batch_on(
                server_device_id, smashed_data, batch_labels
            )
            server_train_batch_times.append(time.time() - start_time)

            with LatencySimulator(latency_factor=self.latency_factor):
                if (
                    train_batch_response is False or train_batch_response is None
                ):  # server device unavailable
                    break
                server_grad, _server_loss, diagnostic_metrics = train_batch_response
                diagnostic_metric_container.merge(diagnostic_metrics)
                self.node_device.battery.update_flops(
                    self._model_flops * len(batch_data) * 2
                )  # 2x for backward pass
                server_grad = server_grad.to(self._device)
                smashed_data.backward(server_grad)
                self._optimizer.step()

        if self._lr_scheduler is not None:
            if round_no != -1:
                self._lr_scheduler.step(round_no)
            else:
                self._lr_scheduler.step()

        client_train_time = (
            time.time() - client_train_start_time - sum(server_train_batch_times)
        )
        diagnostic_metric_container.add_result(
            DiagnosticMetricResult(
                device_id=self.node_device.device_id,
                name="comp_time",
                value=client_train_time,
                method="client_train_epoch_time",
            )
        )
        return self._model.state_dict(), diagnostic_metric_container

    @check_device_set()
    def evaluate(self, server_device: str, val=True) -> DiagnosticMetricResultContainer:
        """
        Evaluates the model on the client's data.

        Args:
            server_device (str): The server device to run the server model on.
            val (bool, optional): If `True`, uses the validation data, otherwise the test data.
                Set to `True` by default.

        Returns:
            DiagnosticMetricResultContainer: The diagnostic metrics collected when training on the server and the actual
                client model execution time.

        Notes:
            If configured, runtime latency is simulated on neural network operations.

            For optimizing the server device selection, the training time for the client model is needed. Therefore, the
            execution time (without the time for the server to process the batches) is measured and added as a
            diagnostic metrics.

            Usual designs measure the execution time at the device level (including batch processing time). Contrary to
            that, this approach does not require to deduce server batch processing time after a "traditional"
            measurement.
        """
        client_eval_start_time = time.time()
        server_eval_batch_times = (
            []
        )  # collects the time for the server to process the batches
        self._model.eval()
        diagnostic_metric_results = DiagnosticMetricResultContainer()
        with torch.no_grad():
            dataloader = self._val_data if val else self._test_data
            for b, (batch_data, batch_labels) in enumerate(dataloader):
                with LatencySimulator(latency_factor=self.latency_factor):
                    self.node_device.battery.update_flops(
                        self._model_flops * len(batch_data)
                    )
                    batch_data = batch_data.to(self._device)
                    batch_labels = batch_labels.to(self._device)

                    # measure the time for the server to process the batch
                    start_time = time.time()
                    diagnostic_metrics = self.node_device.evaluate_batch_on(
                        server_device, self._model(batch_data), batch_labels
                    )
                    server_eval_batch_times.append(time.time() - start_time)

                    diagnostic_metric_results.merge(diagnostic_metrics)
        client_eval_time = (
            time.time() - client_eval_start_time - sum(server_eval_batch_times)
        )
        diagnostic_metric_results.add_result(
            DiagnosticMetricResult(
                device_id=self.node_device.device_id,
                name="comp_time",
                value=client_eval_time,
                method="client_eval_epoch_time",
            )
        )
        return diagnostic_metric_results

    def set_gradient_and_finalize_training(self, gradients: Any):
        for param, grad in zip(self._model.parameters(), gradients):
            param.grad = grad.to(self._device)

        self._optimizer.step()
        self._psl_cache = None