-
Tim Tobias Bauerle authoredTim Tobias Bauerle authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
server.py 15.03 KiB
from __future__ import annotations
import concurrent.futures
import time
from typing import List, Optional, Tuple, Any, TYPE_CHECKING
import torch
from omegaconf import DictConfig
from colorama import init, Fore
from torch import nn
from torch.autograd import Variable
from edml.helpers.config_helpers import get_torch_device_id
from edml.helpers.decorators import check_device_set, simulate_latency_decorator
from edml.helpers.executor import create_executor_with_threads
from edml.helpers.flops import estimate_model_flops
from edml.helpers.load_optimizer import get_optimizer_and_scheduler
from edml.helpers.metrics import (
create_metrics,
ModelMetricResultContainer,
DiagnosticMetricResultContainer,
)
from edml.helpers.types import StateDict, SLTrainBatchResult, LossFn
if TYPE_CHECKING:
from edml.core.device import Device
class DeviceServer:
"""Split learning server that runs on a device and communicates with clients on (potentially) other devices
through the provided interface by its device."""
def __init__(
self,
model: nn.Module,
loss_fn: LossFn,
cfg: DictConfig,
latency_factor: float = 0.0,
):
"""Initializes the server with the given model, loss function, configuration and reference to its device."""
self._device = torch.device(get_torch_device_id(cfg))
self._model = model.to(self._device)
self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
cfg, self._model.parameters()
)
self._model_flops = 0 # determine later
self._metrics = create_metrics(
cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting
)
self._loss_fn = loss_fn
self._cfg = cfg
self.node_device: Optional[Device] = None
self.latency_factor = latency_factor
def set_device(self, node_device: Device):
"""Sets the device reference for the server."""
self.node_device = node_device
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def set_weights(self, state_dict: StateDict):
"""Sets the weights of the server's model"""
if state_dict is not None:
self._model.load_state_dict(state_dict=state_dict)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def get_weights(self):
"""Returns the weights of the server's model"""
return self._model.state_dict()
@check_device_set()
def train(
self,
devices: List[str],
epochs: int = 1,
round_no: int = -1,
optimizer_state: dict[str, Any] = None,
) -> Tuple[
Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
]:
"""Train the model on the given devices for the given number of epochs.
Shares the weights among clients and saves the final weights to the configured paths.
Args:
devices: The devices to train on
epochs: Optionally, the number of epochs to train.
round_no: Optionally, the current global epoch number if a learning rate scheduler is used.
optimizer_state: Optionally, the optimizer_state to proceed from
"""
client_weights = None
metrics = ModelMetricResultContainer()
diagnostic_metric_container = DiagnosticMetricResultContainer()
if optimizer_state is not None:
self._optimizer.load_state_dict(optimizer_state)
for epoch in range(epochs):
for device_id in devices:
print(
f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
)
if client_weights is not None:
self.node_device.set_weights_on(
device_id=device_id,
state_dict=client_weights,
on_client=True, # we want to set client weights
)
train_epoch_response = self.node_device.train_epoch_on(
device_id, self.node_device.device_id, round_no=round_no + epoch
)
if (
train_epoch_response is not False
and train_epoch_response is not None
):
client_weights, diagnostic_metrics = train_epoch_response
train_metrics = self.finalize_metrics(str(device_id), "train")
diagnostic_metric_container.merge(diagnostic_metrics)
diagnostic_metrics = self.node_device.evaluate_on(
device_id, server_device=self.node_device.device_id, val=True
)
if diagnostic_metrics is not None:
diagnostic_metric_container.merge(diagnostic_metrics)
val_metrics = self.finalize_metrics(str(device_id), "val")
metrics.add_results(train_metrics)
metrics.add_results(val_metrics)
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no + epoch)
else:
self._lr_scheduler.step()
return (
client_weights,
self.get_weights(),
metrics,
self._optimizer.state_dict(),
diagnostic_metric_container,
)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]:
"""Train the model on the given batch of data and labels.
Returns the gradients of the model's parameters."""
smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)
self._set_model_flops(smashed_data)
self._optimizer.zero_grad()
self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
smashed_data = Variable(smashed_data, requires_grad=True)
output_train = self._model(smashed_data)
loss_train = self._loss_fn(output_train, labels)
self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2)
loss_train.backward()
self._optimizer.step()
# Capturing training metrics for the current batch.
self.node_device.log({"loss": loss_train.item()})
self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())
return smashed_data.grad, loss_train.item()
def _set_model_flops(self, smashed_data):
"""Helper to determine the model flops when smashed data are available for the first time."""
if self._model_flops == 0:
self._model_flops = estimate_model_flops(self._model, smashed_data) / len(
smashed_data
)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def finalize_metrics(self, device_id: str, phase: str):
"""Computes the total results of the metrics. Logs the results clears the cached predictions.
Returns a list of results."""
metric_result_list = self._metrics.compute_metrics(
phase=phase, device_id=device_id
)
for metric_result in metric_result_list:
self.node_device.log(metric_result.as_loggable_dict())
self._metrics.reset_metrics()
return metric_result_list
@check_device_set()
def evaluate_global(
self, devices: List[str], val: bool
) -> Tuple[ModelMetricResultContainer, DiagnosticMetricResultContainer]:
"""Evaluates on the given devices using the own server model. Returns the gathered metrics."""
result_metrics = ModelMetricResultContainer()
diagnostic_metric_results = DiagnosticMetricResultContainer()
for device_id in devices:
phase = "val" if val else "test"
print(
f"Evaluate with {phase} data on client {device_id} with server {self.node_device.device_id}"
)
diagnostic_metrics = self.node_device.evaluate_on(
device_id, server_device=self.node_device.device_id, val=val
)
metrics = self.finalize_metrics(str(device_id), f"{phase}")
result_metrics.add_results(metrics)
diagnostic_metric_results.merge(diagnostic_metrics)
return result_metrics, diagnostic_metric_results
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def evaluate_batch(self, smashed_data, labels):
"""Evaluates the model on the given batch of data and labels"""
with torch.no_grad():
smashed_data = smashed_data.to(self._device)
self._set_model_flops(smashed_data)
self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
pred = self._model(smashed_data)
self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int())
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def train_parallel_split_learning(
self,
clients: List[str],
round_no: int,
adaptive_learning_threshold: Optional[float] = None,
optimizer_state: dict[str, Any] = None,
):
def client_training_job(client_id: str, batch_index: int) -> SLTrainBatchResult:
return self.node_device.train_batch_on_client_only_on(
device_id=client_id, batch_index=batch_index
)
def client_backpropagation_job(client_id: str, gradients: Any):
return self.node_device.backpropagation_on_client_only_on(
client_id=client_id, gradients=gradients
)
if optimizer_state is not None:
self._optimizer.load_state_dict(optimizer_state)
num_threads = len(clients)
executor = create_executor_with_threads(num_threads)
# batches = []
model_metrics = ModelMetricResultContainer()
diagnostic_metrics = DiagnosticMetricResultContainer()
# We iterate over each batch, initializing all client training at once and processing the results afterward.
num_batches = self.node_device.client.get_approximated_num_batches()
print(f"\n\n:: BATCHES :: {num_batches}\n\n")
for batch_index in range(num_batches):
batches = []
futures = [
executor.submit(client_training_job, client_id, batch_index)
for client_id in clients
]
for future in concurrent.futures.as_completed(futures):
batches.append(future.result())
if _empty_batches(batches):
# Only last batch anyway.
break
print(f"\n\n\nBATCHES: {len(batches)}\n\n\n")
# batches2 = [b for b in batches if b is not None]
# print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n")
server_batch = _concat_smashed_data(
[b[0].to(self._device) for b in batches]
)
server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in batches]
)
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client.
server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold.
print(
f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
)
if (
adaptive_learning_threshold
and server_loss < adaptive_learning_threshold
):
print(
f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
)
self.node_device.log({"adaptive_learning_threshold_applied": True})
continue
num_client_gradients = len(batches)
print(
f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
)
client_gradients = torch.chunk(server_gradients, num_client_gradients)
print(f"::: result shape: {client_gradients[1].shape}")
concatenated_client_gradients = torch.stack(client_gradients, dim=0)
mean_tensor = torch.mean(concatenated_client_gradients, dim=0)
print(f"::: -> {mean_tensor.shape}")
futures = [
executor.submit(client_backpropagation_job, client_id, mean_tensor)
for client_id in clients
]
diagnostic_results = []
for future in concurrent.futures.as_completed(futures):
diagnostic_results.append(future.result())
# Has to be done outside the loop due to thread-safety.
# for diagnostic_result in diagnostic_results:
# diagnostic_metrics.merge(diagnostic_result)
# # Resetting batches since we only need them per-batch-iteration.
# batches = []
# Now we have to determine the model metrics for each client.
for client_id in clients:
train_metrics = self.finalize_metrics(str(client_id), "train")
print(f"::: evaluating on {client_id}")
evaluation_diagnostics_metrics = self.node_device.evaluate_on(
device_id=client_id,
server_device=self.node_device.device_id,
val=True,
)
# if evaluation_diagnostics_metrics:
# diagnostic_metrics.merge(evaluation_diagnostics_metrics)
val_metrics = self.finalize_metrics(str(client_id), "val")
model_metrics.add_results(train_metrics)
model_metrics.add_results(val_metrics)
optimizer_state = self._optimizer.state_dict()
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no + 1) # epoch=1
else:
self._lr_scheduler.step()
# delete references and free GPU memory manually
server_batch = None
server_labels = None
server_gradients = None
client_gradients = None
concatenated_client_gradients = None
mean_tensor = None
torch.cuda.empty_cache()
torch.cuda.set_device(self._device)
return (
self.node_device.client.get_weights(),
self.get_weights(),
model_metrics,
optimizer_state,
diagnostic_metrics,
)
def _concat_smashed_data(data: List[Any]) -> Any:
"""Creates a single batch tensor from a list of tensors."""
return torch.cat(data, dim=0)
def _empty_batches(batches):
"""Checks if all the list entries are `None`."""
return batches.count(None) == len(batches)