diff --git a/qutil/concurrent.py b/qutil/concurrent.py index e1e8e4cbc6ee4112084389a2dc3244f3ac1a9a52..d3ef8250cfd4c2fba187f0c3d4b2fafd702f5512 100644 --- a/qutil/concurrent.py +++ b/qutil/concurrent.py @@ -11,12 +11,13 @@ or certain simulations. import concurrent.futures import typing -from typing import Protocol, Sequence, Tuple, Any, Mapping +from typing import Protocol, Sequence, Tuple, Any, Mapping, Callable, Iterable, Optional, MutableSequence class BatchFuture(concurrent.futures.Future): """This Future subclass is used by the :class:`BatchExecutor`. It triggers the execution of all submitted futures with the same callable.""" + def __init__(self, batch_executor, fn, args, kwargs): super().__init__() self.batch_executor = batch_executor @@ -33,15 +34,17 @@ class BatchFuture(concurrent.futures.Future): class BatchCaller(Protocol): """The implementation of this protocol defines how a batch of arguments is called for a certain callable.""" - def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]: + + def call_batched(self, fn: callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]: raise NotImplementedError class SimpleBatchCaller: """This implementation calls the callable in the order it was created""" - def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]: + + def call_batched(self, fn: callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]: results = [] - for args, kwargs in zip(args_batch, kwargs_batch): + for args, kwargs in batch: try: result = fn(*args, **kwargs) results.append((result, True)) @@ -52,13 +55,20 @@ class SimpleBatchCaller: class SortedArgsBatchCaller: """This implementation of the :class:`BatchCaller` protocol sorts the calls by their arguments.""" - def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]: - batch_size = len(args_batch) - sort_idx = sorted(range(batch_size), key=args_batch.__getitem__) - results = [None] * len(args_batch) - for idx in sort_idx: - args = args_batch[idx] - kwargs = kwargs_batch[idx] + + @staticmethod + def default_sorter(batch: Sequence[Tuple[tuple, dict]]) -> Iterable[int]: + """Returns indices that sort the batch arguments in non-descending order.""" + return sorted(range(len(batch)), key=batch.__getitem__) + + def __init__(self, sorter: Callable[[Sequence[Tuple[tuple, dict]]], Iterable[int]] = None): + self.sorter = sorter or self.default_sorter + + def call_batched(self, fn: Callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]: + index_order = self.sorter(batch) + results: MutableSequence[Optional[tuple]] = [None] * len(batch) + for idx in index_order: + args, kwargs = batch[idx] try: result = fn(*args, **kwargs) success = True @@ -66,7 +76,7 @@ class SortedArgsBatchCaller: result = err success = False results[idx] = (result, success) - return typing.cast(Tuple[Any, bool], results) + return results class BatchExecutor(concurrent.futures.Executor): @@ -74,6 +84,12 @@ class BatchExecutor(concurrent.futures.Executor): evaluated can be defined via the ``batch_callers`` argument/attribute. The default caller can be changed via ``BatchExecutor.DEFAULT_BATCH_CALLER``. It is :class:`SimpleBatchCaller` by default. + You are encouraged to implement your own batch caller. An example of how this can be done is given in + :py:class:`~SortedArgsBatchCaller`. + + Examples + -------- + >>> executor = BatchExecutor() >>> f1 = executor.submit(print, "evaluated first") >>> f2 = executor.submit(print, "evaluated second") @@ -101,9 +117,17 @@ class BatchExecutor(concurrent.futures.Executor): self.batch_callers = {} if batch_callers is None else batch_callers def trigger_execution(self, fn: callable) -> int: - """Trigger the execution of all futures with the given callable (that are not cancelled).""" - batch_args = [] - batch_kwargs = [] + """Trigger the execution of all futures with the given callable (that are not cancelled). + + Parameters + ---------- + fn : Submitted futures with this callable are executed + + Returns + ------- + The number of executed futures. + """ + batch_arguments = [] batch_futures = [] caller = self.batch_callers.get(fn, self.DEFAULT_BATCH_CALLER) @@ -115,12 +139,12 @@ class BatchExecutor(concurrent.futures.Executor): for future in futures: if not future.set_running_or_notify_cancel(): + # ignore futures that have been cancelled continue batch_futures.append(future) - batch_args.append(future.args) - batch_kwargs.append(future.kwargs) + batch_arguments.append((future.args, future.kwargs)) - results = caller.call_batched(fn, batch_args, batch_kwargs) + results = caller.call_batched(fn, batch_arguments) for future, (result, success) in zip(batch_futures, results): if success: future.set_result(result)