Skip to content
Snippets Groups Projects
Verified Commit 77405c72 authored by Tobias Hangleiter's avatar Tobias Hangleiter
Browse files

Make jobs dict thread-safe

parent 1180e8c4
No related branches found
No related tags found
1 merge request!177IO: async datasaver
......@@ -14,7 +14,7 @@ import textwrap
import warnings
from concurrent.futures import Future
from contextlib import contextmanager
from threading import Thread
from threading import Thread, Lock
from types import ModuleType
from typing import Union, Literal
from unittest import mock
......@@ -424,6 +424,7 @@ class AsyncDatasaver:
self._savefn = np.savez_compressed if compress else np.savez
self.jobs: dict[Future, PathType] = {}
self.loop = asyncio.new_event_loop()
self.lock = Lock()
self.thread = Thread(target=self._start_loop, daemon=True)
self.thread.start()
# Make sure all io tasks are done when shutting down the interpreter
......@@ -443,12 +444,13 @@ class AsyncDatasaver:
self.loop.run_forever()
def _mark_done(self, future: Future):
try:
future.result()
except Exception as e:
print(f"Error in async task for file {self.jobs.get(future)}: {e}")
finally:
self.jobs.pop(future, None)
with self.lock:
try:
future.result()
except Exception as e:
print(f"Error in async task for file {self.jobs.get(future)}: {e}")
finally:
self.jobs.pop(future, None)
async def _save_data(self, file: PathType, allow_pickle: bool = True, **data):
"""Coroutine to asynchronously save data in a thread.
......@@ -477,18 +479,21 @@ class AsyncDatasaver:
See :func:`numpy:numpy.savez` for more information.
"""
self.jobs[
# parens necessary for py39..
(future := asyncio.run_coroutine_threadsafe(
self._save_data(file, allow_pickle=allow_pickle, **data),
self.loop
))
] = file
with self.lock:
self.jobs[
# parens necessary for py39..
(future := asyncio.run_coroutine_threadsafe(
self._save_data(file, allow_pickle=allow_pickle, **data),
self.loop
))
] = file
future.add_done_callback(self._mark_done)
def shutdown(self, timeout: float | None = None):
"""Shut down the object, waiting for all I/O tasks to finish."""
for future, file in self.jobs.items():
with self.lock:
jobs_snapshot = list(self.jobs.items())
for future, file in jobs_snapshot:
try:
future.result(timeout)
except Exception as e:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment