diff --git a/edml/helpers/proto_helpers.py b/edml/helpers/proto_helpers.py index cc14aa50ca6926745d5cc00fd8eb801aee83a73f..a7a5ad26e8276bf3ddabede6128ad0373335a1a9 100644 --- a/edml/helpers/proto_helpers.py +++ b/edml/helpers/proto_helpers.py @@ -27,7 +27,8 @@ def _tensor_to_bytes(tensor: torch.Tensor) -> bytes: def _bytes_to_tensor(raw_bytes: bytes) -> torch.Tensor: - return pickle.loads(raw_bytes) + bs = io.BytesIO(raw_bytes) + return CpuUnpickler(bs).load() def _state_dict_to_bytes(state_dict: StateDict) -> bytes: @@ -46,7 +47,8 @@ def _metrics_to_bytes( def _bytes_to_metrics(raw_bytes: bytes): - return pickle.loads(raw_bytes) + bs = io.BytesIO(raw_bytes) + return CpuUnpickler(bs).load() def tensor_to_proto(tensor: torch.Tensor) -> datastructures_pb2.Tensor: