diff --git a/edml/config/experiment/cifar.yaml b/edml/config/experiment/cifar.yaml index d2c4755192042f945a6ac9e6aab388ab3be2dfe7..7e43040363e21fafd69686d7432cfd92d443e7ba 100644 --- a/edml/config/experiment/cifar.yaml +++ b/edml/config/experiment/cifar.yaml @@ -13,7 +13,6 @@ scheduler_gamma: 0.1 max_epochs: 1 max_rounds: 200 metrics: [ accuracy ] -load_weights: False save_weights: True server_model_load_path: "edml/models/weights/initial/Resnet18_Server_random_weights.pth" client_model_load_path: "edml/models/weights/initial/Resnet18_Client_random_weights.pth" @@ -24,6 +23,6 @@ fractions: !!null random_seed: 42 load_single_batch_for_debugging: False early_stopping: True -early_stopping_patience: 5 +early_stopping_patience: 200 early_stopping_metric: accuracy latency: !!null diff --git a/edml/config/model_provider/resnet110-with-autoencoder.yaml b/edml/config/model_provider/resnet110-with-autoencoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..318dc714baa29b050696b73cb4f94ca3c3079542 --- /dev/null +++ b/edml/config/model_provider/resnet110-with-autoencoder.yaml @@ -0,0 +1,22 @@ +_target_: edml.models.provider.autoencoder.AutoencoderModelProvider +model_provider: + # TODO: can this include other files next to it? + _target_: edml.models.provider.cut_layer.CutLayerModelProvider + model: + _target_: edml.models.resnet_models.ResNet + block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock + num_blocks: [ 18, 18, 18 ] + num_classes: 100 + cut_layer: 4 +decoder: + _target_: edml.models.provider.path.SerializedModel + model: + _target_: edml.models.partials.resnet.Decoder + path: resnet_decoder.pth +encoder: + _target_: edml.models.provider.path.SerializedModel + model: + _target_: edml.models.partials.resnet.Encoder + path: resnet_encoder.pth diff --git a/edml/config/model_provider/resnet110.yaml b/edml/config/model_provider/resnet110.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a436b97dcbe01ccf3dc9ae1b50b4ebea084ad338 --- /dev/null +++ b/edml/config/model_provider/resnet110.yaml @@ -0,0 +1,9 @@ +_target_: edml.models.provider.cut_layer.CutLayerModelProvider +model: + _target_: edml.models.resnet_models.ResNet + block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock + num_blocks: [ 18, 18, 18 ] + num_classes: 100 +cut_layer: 4 diff --git a/edml/config/model_provider/resnet20-with-autoencoder.yaml b/edml/config/model_provider/resnet20-with-autoencoder.yaml index da09bf1312842ba021bea2f34786b0ae04dcabd9..a56b77ed33320a14851f91edaa40286e8b9c32aa 100644 --- a/edml/config/model_provider/resnet20-with-autoencoder.yaml +++ b/edml/config/model_provider/resnet20-with-autoencoder.yaml @@ -1,18 +1,22 @@ _target_: edml.models.provider.autoencoder.AutoencoderModelProvider model_provider: # TODO: can this include other files next to it? - _target_: edml.models.provider.base.ModelProvider - client: - _target_: edml.models.mnist_models.ClientNet - server: - _target_: edml.models.mnist_models.ServerNet + _target_: edml.models.provider.cut_layer.CutLayerModelProvider + model: + _target_: edml.models.resnet_models.ResNet + block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock + num_blocks: [ 3, 3, 3 ] + num_classes: 100 + cut_layer: 4 decoder: _target_: edml.models.provider.path.SerializedModel model: - _target_: edml.models.partials.mnist.Decoder - path: decoder.pth + _target_: edml.models.partials.resnet.Decoder + path: resnet_decoder.pth encoder: _target_: edml.models.provider.path.SerializedModel model: - _target_: edml.models.partials.mnist.Encoder - path: encoder.pth + _target_: edml.models.partials.resnet.Encoder + path: resnet_encoder.pth diff --git a/edml/config/topology/100_devices.yaml b/edml/config/topology/100_devices.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f046bf89d624f602d5b2cf5bcfa4784d6aa1e99 --- /dev/null +++ b/edml/config/topology/100_devices.yaml @@ -0,0 +1,602 @@ +devices: [ + { + device_id: "d0", + address: "localhost:50051", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d1", + address: "localhost:50052", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d2", + address: "localhost:50053", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d3", + address: "localhost:50054", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d4", + address: "localhost:50055", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d5", + address: "localhost:50056", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d6", + address: "localhost:50057", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d7", + address: "localhost:50058", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d8", + address: "localhost:50059", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d9", + address: "localhost:50060", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d10", + address: "localhost:50061", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d11", + address: "localhost:50062", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d12", + address: "localhost:50063", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d13", + address: "localhost:50064", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d14", + address: "localhost:50065", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d15", + address: "localhost:50066", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d16", + address: "localhost:50067", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d17", + address: "localhost:50068", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d18", + address: "localhost:50069", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d19", + address: "localhost:50070", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d20", + address: "localhost:50071", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d21", + address: "localhost:50072", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d22", + address: "localhost:50073", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d23", + address: "localhost:50074", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d24", + address: "localhost:50075", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d25", + address: "localhost:50076", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d26", + address: "localhost:50077", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d27", + address: "localhost:50078", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d28", + address: "localhost:50079", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d29", + address: "localhost:50080", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d30", + address: "localhost:50081", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d31", + address: "localhost:50082", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d32", + address: "localhost:50083", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d33", + address: "localhost:50084", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d34", + address: "localhost:50085", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d35", + address: "localhost:50086", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d36", + address: "localhost:50087", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d37", + address: "localhost:50088", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d38", + address: "localhost:50089", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d39", + address: "localhost:50090", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d40", + address: "localhost:50091", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d41", + address: "localhost:50092", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d42", + address: "localhost:50093", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d43", + address: "localhost:50094", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d44", + address: "localhost:50095", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d45", + address: "localhost:50096", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d46", + address: "localhost:50097", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d47", + address: "localhost:50098", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d48", + address: "localhost:50099", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d49", + address: "localhost:50100", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d50", + address: "localhost:50101", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d51", + address: "localhost:50102", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d52", + address: "localhost:50103", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d53", + address: "localhost:50104", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d54", + address: "localhost:50105", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d55", + address: "localhost:50106", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d56", + address: "localhost:50107", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d57", + address: "localhost:50108", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d58", + address: "localhost:50109", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d59", + address: "localhost:50110", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d60", + address: "localhost:50111", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d61", + address: "localhost:50112", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d62", + address: "localhost:50113", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d63", + address: "localhost:50114", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d64", + address: "localhost:50115", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d65", + address: "localhost:50116", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d66", + address: "localhost:50117", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d67", + address: "localhost:50118", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d68", + address: "localhost:50119", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d69", + address: "localhost:50120", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d70", + address: "localhost:50121", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d71", + address: "localhost:50122", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d72", + address: "localhost:50123", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d73", + address: "localhost:50124", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d74", + address: "localhost:50125", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d75", + address: "localhost:50126", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d76", + address: "localhost:50127", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d77", + address: "localhost:50128", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d78", + address: "localhost:50129", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d79", + address: "localhost:50130", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d80", + address: "localhost:50131", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d81", + address: "localhost:50132", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d82", + address: "localhost:50133", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d83", + address: "localhost:50134", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d84", + address: "localhost:50135", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d85", + address: "localhost:50136", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d86", + address: "localhost:50137", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d87", + address: "localhost:50138", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d88", + address: "localhost:50139", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d89", + address: "localhost:50140", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d90", + address: "localhost:50141", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d91", + address: "localhost:50142", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d92", + address: "localhost:50143", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d93", + address: "localhost:50144", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d94", + address: "localhost:50145", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d95", + address: "localhost:50146", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d96", + address: "localhost:50147", + battery_capacity: 1000000, + torch_device: cuda:0 + }, + { + device_id: "d97", + address: "localhost:50148", + battery_capacity: 1000000, + torch_device: cuda:1 + }, + { + device_id: "d98", + address: "localhost:50149", + battery_capacity: 1000000, + torch_device: cuda:2 + }, + { + device_id: "d99", + address: "localhost:50150", + battery_capacity: 1000000, + torch_device: cuda:0 + }, +] diff --git a/edml/core/client.py b/edml/core/client.py index bcedd28435f91f237ff31cf3fc68b1aa1ddb2c07..3526e03c5e2c3c9852d556d385a8a303514082ea 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -159,8 +159,6 @@ class DeviceClient: # Safety check to ensure that we train same-sized batches only. batch_data, batch_labels = next(self._batchable_data_loader) - if len(batch_data) != self._cfg.experiment.batch_size: - return # This means that the last batch with size < batch_size gets discarded. # Updates the battery capacity by simulating the required energy consumption for conducting the training step. self.node_device.battery.update_flops(self._model_flops * len(batch_data)) @@ -256,8 +254,6 @@ class DeviceClient: self._model.train() diagnostic_metric_container = DiagnosticMetricResultContainer() for idx, (batch_data, batch_labels) in enumerate(self._train_data): - if len(batch_data) != self._cfg.experiment.batch_size: - break self.node_device.battery.update_flops(self._model_flops * len(batch_data)) with LatencySimulator(latency_factor=self.latency_factor): diff --git a/edml/models/partials/resnet.py b/edml/models/partials/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..af9c87cc07e8e8b892684322de3639836ee94d8f --- /dev/null +++ b/edml/models/partials/resnet.py @@ -0,0 +1,37 @@ +from torch import nn + + +class Decoder(nn.Module): + """ + decoder model + """ + + def __init__(self): + super(Decoder, self).__init__() + self.t_convx = nn.ConvTranspose2d(4, 8, 1, stride=1) + self.t_conva = nn.ConvTranspose2d(8, 16, 1, stride=1) + self.t_convb = nn.ConvTranspose2d(16, 16, 1, stride=1) + + def forward(self, x): + x = self.t_convx(x) + x = self.t_conva(x) + x = self.t_convb(x) + return x + + +class Encoder(nn.Module): + """ + encoder model + """ + + def __init__(self): + super(Encoder, self).__init__() + self.conva = nn.Conv2d(16, 16, 3, padding=1) + self.convb = nn.Conv2d(16, 8, 3, padding=1) + self.convc = nn.Conv2d(8, 4, 3, padding=1) + + def forward(self, x): + x = self.conva(x) + x = self.convb(x) + x = self.convc(x) + return x