Skip to content
Snippets Groups Projects
Commit 0e8c7aa6 authored by Sven Michael Lechner's avatar Sven Michael Lechner
Browse files

Merge branch 'psl-performance' into 'main'

PSL performance

See merge request !19
parents dcc9296f 8552a410
No related branches found
No related tags found
1 merge request!19PSL performance
......@@ -220,7 +220,7 @@ class DeviceClient:
gradients = []
for param in self._model.parameters():
if param is not None:
if param.grad is not None:
gradients.append(param.grad)
else:
gradients.append(torch.zeros_like(param))
......
......@@ -575,7 +575,6 @@ class RPCDeviceServicer(DeviceServicer):
def TrainSingleBatchOnClient(self, request, context):
batch_index = request.batch_index
print(f"Starting single batch@{batch_index}")
smashed_data, labels = self.device.client.train_single_batch(batch_index)
......
from __future__ import annotations
import concurrent.futures
import time
from typing import List, Optional, Tuple, Any, TYPE_CHECKING
import torch
from omegaconf import DictConfig
from colorama import init, Fore
from colorama import Fore
from torch import nn
from torch.autograd import Variable
......@@ -20,7 +19,7 @@ from edml.helpers.metrics import (
ModelMetricResultContainer,
DiagnosticMetricResultContainer,
)
from edml.helpers.types import StateDict, SLTrainBatchResult, LossFn
from edml.helpers.types import StateDict, LossFn
if TYPE_CHECKING:
from edml.core.device import Device
......@@ -251,10 +250,9 @@ class DeviceServer:
# We iterate over each batch, initializing all client training at once and processing the results afterward.
num_batches = self.node_device.client.get_approximated_num_batches()
print(f"\n\n:: BATCHES :: {num_batches}\n\n")
print(f":: BATCHES :: {num_batches}")
for batch_index in range(num_batches):
client_forward_pass_responses = []
futures = [
executor.submit(client_training_job, client_id, batch_index)
for client_id in clients
......@@ -268,20 +266,17 @@ class DeviceServer:
client_ids = [b[0] for b in client_forward_pass_responses]
client_batches = [b[1] for b in client_forward_pass_responses]
print(f"\n\n\nBATCHES: {len(client_batches)}\n\n\n")
server_batch = _concat_smashed_data(
[b[0].to(self._device) for b in client_batches]
)
server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in client_batches]
)
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client.
server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold.
print(
......@@ -301,25 +296,23 @@ class DeviceServer:
print(
f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
)
client_gradients = torch.chunk(server_gradients, num_client_gradients)
# clone single client gradients so that client_gradients is not a list of views of server_gradients
# if we just use torch.chunk, each client will receive the whole server_gradients
client_gradients = [
t.clone().detach()
for t in torch.chunk(server_gradients, num_client_gradients)
]
futures = [
executor.submit(
client_backpropagation_job, client_id, client_gradients[idx]
)
for (idx, client_id) in enumerate(client_ids)
]
client_backpropagation_results = []
client_backpropagation_gradients = []
for future in concurrent.futures.as_completed(futures):
client_backpropagation_results.append(future.result())
client_backpropagation_gradients = [
result[1]
for result in client_backpropagation_results
if result is not None and result is not False
]
_, grads = future.result()
if grads is not None and grads is not False:
client_backpropagation_gradients.append(grads)
# We want to average the client's backpropagation gradients and send them over again to finalize the
# current training step.
averaged_gradient = _calculate_gradient_mean(
......@@ -340,7 +333,6 @@ class DeviceServer:
for client_id in clients:
train_metrics = self.finalize_metrics(str(client_id), "train")
print(f"::: evaluating on {client_id}")
evaluation_diagnostics_metrics = self.node_device.evaluate_on(
device_id=client_id,
server_device=self.node_device.device_id,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment