Skip to content
Snippets Groups Projects
Commit 3ecbf5eb authored by Simon Sebastian Humpohl's avatar Simon Sebastian Humpohl
Browse files

Rework batch call signature and document intention.

parent 0ba9a3c2
No related branches found
No related tags found
1 merge request!114Add a lazy batched "concurrent" executor implementation
......@@ -11,12 +11,13 @@ or certain simulations.
import concurrent.futures
import typing
from typing import Protocol, Sequence, Tuple, Any, Mapping
from typing import Protocol, Sequence, Tuple, Any, Mapping, Callable, Iterable, Optional, MutableSequence
class BatchFuture(concurrent.futures.Future):
"""This Future subclass is used by the :class:`BatchExecutor`. It triggers the execution of all submitted futures
with the same callable."""
def __init__(self, batch_executor, fn, args, kwargs):
super().__init__()
self.batch_executor = batch_executor
......@@ -33,15 +34,17 @@ class BatchFuture(concurrent.futures.Future):
class BatchCaller(Protocol):
"""The implementation of this protocol defines how a batch of arguments is called for a certain callable."""
def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]:
def call_batched(self, fn: callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]:
raise NotImplementedError
class SimpleBatchCaller:
"""This implementation calls the callable in the order it was created"""
def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]:
def call_batched(self, fn: callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]:
results = []
for args, kwargs in zip(args_batch, kwargs_batch):
for args, kwargs in batch:
try:
result = fn(*args, **kwargs)
results.append((result, True))
......@@ -52,13 +55,20 @@ class SimpleBatchCaller:
class SortedArgsBatchCaller:
"""This implementation of the :class:`BatchCaller` protocol sorts the calls by their arguments."""
def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]:
batch_size = len(args_batch)
sort_idx = sorted(range(batch_size), key=args_batch.__getitem__)
results = [None] * len(args_batch)
for idx in sort_idx:
args = args_batch[idx]
kwargs = kwargs_batch[idx]
@staticmethod
def default_sorter(batch: Sequence[Tuple[tuple, dict]]) -> Iterable[int]:
"""Returns indices that sort the batch arguments in non-descending order."""
return sorted(range(len(batch)), key=batch.__getitem__)
def __init__(self, sorter: Callable[[Sequence[Tuple[tuple, dict]]], Iterable[int]] = None):
self.sorter = sorter or self.default_sorter
def call_batched(self, fn: Callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]:
index_order = self.sorter(batch)
results: MutableSequence[Optional[tuple]] = [None] * len(batch)
for idx in index_order:
args, kwargs = batch[idx]
try:
result = fn(*args, **kwargs)
success = True
......@@ -66,7 +76,7 @@ class SortedArgsBatchCaller:
result = err
success = False
results[idx] = (result, success)
return typing.cast(Tuple[Any, bool], results)
return results
class BatchExecutor(concurrent.futures.Executor):
......@@ -74,6 +84,12 @@ class BatchExecutor(concurrent.futures.Executor):
evaluated can be defined via the ``batch_callers`` argument/attribute. The default caller can be changed via
``BatchExecutor.DEFAULT_BATCH_CALLER``. It is :class:`SimpleBatchCaller` by default.
You are encouraged to implement your own batch caller. An example of how this can be done is given in
:py:class:`~SortedArgsBatchCaller`.
Examples
--------
>>> executor = BatchExecutor()
>>> f1 = executor.submit(print, "evaluated first")
>>> f2 = executor.submit(print, "evaluated second")
......@@ -101,9 +117,17 @@ class BatchExecutor(concurrent.futures.Executor):
self.batch_callers = {} if batch_callers is None else batch_callers
def trigger_execution(self, fn: callable) -> int:
"""Trigger the execution of all futures with the given callable (that are not cancelled)."""
batch_args = []
batch_kwargs = []
"""Trigger the execution of all futures with the given callable (that are not cancelled).
Parameters
----------
fn : Submitted futures with this callable are executed
Returns
-------
The number of executed futures.
"""
batch_arguments = []
batch_futures = []
caller = self.batch_callers.get(fn, self.DEFAULT_BATCH_CALLER)
......@@ -115,12 +139,12 @@ class BatchExecutor(concurrent.futures.Executor):
for future in futures:
if not future.set_running_or_notify_cancel():
# ignore futures that have been cancelled
continue
batch_futures.append(future)
batch_args.append(future.args)
batch_kwargs.append(future.kwargs)
batch_arguments.append((future.args, future.kwargs))
results = caller.call_batched(fn, batch_args, batch_kwargs)
results = caller.call_batched(fn, batch_arguments)
for future, (result, success) in zip(batch_futures, results):
if success:
future.set_result(result)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment