diff --git a/qutil/io.py b/qutil/io.py index c81b964c328f5582b887617c0fa2925c0eede57d..210a893d55335aa17b8b00c9804b4a76447ea418 100644 --- a/qutil/io.py +++ b/qutil/io.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +import asyncio +import atexit import csv +import importlib import os.path import pathlib import pickle @@ -7,8 +12,24 @@ import string import sys import textwrap import warnings +from concurrent.futures import Future from contextlib import contextmanager +from threading import Thread, Lock from typing import Union +from unittest import mock + +import numpy as np + +from .functools import partial + +try: + import dill +except ImportError: + dill = mock.Mock() +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias if os.name == 'nt': _POSIX = False @@ -21,6 +42,8 @@ _VALID_YESNO_ANSWERS = {"yes": True, "y": True, "ye": True, _HOST_RE = re.compile(r"(janeway|ag-bluhm-[0-9]+)(?![.\d])", re.IGNORECASE) _HOST_SUFFIX = '.physik.rwth-aachen.de' +PathType: TypeAlias = Union[str, os.PathLike] + def _host_append_suffix(match: re.Match) -> str: return match.group() + _HOST_SUFFIX @@ -319,3 +342,158 @@ class CsvLogger: def write(self, *args): with open(self.filename, 'a+') as file: csv.DictWriter(file, fieldnames=self.fieldnames, dialect=self.dialect).writerow(dict(zip(self.fieldnames, args))) + + +class AsyncDatasaver: + """Class to handle asynchronous data saving using NumPy. + + To save data, either call this class or use :meth:`save_async`. + Data can also be saved synchronously using :meth:`save_sync`. + + When pickling is enabled, uses :mod:`dill` instead of :mod:`pickle`. + + Parameters + ---------- + pickle_lib : + Which library to use for pickling. Defaults to :mod:`pickle`, + but may also be :mod:`dill`, for instance. + compress : + Compress binary data. Uses :func:`numpy:numpy.savez_compressed` + and :func:`numpy:numpy.savez` otherwise. + + Examples + -------- + Generate some data and test performance. Set up the datasaver. + + >>> import pathlib, tempfile, time + >>> tmpdir = pathlib.Path(tempfile.mkdtemp()) + >>> data = np.random.randn(2**20) + >>> datasaver = AsyncDatasaver() + + Baseline of the synchronous version (which just wraps + :func:`numpy:numpy.savez`): + + >>> tic = time.perf_counter() + >>> datasaver.save_sync(tmpdir / 'foo.npz', data=data) + >>> print(f'Writing took {time.perf_counter() - tic:.2g} seconds.') # doctest: +ELLIPSIS + Writing took ... seconds. + + Now the asynchronous version: + + >>> tic = time.perf_counter() + >>> datasaver.save_async(tmpdir / 'foo.npz', data=data) + >>> tac = time.perf_counter() + >>> while datasaver.jobs: + ... pass + >>> toe = time.perf_counter() + >>> print(f'Blocked for {tac - tic:.2g} seconds.') # doctest: +ELLIPSIS + Blocked for ... seconds. + >>> print(f'Writing took {toe - tic:.2g} seconds.') # doctest: +ELLIPSIS + Writing took ... seconds. + + Evidently, there is some tradeoff between total I/O time and + blocking I/O time. + + On exit, gracefully shut down the datasaver: + + >>> datasaver.shutdown() + + The datasaver can also be used as a context manager, in which case + cleanup is automatically taken care of. For syntactic sugar, calling + the datasaver has the same effect as + :meth:`AsyncDatasaver.save_async`. + + >>> with AsyncDatasaver() as datasaver: + ... datasaver(tmpdir / 'foo.npz', data=data/2) + ... datasaver(tmpdir / 'bar.npz', data=2*data) + + """ + def __init__(self, pickle_lib: str = 'pickle', compress: bool = False): + self._pickle_lib = importlib.import_module(pickle_lib) + self._savefn = np.savez_compressed if compress else np.savez + self.jobs: dict[Future, PathType] = {} + self.loop = asyncio.new_event_loop() + self.lock = Lock() + self.thread = Thread(target=self._start_loop, daemon=True) + self.thread.start() + # Make sure all io tasks are done when shutting down the interpreter + atexit.register(self.shutdown, timeout=60) + + def __call__(self, file: PathType, **data): + self.save_async(file, **data) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + + def _start_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def _mark_done(self, future: Future): + with self.lock: + try: + future.result() + except Exception as e: + print(f"Error in async task for file {self.jobs.get(future)}: {e}") + finally: + self.jobs.pop(future, None) + + async def _save_data(self, file: PathType, allow_pickle: bool = True, **data): + """Coroutine to asynchronously save data in a thread. + + Use :meth:`py:asyncio.loop.set_default_executor` to set the + executor used by this function. + """ + # Use the default executor + await self.loop.run_in_executor(None, partial(self.save_sync, file, + allow_pickle=allow_pickle, **data)) + + def save_sync(self, file: PathType, allow_pickle: bool = True, **data): + """Save arbitrary key-val pairs to *file* synchronously. + + See :func:`numpy:numpy.savez` for more information. + """ + with mock.patch.multiple( + 'pickle', + Unpickler=self._pickle_lib.Unpickler, + Pickler=self._pickle_lib.Pickler + ): + self._savefn(str(file), allow_pickle=allow_pickle, **data) + + def save_async(self, file: PathType, allow_pickle: bool = True, **data): + """Save arbitrary key-val pairs to *file* asynchronously. + + See :func:`numpy:numpy.savez` for more information. + """ + with self.lock: + self.jobs[ + # parens necessary for py39.. + (future := asyncio.run_coroutine_threadsafe( + self._save_data(file, allow_pickle=allow_pickle, **data), + self.loop + )) + ] = file + future.add_done_callback(self._mark_done) + + def shutdown(self, timeout: float | None = None): + """Shut down the object, waiting for all I/O tasks to finish.""" + with self.lock: + jobs_snapshot = list(self.jobs.items()) + for future, file in jobs_snapshot: + try: + future.result(timeout) + except Exception as e: + print(f"Exception when waiting for {file} to write, or timeout: {e}") + + self.jobs.clear() + try: + self.loop.call_soon_threadsafe(self.loop.stop) + except RuntimeError: + # loop already closed + return + else: + self.thread.join(timeout) + self.loop.close() diff --git a/tests/test_asyncdatasaver.py b/tests/test_asyncdatasaver.py new file mode 100644 index 0000000000000000000000000000000000000000..0b759f3a94fa24f62fc4283254c8bf9e684d2086 --- /dev/null +++ b/tests/test_asyncdatasaver.py @@ -0,0 +1,43 @@ +import pytest +import pathlib, tempfile, numpy, time +from qutil.io import AsyncDatasaver + + +@pytest.fixture(name='data', scope='module') +def fixture_data(): + return numpy.random.randn(2**20) + + +@pytest.fixture(name='filename', scope='module') +def fixture_filename(): + tmpdir = pathlib.Path(tempfile.mkdtemp()) + return tmpdir / 'foo.npz' + + +@pytest.fixture(name='datasaver', scope='module') +def fixture_datasaver(): + datasaver = AsyncDatasaver() + yield datasaver + datasaver.shutdown() + + +def test_save_sync(filename, data, datasaver): + """Tests sync saving example from doctests""" + # tic = time.perf_counter() + datasaver.save_sync(filename, data=data) + # print(f'Writing took {time.perf_counter() - tic:.2g} seconds.') + dt = numpy.load(filename) + numpy.testing.assert_equal(dt['data'], data) + + +def test_save_async(filename, data, datasaver): + """Tests async saving example from doctests""" + # tic = time.perf_counter() + datasaver.save_async(filename, data=data) + # tac = time.perf_counter() + while datasaver.jobs: + pass + # toe = time.perf_counter() + # print(f'Writing took {toe - tic:.2g} seconds after blocking for {tac - tic:.2g} seconds.') + dt = numpy.load(filename) + numpy.testing.assert_equal(dt['data'], data)