Code owners
Assign users and groups as approvers for specific file changes. Learn more.
client.py 15.36 KiB
from __future__ import annotations
import itertools
import time
from typing import Optional, Tuple, TYPE_CHECKING, Any
import torch
from omegaconf import DictConfig
from torch import nn
from torch.utils.data import DataLoader
from edml.helpers.config_helpers import get_torch_device_id
from edml.helpers.decorators import (
check_device_set,
simulate_latency_decorator,
LatencySimulator,
)
from edml.helpers.flops import estimate_model_flops
from edml.helpers.load_optimizer import get_optimizer_and_scheduler
from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult
from edml.helpers.types import StateDict, SLTrainBatchResult
if TYPE_CHECKING:
from edml.core.device import Device
class DeviceClient:
"""
A client in the context of split learning. I.e., a device that trains the first `n`-th layers on the client. The
remaining layers will then be trained on the server device.
Attributes:
- TODO: latency_factor and node_device?
See:
- py:meth:`~edml.core.server.DeviceServer`
- py:meth:`~edml.controllers.split_learning.SplitController`
- py:meth:`~edml.controllers.swarm_learning.SwarmController`
Split learning client that runs on a device and communicates with servers on (potentially) other devices
through the provided interface by its device."""
def __init__(
self,
model: nn.Module,
cfg: DictConfig,
train_dl: DataLoader,
val_dl: DataLoader,
test_dl: DataLoader,
latency_factor: float = 0.0,
):
"""
Initializes the split learning client with its (partial) model and training, validation and test data.
Args:
model (nn.Module): The pytorch neural network trained by the client.
cfg (DictConfig): The experiment's configuration.
train_dl (DataLoader): The data loader responsible for loading training data.
val_dl (DataLoader): The data loader responsible for loading validation data.
test_dl (DataLoader): The data loader responsible for loading testing data.
latency_factor (float):
Notes:
This class moves the model to the GPU if CUDA is available. If not, the model will be moved to the CPU.
"""
self._train_data, self._val_data, self._test_data = train_dl, val_dl, test_dl
self._batchable_data_loader = None
self._device = torch.device(get_torch_device_id(cfg))
self._model = model.to(self._device)
self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
cfg, self._model.parameters()
)
# get first sample from train data to estimate model flops
sample = self._train_data.dataset.__getitem__(0)[0]
if not isinstance(sample, torch.Tensor):
sample = torch.tensor(data=sample)
self._model_flops = estimate_model_flops(
self._model, sample.to(self._device).unsqueeze(0)
)
self._cfg = cfg
self.node_device: Optional[Device] = None
self.latency_factor = latency_factor
self._psl_cache = None
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def set_device(self, node_device: Device):
"""
Sets the device that this client-side is part of.
Notes:
If a latency factor is specified, this function sleeps for said amount before returning.
"""
self.node_device = node_device
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def set_weights(self, state_dict: StateDict):
"""
Updates the model's weights with the one specified.
Args:
state_dict (StateDict): The model's parameters and buffers.
Notes:
If a latency factor is specified, this function sleeps for said amount before returning.
"""
if state_dict is not None:
self._model.load_state_dict(state_dict=state_dict)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def get_weights(self) -> StateDict:
"""
Returns the model's parameters and buffers, including its weights.
Returns:
StateDict
Notes:
If a latency factor is specified, this function sleeps for said amount before returning.
"""
return self._model.state_dict()
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def get_num_samples(self) -> int:
"""
Returns the number of samples in the client's training data.
Returns:
int: The length of the training data set.
Notes:
If a latency factor is specified, this function sleeps for said amount before returning.
"""
return len(self._train_data.dataset)
@check_device_set()
def train_single_batch(
self, batch_index: int
) -> Optional[torch.Tensor, torch.Tensor]:
torch.cuda.set_device(self._device)
# We have to re-initialize the data loader in the case that we do another epoch.
if batch_index == 0:
self._batchable_data_loader = iter(self._train_data)
# Used to measure training time. The problem we have with parallel split learning is that forward- and backward-
# passes are orchestrated by the current server.
# Thus, we need to cache the time required for the forward pass to ensure that we collect the right execution
# time.
start_time = time.time()
self._model.train()
# We need to get the number of batches that the DataLoader can provide us with to properly index and retrieve
# the correct batch.
#
# TODO: is there another way to do this? gRPC streaming does not work here, since we have to keep the streams
# alive while doing other RPC calls like settings weights, sending/averaging gradient data, ...
num_batches = self.get_approximated_num_batches()
assert 0 <= batch_index < num_batches
# Safety check to ensure that we train same-sized batches only.
batch_data, batch_labels = next(self._batchable_data_loader)
# Updates the battery capacity by simulating the required energy consumption for conducting the training step.
self.node_device.battery.update_flops(self._model_flops * len(batch_data))
# We train the model using the single batch and return the activations and labels. These get send over to the
# server to be then further processed
with LatencySimulator(latency_factor=self.latency_factor):
batch_data_to = batch_data.to(self._device)
self._optimizer.zero_grad()
smashed_data = self._model(batch_data_to)
end_time = time.time()
self._psl_cache = {
"batch_data": batch_data,
"smashed_data": smashed_data,
"start_time": start_time,
"end_time": end_time,
}
return smashed_data, batch_labels
@check_device_set()
def backward_single_batch(
self, gradients
) -> Tuple[DiagnosticMetricResultContainer, torch.Tensor]:
torch.cuda.set_device(self._device)
batch_data, smashed_data, start_time, end_time = (
self._psl_cache["batch_data"],
self._psl_cache["smashed_data"],
self._psl_cache["start_time"],
self._psl_cache["end_time"],
)
start_time_2 = time.time()
self.node_device.battery.update_flops(
self._model_flops * len(batch_data) * 2
) # 2x for backward pass
gradients = gradients.to(self._device)
smashed_data.backward(gradients)
# self._optimizer.step()
# We need to store a reference to the smashed_data to make it possible to finalize the training step.
self._psl_cache["smashed_data"] = smashed_data
end_time_2 = time.time()
metric = DiagnosticMetricResult(
device_id=self.node_device.device_id,
name="comp_time",
value=end_time - start_time + (end_time_2 - start_time_2),
method="client_train_batch_time",
)
metrics_container = DiagnosticMetricResultContainer([metric])
gradients = []
for param in self._model.parameters():
if param.grad is not None:
gradients.append(param.grad)
else:
gradients.append(torch.zeros_like(param))
return metrics_container, gradients
def get_approximated_num_batches(self) -> int:
return len(self._train_data)
@check_device_set()
def train_epoch(
self, server_device_id: str, round_no: int = -1
) -> Tuple[StateDict, DiagnosticMetricResultContainer]:
"""
Trains the model on the client's data for one epoch, returning the new weights and training metrics.
The server model is run on the device with the given id.
Args:
server_device_id (str): The id of the device on which the server model is run.
round_no (int, optional): The current epoch number. Required when using a learning rate scheduler.
Returns:
StateDict: The updated weights of the client's model.
DiagnosticMetricResultContainer: The diagnostic metrics collected when training on the server and the actual
client model execution time.
Notes:
If configured, runtime latency is simulated on neural network operations.
For optimizing the server device selection, the training time for the client model is needed. Therefore, the
execution time (without the time for the server to process the batches) is measured and added as a
diagnostic metrics.
Usual designs measure the execution time at the device level (including batch processing time). Contrary to
that, this approach does not require to deduce server batch processing time after a "traditional"
measurement.
"""
client_train_start_time = time.time()
server_train_batch_times = (
[]
) # collects the time for the server to process the batches
self._model.train()
diagnostic_metric_container = DiagnosticMetricResultContainer()
for idx, (batch_data, batch_labels) in enumerate(self._train_data):
self.node_device.battery.update_flops(self._model_flops * len(batch_data))
with LatencySimulator(latency_factor=self.latency_factor):
batch_data = batch_data.to(self._device)
batch_labels = batch_labels.to(self._device)
self._optimizer.zero_grad()
smashed_data = self._model(batch_data)
# measure the time for the server to process the batch
start_time = time.time()
train_batch_response = self.node_device.train_batch_on(
server_device_id, smashed_data, batch_labels
)
server_train_batch_times.append(time.time() - start_time)
with LatencySimulator(latency_factor=self.latency_factor):
if (
train_batch_response is False or train_batch_response is None
): # server device unavailable
break
server_grad, _server_loss, diagnostic_metrics = train_batch_response
diagnostic_metric_container.merge(diagnostic_metrics)
self.node_device.battery.update_flops(
self._model_flops * len(batch_data) * 2
) # 2x for backward pass
server_grad = server_grad.to(self._device)
smashed_data.backward(server_grad)
self._optimizer.step()
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no)
else:
self._lr_scheduler.step()
client_train_time = (
time.time() - client_train_start_time - sum(server_train_batch_times)
)
diagnostic_metric_container.add_result(
DiagnosticMetricResult(
device_id=self.node_device.device_id,
name="comp_time",
value=client_train_time,
method="client_train_epoch_time",
)
)
return self._model.state_dict(), diagnostic_metric_container
@check_device_set()
def evaluate(self, server_device: str, val=True) -> DiagnosticMetricResultContainer:
"""
Evaluates the model on the client's data.
Args:
server_device (str): The server device to run the server model on.
val (bool, optional): If `True`, uses the validation data, otherwise the test data.
Set to `True` by default.
Returns:
DiagnosticMetricResultContainer: The diagnostic metrics collected when training on the server and the actual
client model execution time.
Notes:
If configured, runtime latency is simulated on neural network operations.
For optimizing the server device selection, the training time for the client model is needed. Therefore, the
execution time (without the time for the server to process the batches) is measured and added as a
diagnostic metrics.
Usual designs measure the execution time at the device level (including batch processing time). Contrary to
that, this approach does not require to deduce server batch processing time after a "traditional"
measurement.
"""
client_eval_start_time = time.time()
server_eval_batch_times = (
[]
) # collects the time for the server to process the batches
self._model.eval()
diagnostic_metric_results = DiagnosticMetricResultContainer()
with torch.no_grad():
dataloader = self._val_data if val else self._test_data
for b, (batch_data, batch_labels) in enumerate(dataloader):
with LatencySimulator(latency_factor=self.latency_factor):
self.node_device.battery.update_flops(
self._model_flops * len(batch_data)
)
batch_data = batch_data.to(self._device)
batch_labels = batch_labels.to(self._device)
# measure the time for the server to process the batch
start_time = time.time()
diagnostic_metrics = self.node_device.evaluate_batch_on(
server_device, self._model(batch_data), batch_labels
)
server_eval_batch_times.append(time.time() - start_time)
diagnostic_metric_results.merge(diagnostic_metrics)
client_eval_time = (
time.time() - client_eval_start_time - sum(server_eval_batch_times)
)
diagnostic_metric_results.add_result(
DiagnosticMetricResult(
device_id=self.node_device.device_id,
name="comp_time",
value=client_eval_time,
method="client_eval_epoch_time",
)
)
return diagnostic_metric_results
def set_gradient_and_finalize_training(self, gradients: Any):
for param, grad in zip(self._model.parameters(), gradients):
param.grad = grad.to(self._device)
self._optimizer.step()
self._psl_cache = None