diff --git a/edml/core/client.py b/edml/core/client.py index 7dc1617a1b0d283211b9be9900214831ab971e08..1a81dfb0d0c661da362e5b6c14eff28c21c935f8 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -229,7 +229,7 @@ class DeviceClient: metrics_container = DiagnosticMetricResultContainer([metric]) gradients = [] - for param in self._model.parameters(): + for param in self._model.get_optimizer_params(): if param.grad is not None: gradients.append(param.grad) else: @@ -389,7 +389,7 @@ class DeviceClient: return diagnostic_metric_results def set_gradient_and_finalize_training(self, gradients: Any): - for param, grad in zip(self._model.parameters(), gradients): + for param, grad in zip(self._model.get_optimizer_params(), gradients): param.grad = grad.to(self._device) self._optimizer.step() diff --git a/edml/models/provider/base.py b/edml/models/provider/base.py index 2b28c7cf27953017e140f976b4b9d76f5f2bed00..9fa81abd5a7c58e880f163d2d88eb070c7271e47 100644 --- a/edml/models/provider/base.py +++ b/edml/models/provider/base.py @@ -1,4 +1,3 @@ -import torch from torch import nn @@ -8,16 +7,6 @@ def has_autoencoder(model: nn.Module): return False -def get_grads(model: nn.Module): - gradients = [] - for param in model.parameters(): - if param.grad is not None: - gradients.append(param.grad) - else: - gradients.append(torch.zeros_like(param)) - return gradients - - def add_optimizer_params_function(model: nn.Module): # exclude AE params if has_autoencoder(model): diff --git a/results/result_generation.py b/results/result_generation.py index caa42642976785d82f0313733b324c9cd63b1166..f7db5c83721b4f51fb6cd08b7f14e6a277a4387b 100644 --- a/results/result_generation.py +++ b/results/result_generation.py @@ -39,6 +39,8 @@ LABEL_MAPPING = { "device": "Device", } +DPI = 300 + def scale_parallel_time(run_df, scale_factor=1.0): """ @@ -401,7 +403,7 @@ def plot_remaining_devices(devices_per_round, save_path=None): save_path: (str) the path to save the plot to """ - plt.figure() + plt.figure(dpi=DPI) num_rounds = [0] max_devices = [] for (strategy, job), (rounds, num_devices) in devices_per_round.items(): @@ -492,7 +494,7 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"): accuracies_per_round: (dict) the accuracy (list(float)) per round (list(int)) for each group save_path: (str) the path to save the plot to """ - plt.figure() + plt.figure(dpi=DPI) num_rounds = [0] for (strategy, job), (rounds, accs) in accuracies_per_round.items(): plt.plot(rounds, accs, label=f"{STRATEGY_MAPPING[strategy]}") @@ -517,7 +519,7 @@ def plot_accuracies_over_time(accuracies_per_time, save_path=None, phase="train" accuracies_per_time: (dict) the accuracy (list(float)) per time (list(float)) for each group save_path: (str) the path to save the plot to """ - plt.figure() + plt.figure(dpi=DPI) for (strategy, job), (time, accs) in accuracies_per_time.items(): plt.plot(time, accs, label=f"{STRATEGY_MAPPING[strategy]}") plt.xlabel(LABEL_MAPPING["runtime"]) @@ -688,7 +690,7 @@ def plot_batteries_over_time( aggregated: (bool) whether the battery is aggregated or not """ if aggregated: - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) for (strategy, job), series in batteries_over_time.items(): runtime = max_runtimes[(strategy, job)] @@ -709,7 +711,7 @@ def plot_batteries_over_time( plt.close() else: for (strategy, job), series_dict in batteries_over_time.items(): - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) for device_id, series in series_dict.items(): runtime = max_runtimes[(strategy, job)] @@ -739,7 +741,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T aggregated: (bool) whether the battery is aggregated or not """ if aggregated: - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) num_rounds = [0] for (strategy, job), series in batteries_over_epoch.items(): @@ -760,7 +762,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T plt.close() else: for (strategy, job), series_dict in batteries_over_epoch.items(): - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) num_rounds = [0] for device_id, series in series_dict.items(): @@ -899,7 +901,7 @@ def plot_batteries_over_time_with_activity( 0, train_time_end(server_train_times, client_train_times) * 1.05, ) # set end 5% after last activity timestamp - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) # battery_plot plt.subplot(2, 1, 1) @@ -976,7 +978,7 @@ def plot_batteries_over_epoch_with_activity_at_time_scale( ], [str(i) for i in range(0, len(start_times), max(1, len(start_times) // 8))], ) - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) # battery_plot plt.subplot(2, 1, 1) @@ -1048,7 +1050,7 @@ def plot_batteries_over_epoch_with_activity_at_epoch_scale( ).sort_values("start") # battery plot - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) plt.subplot(2, 1, 1) num_rounds = []