Skip to content
Snippets Groups Projects
Commit 791e787e authored by Sven Michael Lechner's avatar Sven Michael Lechner
Browse files

feat: add config for simulated parallelism

This is a quick and dirty fix to make it easier to run experiments.
A proper implementation should relocate the parallelism into the
corresponding dispatcher classes.

Closes #1
parent be776bcc
Branches
No related tags found
1 merge request!23Simulating parallel execution
...@@ -14,6 +14,9 @@ defaults: ...@@ -14,6 +14,9 @@ defaults:
- wandb: default - wandb: default
- _self_ - _self_
# If true, controllers will run devices in parallel. If false, they will run sequentially and their runtime is corrected
# to account for the parallelism in post-processing.
simulate_parallelism: False
own_device_id: "d0" own_device_id: "d0"
num_devices: ${len:${topology.devices}} num_devices: ${len:${topology.devices}}
......
import concurrent.futures import concurrent.futures
import contextlib
import functools
import threading import threading
from typing import Dict, List from typing import Dict, List
...@@ -36,6 +38,14 @@ def fed_average(model_weights: list[Dict], weighting_scheme: List[float] = None) ...@@ -36,6 +38,14 @@ def fed_average(model_weights: list[Dict], weighting_scheme: List[float] = None)
return None return None
def simulate_parallelism(f):
@functools.wraps(f)
def inner(self):
pass
return inner
class FedController(BaseController): class FedController(BaseController):
"""Controller for federated learning.""" """Controller for federated learning."""
...@@ -48,13 +58,48 @@ class FedController(BaseController): ...@@ -48,13 +58,48 @@ class FedController(BaseController):
server_weights = [] server_weights = []
samples_count = [] samples_count = []
metrics_container = ModelMetricResultContainer() metrics_container = ModelMetricResultContainer()
parallel_times = []
with Timer() as elapsed_time: if self.cfg.simulate_parallelism:
for device_id in self.active_devices: parallel_times = []
with Timer() as individual_time: with Timer() as elapsed_time:
response = self.request_dispatcher.federated_train_on( for device_id in self.active_devices:
device_id, round_no with Timer() as individual_time:
response = self.request_dispatcher.federated_train_on(
device_id, round_no
)
if response is not False:
(
new_client_weights,
new_server_weights,
num_samples,
metrics,
_,
) = response # skip diagnostic metrics
client_weights.append(new_client_weights)
server_weights.append(new_server_weights)
samples_count.append(num_samples)
metrics_container.merge(metrics)
parallel_times.append(individual_time.execution_time)
self.logger.log(
{
"parallel_fed_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
}
)
else:
with concurrent.futures.ThreadPoolExecutor(
max_workers=max(len(self.active_devices), 1)
) as executor: # avoid exception when setting 0 workers
futures = [
executor.submit(
self.request_dispatcher.federated_train_on, device_id, round_no
) )
for device_id in self.active_devices
]
for future in concurrent.futures.as_completed(futures):
response = future.result()
if response is not False: if response is not False:
( (
new_client_weights, new_client_weights,
...@@ -67,17 +112,8 @@ class FedController(BaseController): ...@@ -67,17 +112,8 @@ class FedController(BaseController):
server_weights.append(new_server_weights) server_weights.append(new_server_weights)
samples_count.append(num_samples) samples_count.append(num_samples)
metrics_container.merge(metrics) metrics_container.merge(metrics)
parallel_times.append(individual_time.execution_time)
self.logger.log(
{
"parallel_fed_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
}
)
print(f"samples count {samples_count}")
print(f"samples count {samples_count}")
return ( return (
fed_average(model_weights=client_weights, weighting_scheme=samples_count), fed_average(model_weights=client_weights, weighting_scheme=samples_count),
fed_average(model_weights=server_weights, weighting_scheme=samples_count), fed_average(model_weights=server_weights, weighting_scheme=samples_count),
......
...@@ -260,136 +260,245 @@ class DeviceServer: ...@@ -260,136 +260,245 @@ class DeviceServer:
# We iterate over each batch, initializing all client training at once and processing the results afterward. # We iterate over each batch, initializing all client training at once and processing the results afterward.
num_batches = self.node_device.client.get_approximated_num_batches() num_batches = self.node_device.client.get_approximated_num_batches()
print(f":: BATCHES :: {num_batches}") print(f":: BATCHES :: {num_batches}")
for batch_index in range(num_batches):
client_forward_pass_responses = [] if self._cfg.simulate_parallelism:
parallel_times = [] for batch_index in range(num_batches):
with Timer() as elapsed_time: client_forward_pass_responses = []
for client_id in clients: parallel_times = []
with Timer() as individual_time: with Timer() as elapsed_time:
(client_id, result) = client_training_job( for client_id in clients:
client_id, batch_index with Timer() as individual_time:
) (client_id, result) = client_training_job(
if result is not None and result is not False: client_id, batch_index
client_forward_pass_responses.append((client_id, result)) )
parallel_times.append(individual_time.execution_time) if result is not None and result is not False:
self.node_device.log( client_forward_pass_responses.append(
{ (client_id, result)
"parallel_client_train_time": { )
"elapsed_time": elapsed_time.execution_time, parallel_times.append(individual_time.execution_time)
"parallel_time": max(parallel_times), self.node_device.log(
{
"parallel_client_train_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
} }
} )
) # We want to split up the responses into a list of client IDs and batches again.
# We want to split up the responses into a list of client IDs and batches again. client_ids = [b[0] for b in client_forward_pass_responses]
client_ids = [b[0] for b in client_forward_pass_responses] client_batches = [b[1] for b in client_forward_pass_responses]
client_batches = [b[1] for b in client_forward_pass_responses]
server_batch = _concat_smashed_data( server_batch = _concat_smashed_data(
[b[0].to(self._device) for b in client_batches] [b[0].to(self._device) for b in client_batches]
) )
server_labels = _concat_smashed_data( server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in client_batches] [b[1].to(self._device) for b in client_batches]
) )
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client. # to split the gradients back into batch-sized tensors to average them before sending them to the client.
server_gradients, server_loss, server_metrics = ( server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels) self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer ) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold. # do the client propagation once the current loss value is larger than the threshold.
print( print(
f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}" f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
) )
if ( if (
adaptive_learning_threshold adaptive_learning_threshold
and server_loss < adaptive_learning_threshold and server_loss < adaptive_learning_threshold
): ):
print(
f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
)
self.node_device.log(
{
"adaptive_learning_threshold_applied": server_gradients.size(
0
)
}
)
continue
num_client_gradients = len(client_forward_pass_responses)
print( print(
f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}" f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
) )
# clone single client gradients so that client_gradients is not a list of views of server_gradients
# if we just use torch.chunk, each client will receive the whole server_gradients
client_gradients = [
t.clone().detach()
for t in torch.chunk(server_gradients, num_client_gradients)
]
client_backpropagation_gradients = []
parallel_times = []
with Timer() as elapsed_time:
for idx, client_id in enumerate(client_ids):
with Timer() as individual_time:
_, grads = client_backpropagation_job(
client_id, client_gradients[idx]
)
if grads is not None and grads is not False:
client_backpropagation_gradients.append(grads)
parallel_times.append(individual_time.execution_time)
self.node_device.log( self.node_device.log(
{"adaptive_learning_threshold_applied": server_gradients.size(0)} {
"parallel_client_backprop_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
}
) )
continue # We want to average the client's backpropagation gradients and send them over again to finalize the
# current training step.
num_client_gradients = len(client_forward_pass_responses) averaged_gradient = _calculate_gradient_mean(
print( client_backpropagation_gradients, self._device
f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}" )
) parallel_times = []
# clone single client gradients so that client_gradients is not a list of views of server_gradients with Timer() as elapsed_time:
# if we just use torch.chunk, each client will receive the whole server_gradients for client_id in clients:
client_gradients = [ with Timer() as individual_time:
t.clone().detach() client_set_gradient_and_finalize_training_job(
for t in torch.chunk(server_gradients, num_client_gradients) client_id, averaged_gradient
] )
client_backpropagation_gradients = [] parallel_times.append(individual_time.execution_time)
parallel_times = [] self.node_device.log(
with Timer() as elapsed_time: {
for idx, client_id in enumerate(client_ids): "parallel_client_model_update_time": {
with Timer() as individual_time: "elapsed_time": elapsed_time.execution_time,
_, grads = client_backpropagation_job( "parallel_time": max(parallel_times),
client_id, client_gradients[idx] }
)
if grads is not None and grads is not False:
client_backpropagation_gradients.append(grads)
parallel_times.append(individual_time.execution_time)
self.node_device.log(
{
"parallel_client_backprop_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
} }
} )
) else:
# We want to average the client's backpropagation gradients and send them over again to finalize the for batch_index in range(num_batches):
# current training step. client_forward_pass_responses = []
averaged_gradient = _calculate_gradient_mean( futures = [
client_backpropagation_gradients, self._device executor.submit(client_training_job, client_id, batch_index)
) for client_id in clients
]
for future in concurrent.futures.as_completed(futures):
(client_id, result) = future.result()
if result is not None and result is not False:
client_forward_pass_responses.append((client_id, result))
# We want to split up the responses into a list of client IDs and batches again.
client_ids = [b[0] for b in client_forward_pass_responses]
client_batches = [b[1] for b in client_forward_pass_responses]
server_batch = _concat_smashed_data(
[b[0].to(self._device) for b in client_batches]
)
server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in client_batches]
)
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client.
server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold.
print(
f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
)
if (
adaptive_learning_threshold
and server_loss < adaptive_learning_threshold
):
print(
f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
)
self.node_device.log(
{
"adaptive_learning_threshold_applied": server_gradients.size(
0
)
}
)
continue
num_client_gradients = len(client_forward_pass_responses)
print(
f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
)
# clone single client gradients so that client_gradients is not a list of views of server_gradients
# if we just use torch.chunk, each client will receive the whole server_gradients
client_gradients = [
t.clone().detach()
for t in torch.chunk(server_gradients, num_client_gradients)
]
futures = [
executor.submit(
client_backpropagation_job, client_id, client_gradients[idx]
)
for (idx, client_id) in enumerate(client_ids)
]
client_backpropagation_gradients = []
for future in concurrent.futures.as_completed(futures):
_, grads = future.result()
if grads is not None and grads is not False:
client_backpropagation_gradients.append(grads)
# We want to average the client's backpropagation gradients and send them over again to finalize the
# current training step.
averaged_gradient = _calculate_gradient_mean(
client_backpropagation_gradients, self._device
)
futures = [
executor.submit(
client_set_gradient_and_finalize_training_job,
client_id,
averaged_gradient,
)
for client_id in clients
]
for future in concurrent.futures.as_completed(futures):
future.result()
# Now we have to determine the model metrics for each client.
if self._cfg.simulate_parallelism:
parallel_times = [] parallel_times = []
with Timer() as elapsed_time: with Timer() as elapsed_time:
for client_id in clients: for client_id in clients:
with Timer() as individual_time: with Timer() as individual_time:
client_set_gradient_and_finalize_training_job( train_metrics = self.finalize_metrics(str(client_id), "train")
client_id, averaged_gradient
evaluation_diagnostics_metrics = self.node_device.evaluate_on(
device_id=client_id,
server_device=self.node_device.device_id,
val=True,
) )
# if evaluation_diagnostics_metrics:
# diagnostic_metrics.merge(evaluation_diagnostics_metrics)
val_metrics = self.finalize_metrics(str(client_id), "val")
model_metrics.add_results(train_metrics)
model_metrics.add_results(val_metrics)
parallel_times.append(individual_time.execution_time) parallel_times.append(individual_time.execution_time)
self.node_device.log( self.node_device.log(
{ {
"parallel_client_model_update_time": { "parallel_client_eval_time": {
"elapsed_time": elapsed_time.execution_time, "elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times), "parallel_time": max(parallel_times),
} }
} }
) )
else:
# Now we have to determine the model metrics for each client.
parallel_times = []
with Timer() as elapsed_time:
for client_id in clients: for client_id in clients:
with Timer() as individual_time: train_metrics = self.finalize_metrics(str(client_id), "train")
train_metrics = self.finalize_metrics(str(client_id), "train")
evaluation_diagnostics_metrics = self.node_device.evaluate_on( evaluation_diagnostics_metrics = self.node_device.evaluate_on(
device_id=client_id, device_id=client_id,
server_device=self.node_device.device_id, server_device=self.node_device.device_id,
val=True, val=True,
) )
# if evaluation_diagnostics_metrics: # if evaluation_diagnostics_metrics:
# diagnostic_metrics.merge(evaluation_diagnostics_metrics) # diagnostic_metrics.merge(evaluation_diagnostics_metrics)
val_metrics = self.finalize_metrics(str(client_id), "val") val_metrics = self.finalize_metrics(str(client_id), "val")
model_metrics.add_results(train_metrics) model_metrics.add_results(train_metrics)
model_metrics.add_results(val_metrics) model_metrics.add_results(val_metrics)
parallel_times.append(individual_time.execution_time)
self.node_device.log(
{
"parallel_client_eval_time": {
"elapsed_time": elapsed_time.execution_time,
"parallel_time": max(parallel_times),
}
}
)
optimizer_state = self._optimizer.state_dict() optimizer_state = self._optimizer.state_dict()
# delete references and free GPU memory manually # delete references and free GPU memory manually
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment