Skip to content
Snippets Groups Projects
Commit 0e8c7aa6 authored by Sven Michael Lechner's avatar Sven Michael Lechner
Browse files

Merge branch 'psl-performance' into 'main'

PSL performance

See merge request !19
parents dcc9296f 8552a410
No related branches found
No related tags found
1 merge request!19PSL performance
...@@ -220,7 +220,7 @@ class DeviceClient: ...@@ -220,7 +220,7 @@ class DeviceClient:
gradients = [] gradients = []
for param in self._model.parameters(): for param in self._model.parameters():
if param is not None: if param.grad is not None:
gradients.append(param.grad) gradients.append(param.grad)
else: else:
gradients.append(torch.zeros_like(param)) gradients.append(torch.zeros_like(param))
......
...@@ -575,7 +575,6 @@ class RPCDeviceServicer(DeviceServicer): ...@@ -575,7 +575,6 @@ class RPCDeviceServicer(DeviceServicer):
def TrainSingleBatchOnClient(self, request, context): def TrainSingleBatchOnClient(self, request, context):
batch_index = request.batch_index batch_index = request.batch_index
print(f"Starting single batch@{batch_index}")
smashed_data, labels = self.device.client.train_single_batch(batch_index) smashed_data, labels = self.device.client.train_single_batch(batch_index)
......
from __future__ import annotations from __future__ import annotations
import concurrent.futures import concurrent.futures
import time
from typing import List, Optional, Tuple, Any, TYPE_CHECKING from typing import List, Optional, Tuple, Any, TYPE_CHECKING
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from colorama import init, Fore from colorama import Fore
from torch import nn from torch import nn
from torch.autograd import Variable from torch.autograd import Variable
...@@ -20,7 +19,7 @@ from edml.helpers.metrics import ( ...@@ -20,7 +19,7 @@ from edml.helpers.metrics import (
ModelMetricResultContainer, ModelMetricResultContainer,
DiagnosticMetricResultContainer, DiagnosticMetricResultContainer,
) )
from edml.helpers.types import StateDict, SLTrainBatchResult, LossFn from edml.helpers.types import StateDict, LossFn
if TYPE_CHECKING: if TYPE_CHECKING:
from edml.core.device import Device from edml.core.device import Device
...@@ -251,10 +250,9 @@ class DeviceServer: ...@@ -251,10 +250,9 @@ class DeviceServer:
# We iterate over each batch, initializing all client training at once and processing the results afterward. # We iterate over each batch, initializing all client training at once and processing the results afterward.
num_batches = self.node_device.client.get_approximated_num_batches() num_batches = self.node_device.client.get_approximated_num_batches()
print(f"\n\n:: BATCHES :: {num_batches}\n\n") print(f":: BATCHES :: {num_batches}")
for batch_index in range(num_batches): for batch_index in range(num_batches):
client_forward_pass_responses = [] client_forward_pass_responses = []
futures = [ futures = [
executor.submit(client_training_job, client_id, batch_index) executor.submit(client_training_job, client_id, batch_index)
for client_id in clients for client_id in clients
...@@ -268,20 +266,17 @@ class DeviceServer: ...@@ -268,20 +266,17 @@ class DeviceServer:
client_ids = [b[0] for b in client_forward_pass_responses] client_ids = [b[0] for b in client_forward_pass_responses]
client_batches = [b[1] for b in client_forward_pass_responses] client_batches = [b[1] for b in client_forward_pass_responses]
print(f"\n\n\nBATCHES: {len(client_batches)}\n\n\n")
server_batch = _concat_smashed_data( server_batch = _concat_smashed_data(
[b[0].to(self._device) for b in client_batches] [b[0].to(self._device) for b in client_batches]
) )
server_labels = _concat_smashed_data( server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in client_batches] [b[1].to(self._device) for b in client_batches]
) )
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client. # to split the gradients back into batch-sized tensors to average them before sending them to the client.
server_gradients, server_loss, server_metrics = ( server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels) self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer ) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold. # do the client propagation once the current loss value is larger than the threshold.
print( print(
...@@ -301,25 +296,23 @@ class DeviceServer: ...@@ -301,25 +296,23 @@ class DeviceServer:
print( print(
f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}" f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
) )
# clone single client gradients so that client_gradients is not a list of views of server_gradients
client_gradients = torch.chunk(server_gradients, num_client_gradients) # if we just use torch.chunk, each client will receive the whole server_gradients
client_gradients = [
t.clone().detach()
for t in torch.chunk(server_gradients, num_client_gradients)
]
futures = [ futures = [
executor.submit( executor.submit(
client_backpropagation_job, client_id, client_gradients[idx] client_backpropagation_job, client_id, client_gradients[idx]
) )
for (idx, client_id) in enumerate(client_ids) for (idx, client_id) in enumerate(client_ids)
] ]
client_backpropagation_results = [] client_backpropagation_gradients = []
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
client_backpropagation_results.append(future.result()) _, grads = future.result()
if grads is not None and grads is not False:
client_backpropagation_gradients = [ client_backpropagation_gradients.append(grads)
result[1]
for result in client_backpropagation_results
if result is not None and result is not False
]
# We want to average the client's backpropagation gradients and send them over again to finalize the # We want to average the client's backpropagation gradients and send them over again to finalize the
# current training step. # current training step.
averaged_gradient = _calculate_gradient_mean( averaged_gradient = _calculate_gradient_mean(
...@@ -340,7 +333,6 @@ class DeviceServer: ...@@ -340,7 +333,6 @@ class DeviceServer:
for client_id in clients: for client_id in clients:
train_metrics = self.finalize_metrics(str(client_id), "train") train_metrics = self.finalize_metrics(str(client_id), "train")
print(f"::: evaluating on {client_id}")
evaluation_diagnostics_metrics = self.node_device.evaluate_on( evaluation_diagnostics_metrics = self.node_device.evaluate_on(
device_id=client_id, device_id=client_id,
server_device=self.node_device.device_id, server_device=self.node_device.device_id,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment