diff --git a/qutil/io.py b/qutil/io.py index c81b964c328f5582b887617c0fa2925c0eede57d..ade2a7fa94f4995be116baba5459974741d7e6a2 100644 --- a/qutil/io.py +++ b/qutil/io.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +import asyncio +import atexit import csv +import importlib import os.path import pathlib import pickle @@ -7,8 +12,25 @@ import string import sys import textwrap import warnings +from concurrent.futures import Future from contextlib import contextmanager -from typing import Union +from threading import Thread +from types import ModuleType +from typing import Union, Literal +from unittest import mock + +import numpy as np + +from .functools import partial + +try: + import dill +except ImportError: + dill = mock.Mock() +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias if os.name == 'nt': _POSIX = False @@ -21,6 +43,8 @@ _VALID_YESNO_ANSWERS = {"yes": True, "y": True, "ye": True, _HOST_RE = re.compile(r"(janeway|ag-bluhm-[0-9]+)(?![.\d])", re.IGNORECASE) _HOST_SUFFIX = '.physik.rwth-aachen.de' +PathType: TypeAlias = Union[str, os.PathLike] + def _host_append_suffix(match: re.Match) -> str: return match.group() + _HOST_SUFFIX @@ -319,3 +343,143 @@ class CsvLogger: def write(self, *args): with open(self.filename, 'a+') as file: csv.DictWriter(file, fieldnames=self.fieldnames, dialect=self.dialect).writerow(dict(zip(self.fieldnames, args))) + + +class AsyncDatasaver: + """Class to handle asynchronous data saving using NumPy. + + To save data, either call this class or use :meth:`save_async`. + Data can also be saved synchronously using :meth:`save_sync`. + + When pickling is enabled, uses :mod:`dill` instead of :mod:`pickle`. + + Parameters + ---------- + pickle_lib : + Which library to use for pickling. Defaults to :mod:`pickle`, + but may also be :mod:`dill`. + compress : + Compress binary data. Uses :func:`numpy:numpy.savez_compressed` + and :func:`numpy:numpy.savez` otherwise. + + Examples + -------- + Generate some data and test performance. Set up the datasaver. + + >>> import pathlib, tempfile, time + >>> tmpdir = pathlib.Path(tempfile.mkdtemp()) + >>> data = np.random.randn(2**20) + >>> datasaver = AsyncDatasaver() + + Baseline of the synchronous version (which just wraps + :func:`numpy:numpy.savez`): + + >>> tic = time.perf_counter() + >>> datasaver.save_sync(tmpdir / 'foo.npz', data=data) + >>> print(f'Writing took {time.perf_counter() - tic:.2g} seconds.') # doctest: +ELLIPSIS + Writing took ... seconds. + + Now the asynchronous version: + + >>> tic = time.perf_counter() + >>> datasaver.save_async(tmpdir / 'foo.npz', data=data) + >>> tac = time.perf_counter() + >>> while datasaver.jobs: + ... pass + >>> toe = time.perf_counter() + >>> print(f'Blocked for {tac - tic:.2g} seconds.') # doctest: +ELLIPSIS + Blocked for ... seconds. + >>> print(f'Writing took {toe - tic:.2g} seconds.') # doctest: +ELLIPSIS + Writing took ... seconds. + + Evidently, there is some tradeoff between total I/O time and + blocking I/O time. + """ + def __init__(self, pickle_lib: Literal['pickle', 'dill'] = 'pickle', compress: bool = False): + if pickle_lib == 'pickle': + self._pickle_lib = pickle + elif pickle_lib == 'dill': + if isinstance(dill, ModuleType): + self._pickle_lib = dill + else: + raise ValueError("pickle_lib 'dill' requested but could not be imported.") + else: + raise ValueError("pickle_lib must be either 'pickle' or 'dill'.") + + self._pickle_lib = importlib.import_module(pickle_lib) + self._savefn = np.savez_compressed if compress else np.savez + self.jobs: dict[Future, PathType] = {} + self.loop = asyncio.new_event_loop() + self.thread = Thread(target=self._start_loop, daemon=True) + self.thread.start() + # Make sure all io tasks are done when shutting down the interpreter + atexit.register(self.shutdown, timeout=60) + + def __call__(self, file: PathType, **data): + self.save_async(file, **data) + + def _start_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def _mark_done(self, future: Future): + try: + future.result() + except Exception as e: + print(f"Error in async task for file {self.jobs.get(future)}: {e}") + finally: + self.jobs.pop(future, None) + + async def _save_data(self, file: PathType, allow_pickle: bool = True, **data): + """Coroutine to asynchronously save data in a thread. + + Use :meth:`py:asyncio.loop.set_default_executor` to set the + executor used by this function. + """ + # Use the default executor + await self.loop.run_in_executor(None, partial(self.save_sync, file, + allow_pickle=allow_pickle, **data)) + + def save_sync(self, file: PathType, allow_pickle: bool = True, **data): + """Save arbitrary key-val pairs to *file* synchronously. + + See :func:`numpy:numpy.savez` for more information. + """ + with mock.patch.multiple( + 'pickle', + Unpickler=self._pickle_lib.Unpickler, + Pickler=self._pickle_lib.Pickler + ): + self._savefn(str(file), allow_pickle=allow_pickle, **data) + + def save_async(self, file: PathType, allow_pickle: bool = True, **data): + """Save arbitrary key-val pairs to *file* asynchronously. + + See :func:`numpy:numpy.savez` for more information. + """ + self.jobs[ + # parens necessary for py39.. + (future := asyncio.run_coroutine_threadsafe( + self._save_data(file, allow_pickle=allow_pickle, **data), + self.loop + )) + ] = file + future.add_done_callback(self._mark_done) + + def shutdown(self, timeout: float | None = None): + """Shut down the object, waiting for all I/O tasks to finish.""" + for future, file in self.jobs.items(): + try: + future.result(timeout) + except Exception as e: + print(f"Exception when waiting for {file} to write, or timeout: {e}") + + self.jobs.clear() + try: + self.loop.call_soon_threadsafe(self.loop.stop) + except RuntimeError: + # loop already closed + return + else: + self.thread.join(timeout) + self.loop.close()