-
Sven Michael Lechner authored
This is a quick and dirty fix to make it easier to run experiments. A proper implementation should relocate the parallelism into the corresponding dispatcher classes. Closes #1
Sven Michael Lechner authoredThis is a quick and dirty fix to make it easier to run experiments. A proper implementation should relocate the parallelism into the corresponding dispatcher classes. Closes #1
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
server.py 23.86 KiB
from __future__ import annotations
import concurrent.futures
from typing import List, Optional, Tuple, Any, TYPE_CHECKING
import torch
from omegaconf import DictConfig
from colorama import Fore
from torch import nn
from torch.autograd import Variable
from edml.helpers.config_helpers import get_torch_device_id
from edml.helpers.decorators import check_device_set, simulate_latency_decorator, Timer
from edml.helpers.executor import create_executor_with_threads
from edml.helpers.flops import estimate_model_flops
from edml.helpers.load_optimizer import get_optimizer_and_scheduler
from edml.helpers.metrics import (
create_metrics,
ModelMetricResultContainer,
DiagnosticMetricResultContainer,
)
from edml.helpers.types import StateDict, LossFn
if TYPE_CHECKING:
from edml.core.device import Device
class DeviceServer:
"""Split learning server that runs on a device and communicates with clients on (potentially) other devices
through the provided interface by its device."""
def __init__(
self,
model: nn.Module,
loss_fn: LossFn,
cfg: DictConfig,
latency_factor: float = 0.0,
):
"""Initializes the server with the given model, loss function, configuration and reference to its device."""
self._device = torch.device(get_torch_device_id(cfg))
self._model = model.to(self._device)
self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
cfg, self._model.parameters()
)
self._model_flops = 0 # determine later
self._metrics = create_metrics(
cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting
)
self._loss_fn = loss_fn
self._cfg = cfg
self.node_device: Optional[Device] = None
self.latency_factor = latency_factor
def set_device(self, node_device: Device):
"""Sets the device reference for the server."""
self.node_device = node_device
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def set_weights(self, state_dict: StateDict):
"""Sets the weights of the server's model"""
if state_dict is not None:
self._model.load_state_dict(state_dict=state_dict)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def get_weights(self):
"""Returns the weights of the server's model"""
return self._model.state_dict()
@check_device_set()
def train(
self,
devices: List[str],
epochs: int = 1,
round_no: int = -1,
optimizer_state: dict[str, Any] = None,
) -> Tuple[
Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
]:
"""Train the model on the given devices for the given number of epochs.
Shares the weights among clients and saves the final weights to the configured paths.
Args:
devices: The devices to train on
epochs: Optionally, the number of epochs to train.
round_no: Optionally, the current global epoch number if a learning rate scheduler is used.
optimizer_state: Optionally, the optimizer_state to proceed from
"""
client_weights = None
metrics = ModelMetricResultContainer()
diagnostic_metric_container = DiagnosticMetricResultContainer()
if optimizer_state is not None:
self._optimizer.load_state_dict(optimizer_state)
for epoch in range(epochs):
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no + epoch)
else:
self._lr_scheduler.step()
for device_id in devices:
print(
f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
)
if client_weights is not None:
self.node_device.set_weights_on(
device_id=device_id,
state_dict=client_weights,
on_client=True, # we want to set client weights
)
train_epoch_response = self.node_device.train_epoch_on(
device_id, self.node_device.device_id, round_no=round_no + epoch
)
if (
train_epoch_response is not False
and train_epoch_response is not None
):
client_weights, diagnostic_metrics = train_epoch_response
train_metrics = self.finalize_metrics(str(device_id), "train")
diagnostic_metric_container.merge(diagnostic_metrics)
diagnostic_metrics = self.node_device.evaluate_on(
device_id, server_device=self.node_device.device_id, val=True
)
if diagnostic_metrics is not None:
diagnostic_metric_container.merge(diagnostic_metrics)
val_metrics = self.finalize_metrics(str(device_id), "val")
metrics.add_results(train_metrics)
metrics.add_results(val_metrics)
return (
client_weights,
self.get_weights(),
metrics,
self._optimizer.state_dict(),
diagnostic_metric_container,
)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]:
"""Train the model on the given batch of data and labels.
Returns the gradients of the model's parameters."""
smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)
self._set_model_flops(smashed_data)
self._optimizer.zero_grad()
self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
smashed_data = Variable(smashed_data, requires_grad=True)
output_train = self._model(smashed_data)
loss_train = self._loss_fn(output_train, labels)
self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2)
loss_train.backward()
self._optimizer.step()
# Capturing training metrics for the current batch.
self.node_device.log({"loss": loss_train.item()})
self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())
return smashed_data.grad, loss_train.item()
def _set_model_flops(self, smashed_data):
"""Helper to determine the model flops when smashed data are available for the first time."""
if self._model_flops == 0:
self._model_flops = estimate_model_flops(self._model, smashed_data) / len(
smashed_data
)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def finalize_metrics(self, device_id: str, phase: str):
"""Computes the total results of the metrics. Logs the results clears the cached predictions.
Returns a list of results."""
metric_result_list = self._metrics.compute_metrics(
phase=phase, device_id=device_id
)
for metric_result in metric_result_list:
self.node_device.log(metric_result.as_loggable_dict())
self._metrics.reset_metrics()
return metric_result_list
@check_device_set()
def evaluate_global(
self, devices: List[str], val: bool
) -> Tuple[ModelMetricResultContainer, DiagnosticMetricResultContainer]:
"""Evaluates on the given devices using the own server model. Returns the gathered metrics."""
result_metrics = ModelMetricResultContainer()
diagnostic_metric_results = DiagnosticMetricResultContainer()
for device_id in devices:
phase = "val" if val else "test"
print(
f"Evaluate with {phase} data on client {device_id} with server {self.node_device.device_id}"
)
diagnostic_metrics = self.node_device.evaluate_on(
device_id, server_device=self.node_device.device_id, val=val
)
metrics = self.finalize_metrics(str(device_id), f"{phase}")
result_metrics.add_results(metrics)
diagnostic_metric_results.merge(diagnostic_metrics)
return result_metrics, diagnostic_metric_results
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def evaluate_batch(self, smashed_data, labels):
"""Evaluates the model on the given batch of data and labels"""
with torch.no_grad():
smashed_data = smashed_data.to(self._device)
self._set_model_flops(smashed_data)
self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
pred = self._model(smashed_data)
self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int())
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def train_parallel_split_learning(
self,
clients: List[str],
round_no: int,
adaptive_learning_threshold: Optional[float] = None,
optimizer_state: dict[str, Any] = None,
):
def client_training_job(client_id: str, batch_index: int):
result = self.node_device.train_batch_on_client_only_on(
device_id=client_id,
batch_index=batch_index,
round_no=round_no,
# round_no is taken from outer method arg
)
return (client_id, result)
def client_backpropagation_job(client_id: str, gradients: Any):
return self.node_device.backpropagation_on_client_only_on(
client_id=client_id, gradients=gradients
)
def client_set_gradient_and_finalize_training_job(
client_id: str, gradients: Any
):
return (
self.node_device.set_gradient_and_finalize_training_on_client_only_on(
client_id=client_id, gradients=gradients
)
)
if optimizer_state is not None:
self._optimizer.load_state_dict(optimizer_state)
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no + 1) # epoch=1
else:
self._lr_scheduler.step()
num_threads = len(clients)
executor = create_executor_with_threads(num_threads)
# batches = []
model_metrics = ModelMetricResultContainer()
diagnostic_metrics = DiagnosticMetricResultContainer()
# We iterate over each batch, initializing all client training at once and processing the results afterward.
num_batches = self.node_device.client.get_approximated_num_batches()
print(f":: BATCHES :: {num_batches}")
if self._cfg.simulate_parallelism:
for batch_index in range(num_batches):
client_forward_pass_responses = []
parallel_times = []
with Timer() as elapsed_time:
for client_id in clients:
with Timer() as individual_time:
(client_id, result) = client_training_job(
client_id, batch_index
)
if result is not None and result is not False:
client_forward_pass_responses.append(
(client_id, result)
)
parallel_times.append(individual_time.execution_time)
self.node_device.log(
{
"parallel_client_train_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
}
)
# We want to split up the responses into a list of client IDs and batches again.
client_ids = [b[0] for b in client_forward_pass_responses]
client_batches = [b[1] for b in client_forward_pass_responses]
server_batch = _concat_smashed_data(
[b[0].to(self._device) for b in client_batches]
)
server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in client_batches]
)
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client.
server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold.
print(
f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
)
if (
adaptive_learning_threshold
and server_loss < adaptive_learning_threshold
):
print(
f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
)
self.node_device.log(
{
"adaptive_learning_threshold_applied": server_gradients.size(
0
)
}
)
continue
num_client_gradients = len(client_forward_pass_responses)
print(
f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
)
# clone single client gradients so that client_gradients is not a list of views of server_gradients
# if we just use torch.chunk, each client will receive the whole server_gradients
client_gradients = [
t.clone().detach()
for t in torch.chunk(server_gradients, num_client_gradients)
]
client_backpropagation_gradients = []
parallel_times = []
with Timer() as elapsed_time:
for idx, client_id in enumerate(client_ids):
with Timer() as individual_time:
_, grads = client_backpropagation_job(
client_id, client_gradients[idx]
)
if grads is not None and grads is not False:
client_backpropagation_gradients.append(grads)
parallel_times.append(individual_time.execution_time)
self.node_device.log(
{
"parallel_client_backprop_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
}
)
# We want to average the client's backpropagation gradients and send them over again to finalize the
# current training step.
averaged_gradient = _calculate_gradient_mean(
client_backpropagation_gradients, self._device
)
parallel_times = []
with Timer() as elapsed_time:
for client_id in clients:
with Timer() as individual_time:
client_set_gradient_and_finalize_training_job(
client_id, averaged_gradient
)
parallel_times.append(individual_time.execution_time)
self.node_device.log(
{
"parallel_client_model_update_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
}
)
else:
for batch_index in range(num_batches):
client_forward_pass_responses = []
futures = [
executor.submit(client_training_job, client_id, batch_index)
for client_id in clients
]
for future in concurrent.futures.as_completed(futures):
(client_id, result) = future.result()
if result is not None and result is not False:
client_forward_pass_responses.append((client_id, result))
# We want to split up the responses into a list of client IDs and batches again.
client_ids = [b[0] for b in client_forward_pass_responses]
client_batches = [b[1] for b in client_forward_pass_responses]
server_batch = _concat_smashed_data(
[b[0].to(self._device) for b in client_batches]
)
server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in client_batches]
)
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client.
server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold.
print(
f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
)
if (
adaptive_learning_threshold
and server_loss < adaptive_learning_threshold
):
print(
f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
)
self.node_device.log(
{
"adaptive_learning_threshold_applied": server_gradients.size(
0
)
}
)
continue
num_client_gradients = len(client_forward_pass_responses)
print(
f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
)
# clone single client gradients so that client_gradients is not a list of views of server_gradients
# if we just use torch.chunk, each client will receive the whole server_gradients
client_gradients = [
t.clone().detach()
for t in torch.chunk(server_gradients, num_client_gradients)
]
futures = [
executor.submit(
client_backpropagation_job, client_id, client_gradients[idx]
)
for (idx, client_id) in enumerate(client_ids)
]
client_backpropagation_gradients = []
for future in concurrent.futures.as_completed(futures):
_, grads = future.result()
if grads is not None and grads is not False:
client_backpropagation_gradients.append(grads)
# We want to average the client's backpropagation gradients and send them over again to finalize the
# current training step.
averaged_gradient = _calculate_gradient_mean(
client_backpropagation_gradients, self._device
)
futures = [
executor.submit(
client_set_gradient_and_finalize_training_job,
client_id,
averaged_gradient,
)
for client_id in clients
]
for future in concurrent.futures.as_completed(futures):
future.result()
# Now we have to determine the model metrics for each client.
if self._cfg.simulate_parallelism:
parallel_times = []
with Timer() as elapsed_time:
for client_id in clients:
with Timer() as individual_time:
train_metrics = self.finalize_metrics(str(client_id), "train")
evaluation_diagnostics_metrics = self.node_device.evaluate_on(
device_id=client_id,
server_device=self.node_device.device_id,
val=True,
)
# if evaluation_diagnostics_metrics:
# diagnostic_metrics.merge(evaluation_diagnostics_metrics)
val_metrics = self.finalize_metrics(str(client_id), "val")
model_metrics.add_results(train_metrics)
model_metrics.add_results(val_metrics)
parallel_times.append(individual_time.execution_time)
self.node_device.log(
{
"parallel_client_eval_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
}
)
else:
for client_id in clients:
train_metrics = self.finalize_metrics(str(client_id), "train")
evaluation_diagnostics_metrics = self.node_device.evaluate_on(
device_id=client_id,
server_device=self.node_device.device_id,
val=True,
)
# if evaluation_diagnostics_metrics:
# diagnostic_metrics.merge(evaluation_diagnostics_metrics)
val_metrics = self.finalize_metrics(str(client_id), "val")
model_metrics.add_results(train_metrics)
model_metrics.add_results(val_metrics)
optimizer_state = self._optimizer.state_dict()
# delete references and free GPU memory manually
server_batch = None
server_labels = None
server_gradients = None
client_gradients = None
concatenated_client_gradients = None
mean_tensor = None
torch.cuda.empty_cache()
torch.cuda.set_device(self._device)
return (
self.node_device.client.get_weights(),
self.get_weights(),
model_metrics,
optimizer_state,
diagnostic_metrics,
)
def _calculate_gradient_mean(
gradients: List[Variable], device: str = "cpu"
) -> Variable:
num_devices = len(gradients)
weights = [1 / num_devices] * num_devices
# We need to move all tensors to the same device to do calculations.
for i, client_gradients in enumerate(gradients):
for j, grad in enumerate(client_gradients):
gradients[i][j] = grad.to(device)
return [
sum(gradients[i][j] * weights[i] for i in range(num_devices))
for j in range(len(gradients[0]))
]
def _concat_smashed_data(data: List[Any]) -> Any:
"""Creates a single batch tensor from a list of tensors."""
return torch.cat(data, dim=0)