Skip to content
Snippets Groups Projects
Commit be7a70f9 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Added adaptive threshold for SwarmSL. Fixed setting the optimizer state for SwarmSL and PSL

parent 4824b08d
No related branches found
No related tags found
1 merge request!24Autoencoder, ATM and global optimizer
......@@ -3,3 +3,4 @@ _target_: edml.controllers.swarm_controller.SwarmController
_partial_: true
defaults:
- scheduler: sequential
- adaptive_threshold_fn: !!null
from typing import Any
from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn
from edml.controllers.adaptive_threshold_mechanism.static import (
StaticAdaptiveThresholdFn,
)
from edml.controllers.base_controller import BaseController
from edml.controllers.scheduler.base import NextServerScheduler
from edml.helpers.config_helpers import get_device_index_by_id
......@@ -7,10 +11,16 @@ from edml.helpers.config_helpers import get_device_index_by_id
class SwarmController(BaseController):
def __init__(self, cfg, scheduler: NextServerScheduler):
def __init__(
self,
cfg,
scheduler: NextServerScheduler,
adaptive_threshold_fn: AdaptiveThresholdFn = StaticAdaptiveThresholdFn(0.0),
):
super().__init__(cfg)
scheduler.initialize(self)
self._next_server_scheduler = scheduler
self._adaptive_threshold_fn = adaptive_threshold_fn
def _train(self):
client_weights = None
......@@ -87,10 +97,13 @@ class SwarmController(BaseController):
device_id=server_device_id, state_dict=server_weights, on_client=False
)
adaptive_threshold = self._adaptive_threshold_fn.invoke(round_no)
self.logger.log({"adaptive-threshold": adaptive_threshold})
training_response = self.request_dispatcher.train_global_on(
server_device_id,
epochs=1,
round_no=round_no,
adaptive_learning_threshold=adaptive_threshold,
optimizer_state=optimizer_state,
)
......
......@@ -303,15 +303,16 @@ class DeviceClient:
break
server_grad, _server_loss, diagnostic_metrics = train_batch_response
diagnostic_metric_container.merge(diagnostic_metrics)
self.node_device.battery.update_flops(
self._model_flops["BW"] * len(batch_data)
)
server_grad = server_grad.to(self._device)
if has_autoencoder(self._model):
self._model.trainable_layers_output.backward(server_grad)
else:
smashed_data.backward(server_grad)
self._optimizer.step()
if server_grad is not None: # otherwise threshold was applied
self.node_device.battery.update_flops(
self._model_flops["BW"] * len(batch_data)
)
server_grad = server_grad.to(self._device)
if has_autoencoder(self._model):
self._model.trainable_layers_output.backward(server_grad)
else:
smashed_data.backward(server_grad)
self._optimizer.step()
client_train_time = (
time.time() - client_train_start_time - sum(server_train_batch_times)
......
......@@ -303,7 +303,11 @@ class NetworkDevice(Device):
@update_battery
@log_execution_time("logger", "train_global_time")
def train_global(
self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None
self,
epochs: int,
round_no: int = -1,
adaptive_learning_threshold: Optional[float] = None,
optimizer_state: dict[str, Any] = None,
) -> Tuple[
Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
]:
......@@ -311,6 +315,7 @@ class NetworkDevice(Device):
devices=self.__get_device_ids__(),
epochs=epochs,
round_no=round_no,
adaptive_learning_threshold=adaptive_learning_threshold,
optimizer_state=optimizer_state,
)
......@@ -448,7 +453,12 @@ class RPCDeviceServicer(DeviceServicer):
def TrainGlobal(self, request, context):
print(f"Called TrainGlobal on device {self.device.device_id}")
client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
self.device.train_global(request.epochs, request.round_no)
self.device.train_global(
request.epochs,
request.round_no,
request.adaptive_learning_threshold,
proto_to_state_dict(request.optimizer_state),
)
)
response = connection_pb2.TrainGlobalResponse(
client_weights=Weights(weights=state_dict_to_proto(client_weights)),
......@@ -568,12 +578,14 @@ class RPCDeviceServicer(DeviceServicer):
clients = self.device.__get_device_ids__()
round_no = request.round_no
adaptive_learning_threshold = request.adaptive_learning_threshold
optimizer_state = proto_to_state_dict(request.optimizer_state)
cw, sw, model_metrics, optimizer_state, diagnostic_metrics = (
self.device.train_parallel_split_learning(
clients=clients,
round_no=round_no,
adaptive_learning_threshold=adaptive_learning_threshold,
optimizer_state=optimizer_state,
)
)
response = connection_pb2.TrainGlobalParallelSplitLearningResponse(
......@@ -761,6 +773,7 @@ class DeviceRequestDispatcher:
device_id: str,
epochs: int,
round_no: int = -1,
adaptive_learning_threshold: Optional[float] = None,
optimizer_state: dict[str, Any] = None,
) -> Union[
Tuple[
......@@ -777,6 +790,7 @@ class DeviceRequestDispatcher:
connection_pb2.TrainGlobalRequest(
epochs=epochs,
round_no=round_no,
adaptive_learning_threshold=adaptive_learning_threshold,
optimizer_state=state_dict_to_proto(optimizer_state),
)
)
......
......@@ -51,6 +51,7 @@ class DeviceServer:
self._cfg = cfg
self.node_device: Optional[Device] = None
self.latency_factor = latency_factor
self.adaptive_learning_threshold = None
def set_device(self, node_device: Device):
"""Sets the device reference for the server."""
......@@ -73,6 +74,7 @@ class DeviceServer:
devices: List[str],
epochs: int = 1,
round_no: int = -1,
adaptive_learning_threshold: Optional[float] = None,
optimizer_state: dict[str, Any] = None,
) -> Tuple[
Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
......@@ -83,13 +85,19 @@ class DeviceServer:
devices: The devices to train on
epochs: Optionally, the number of epochs to train.
round_no: Optionally, the current global epoch number if a learning rate scheduler is used.
adaptive_learning_threshold: Optionally, the loss threshold to not send the gradients to the client
optimizer_state: Optionally, the optimizer_state to proceed from
"""
client_weights = None
metrics = ModelMetricResultContainer()
diagnostic_metric_container = DiagnosticMetricResultContainer()
if optimizer_state is not None:
print(f"apply optimizer state {optimizer_state}")
self._optimizer.load_state_dict(optimizer_state)
else:
print("optimizer state is None")
if adaptive_learning_threshold is not None:
self.adaptive_learning_threshold = adaptive_learning_threshold
for epoch in range(epochs):
if self._lr_scheduler is not None:
if round_no != -1:
......@@ -98,7 +106,7 @@ class DeviceServer:
self._lr_scheduler.step()
for device_id in devices:
print(
f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id} and threshold {self.adaptive_learning_threshold}"
)
if client_weights is not None:
self.node_device.set_weights_on(
......@@ -135,7 +143,7 @@ class DeviceServer:
)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]:
def train_batch(self, smashed_data, labels) -> Tuple[Optional[Variable], float]:
"""Train the model on the given batch of data and labels.
Returns the gradients of the model's parameters."""
smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)
......@@ -161,12 +169,20 @@ class DeviceServer:
# Capturing training metrics for the current batch.
self.node_device.log({"loss": loss_train.item()})
self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())
if has_autoencoder(self._model):
return self._model.trainable_layers_input.grad, loss_train.item()
return (
smashed_data.grad,
loss_train.item(),
) # hier sollten beim AE die gradients vom server model vor dem decoder zurückgegeben werden
gradients = self._model.trainable_layers_input.grad
else:
gradients = smashed_data.grad
if (
self.adaptive_learning_threshold
and loss_train.item() < self.adaptive_learning_threshold
):
self.node_device.log(
{"adaptive_learning_threshold_applied": gradients.size(0)}
)
return None, loss_train.item()
return gradients, loss_train.item()
def _set_model_flops(self, sample):
"""Helper to determine the model flops when smashed data are available for the first time."""
......@@ -261,7 +277,8 @@ class DeviceServer:
self._lr_scheduler.step(round_no + 1) # epoch=1
else:
self._lr_scheduler.step()
if adaptive_learning_threshold is not None:
self.adaptive_learning_threshold = adaptive_learning_threshold
num_threads = len(clients)
executor = create_executor_with_threads(num_threads)
......@@ -311,25 +328,9 @@ class DeviceServer:
server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold.
print(
f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
)
if (
adaptive_learning_threshold
and server_loss < adaptive_learning_threshold
):
print(
f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
)
self.node_device.log(
{
"adaptive_learning_threshold_applied": server_gradients.size(
0
)
}
)
server_gradients is None
): # loss threshold was reached, skip client backprop
continue
num_client_gradients = len(client_forward_pass_responses)
......@@ -409,25 +410,9 @@ class DeviceServer:
server_gradients, server_loss, server_metrics = (
self.node_device.train_batch(server_batch, server_labels)
) # DiagnosticMetricResultContainer
# We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
# do the client propagation once the current loss value is larger than the threshold.
print(
f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
)
if (
adaptive_learning_threshold
and server_loss < adaptive_learning_threshold
):
print(
f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
)
self.node_device.log(
{
"adaptive_learning_threshold_applied": server_gradients.size(
0
)
}
)
server_gradients is None
): # loss threshold was reached, skip client backprop
continue
num_client_gradients = len(client_forward_pass_responses)
......
This diff is collapsed.
......@@ -72,14 +72,16 @@ class TrainGlobalParallelSplitLearningResponse(_message.Message):
def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
class TrainGlobalRequest(_message.Message):
__slots__ = ["epochs", "round_no", "optimizer_state"]
__slots__ = ["epochs", "round_no", "adaptive_learning_threshold", "optimizer_state"]
EPOCHS_FIELD_NUMBER: _ClassVar[int]
ROUND_NO_FIELD_NUMBER: _ClassVar[int]
ADAPTIVE_LEARNING_THRESHOLD_FIELD_NUMBER: _ClassVar[int]
OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
epochs: int
round_no: int
adaptive_learning_threshold: float
optimizer_state: _datastructures_pb2.StateDict
def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., adaptive_learning_threshold: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
class TrainGlobalResponse(_message.Message):
__slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"]
......
......@@ -66,7 +66,8 @@ message TrainGlobalParallelSplitLearningResponse {
message TrainGlobalRequest {
int32 epochs = 1;
optional int32 round_no = 2;
optional StateDict optimizer_state = 3;
optional double adaptive_learning_threshold = 3;
optional StateDict optimizer_state = 4;
}
......
......@@ -37,7 +37,9 @@ class SwarmControllerTest(unittest.TestCase):
)
client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
self.swarm_controller._swarm_train_round(None, None, "d1", 0)
self.swarm_controller._swarm_train_round(
None, None, "d1", 0, optimizer_state={"optimizer_state": 43}
)
)
self.assertEqual(client_weights, {"weights": 42})
......@@ -52,7 +54,11 @@ class SwarmControllerTest(unittest.TestCase):
]
)
self.mock.train_global_on.assert_called_once_with(
"d1", epochs=1, round_no=0, optimizer_state=None
"d1",
epochs=1,
round_no=0,
adaptive_learning_threshold=0.0,
optimizer_state={"optimizer_state": 43},
)
def test_split_train_round_with_inactive_server_device(self):
......@@ -74,7 +80,11 @@ class SwarmControllerTest(unittest.TestCase):
]
)
self.mock.train_global_on.assert_called_once_with(
"d1", epochs=1, round_no=0, optimizer_state=None
"d1",
epochs=1,
round_no=0,
adaptive_learning_threshold=0.0,
optimizer_state=None,
)
......
......@@ -131,7 +131,12 @@ class RPCDeviceServicerTest(unittest.TestCase):
{"optimizer_state": 44},
self.diagnostic_metrics,
)
request = connection_pb2.TrainGlobalRequest(epochs=42)
request = connection_pb2.TrainGlobalRequest(
epochs=42,
round_no=1,
adaptive_learning_threshold=3,
optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
)
response, metadata, code, details = self.make_call("TrainGlobal", request)
......@@ -147,7 +152,9 @@ class RPCDeviceServicerTest(unittest.TestCase):
proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44}
)
self.assertEqual(code, StatusCode.OK)
self.mock_device.train_global.assert_called_once_with(42, 0)
self.mock_device.train_global.assert_called_once_with(
42, 1, 3, {"optimizer_state": 42}
)
self.assertEqual(
proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
)
......@@ -357,7 +364,9 @@ class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase):
self.mock_device.train_global.side_effect = BatteryEmptyException(
"Battery empty"
)
request = connection_pb2.TrainGlobalRequest()
request = connection_pb2.TrainGlobalRequest(
optimizer_state=state_dict_to_proto(None)
)
self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal")
def test_stop_at_set_weights(self):
......@@ -504,7 +513,7 @@ class RequestDispatcherTest(unittest.TestCase):
)
client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
self.dispatcher.train_global_on("1", 42, 43)
self.dispatcher.train_global_on("1", 42, 43, 3, {"optimizer_state": 44})
)
self.assertEqual(client_weights, self.weights)
......@@ -515,19 +524,31 @@ class RequestDispatcherTest(unittest.TestCase):
self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
self.mock_stub.TrainGlobal.assert_called_once_with(
connection_pb2.TrainGlobalRequest(
epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None)
epochs=42,
round_no=43,
adaptive_learning_threshold=3,
optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
)
)
def test_train_global_on_with_error(self):
self.mock_stub.TrainGlobal.side_effect = grpc.RpcError()
response = self.dispatcher.train_global_on("1", 42, round_no=43)
response = self.dispatcher.train_global_on(
"1",
42,
round_no=43,
adaptive_learning_threshold=3,
optimizer_state={"optimizer_state": 44},
)
self.assertEqual(response, False)
self.mock_stub.TrainGlobal.assert_called_once_with(
connection_pb2.TrainGlobalRequest(
epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None)
epochs=42,
round_no=43,
adaptive_learning_threshold=3,
optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
)
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment