diff --git a/edml/core/client.py b/edml/core/client.py index fe41e2cdb3c0be6847b37df8ae3c7cae8c4eba7d..b3c3e7491e718cf98c9ccde598e00fd4d2054996 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -220,7 +220,7 @@ class DeviceClient: gradients = [] for param in self._model.parameters(): - if param is not None: + if param.grad is not None: gradients.append(param.grad) else: gradients.append(torch.zeros_like(param)) diff --git a/edml/core/device.py b/edml/core/device.py index 51278532bfb7f1888f13671a24a289fc10e868b2..1348bea7e4a4fadb44a6bd304bfc466f44bbba41 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -575,7 +575,6 @@ class RPCDeviceServicer(DeviceServicer): def TrainSingleBatchOnClient(self, request, context): batch_index = request.batch_index - print(f"Starting single batch@{batch_index}") smashed_data, labels = self.device.client.train_single_batch(batch_index) diff --git a/edml/core/server.py b/edml/core/server.py index 8853bd92bc3d5dc438e7ada06cb4dce5e62967c5..5e2f8235de476c5a998f9c8c040b462ac9d3b198 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -1,12 +1,11 @@ from __future__ import annotations import concurrent.futures -import time from typing import List, Optional, Tuple, Any, TYPE_CHECKING import torch from omegaconf import DictConfig -from colorama import init, Fore +from colorama import Fore from torch import nn from torch.autograd import Variable @@ -20,7 +19,7 @@ from edml.helpers.metrics import ( ModelMetricResultContainer, DiagnosticMetricResultContainer, ) -from edml.helpers.types import StateDict, SLTrainBatchResult, LossFn +from edml.helpers.types import StateDict, LossFn if TYPE_CHECKING: from edml.core.device import Device @@ -251,10 +250,9 @@ class DeviceServer: # We iterate over each batch, initializing all client training at once and processing the results afterward. num_batches = self.node_device.client.get_approximated_num_batches() - print(f"\n\n:: BATCHES :: {num_batches}\n\n") + print(f":: BATCHES :: {num_batches}") for batch_index in range(num_batches): client_forward_pass_responses = [] - futures = [ executor.submit(client_training_job, client_id, batch_index) for client_id in clients @@ -268,20 +266,17 @@ class DeviceServer: client_ids = [b[0] for b in client_forward_pass_responses] client_batches = [b[1] for b in client_forward_pass_responses] - print(f"\n\n\nBATCHES: {len(client_batches)}\n\n\n") server_batch = _concat_smashed_data( [b[0].to(self._device) for b in client_batches] ) server_labels = _concat_smashed_data( [b[1].to(self._device) for b in client_batches] ) - # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need # to split the gradients back into batch-sized tensors to average them before sending them to the client. server_gradients, server_loss, server_metrics = ( self.node_device.train_batch(server_batch, server_labels) ) # DiagnosticMetricResultContainer - # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only # do the client propagation once the current loss value is larger than the threshold. print( @@ -301,25 +296,23 @@ class DeviceServer: print( f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}" ) - - client_gradients = torch.chunk(server_gradients, num_client_gradients) - + # clone single client gradients so that client_gradients is not a list of views of server_gradients + # if we just use torch.chunk, each client will receive the whole server_gradients + client_gradients = [ + t.clone().detach() + for t in torch.chunk(server_gradients, num_client_gradients) + ] futures = [ executor.submit( client_backpropagation_job, client_id, client_gradients[idx] ) for (idx, client_id) in enumerate(client_ids) ] - client_backpropagation_results = [] + client_backpropagation_gradients = [] for future in concurrent.futures.as_completed(futures): - client_backpropagation_results.append(future.result()) - - client_backpropagation_gradients = [ - result[1] - for result in client_backpropagation_results - if result is not None and result is not False - ] - + _, grads = future.result() + if grads is not None and grads is not False: + client_backpropagation_gradients.append(grads) # We want to average the client's backpropagation gradients and send them over again to finalize the # current training step. averaged_gradient = _calculate_gradient_mean( @@ -340,7 +333,6 @@ class DeviceServer: for client_id in clients: train_metrics = self.finalize_metrics(str(client_id), "train") - print(f"::: evaluating on {client_id}") evaluation_diagnostics_metrics = self.node_device.evaluate_on( device_id=client_id, server_device=self.node_device.device_id,