diff --git a/qutil/concurrent.py b/qutil/concurrent.py
index e1e8e4cbc6ee4112084389a2dc3244f3ac1a9a52..d3ef8250cfd4c2fba187f0c3d4b2fafd702f5512 100644
--- a/qutil/concurrent.py
+++ b/qutil/concurrent.py
@@ -11,12 +11,13 @@ or certain simulations.
 
 import concurrent.futures
 import typing
-from typing import Protocol, Sequence, Tuple, Any, Mapping
+from typing import Protocol, Sequence, Tuple, Any, Mapping, Callable, Iterable, Optional, MutableSequence
 
 
 class BatchFuture(concurrent.futures.Future):
     """This Future subclass is used by the :class:`BatchExecutor`. It triggers the execution of all submitted futures
     with the same callable."""
+
     def __init__(self, batch_executor, fn, args, kwargs):
         super().__init__()
         self.batch_executor = batch_executor
@@ -33,15 +34,17 @@ class BatchFuture(concurrent.futures.Future):
 
 class BatchCaller(Protocol):
     """The implementation of this protocol defines how a batch of arguments is called for a certain callable."""
-    def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]:
+
+    def call_batched(self, fn: callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]:
         raise NotImplementedError
 
 
 class SimpleBatchCaller:
     """This implementation calls the callable in the order it was created"""
-    def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]:
+
+    def call_batched(self, fn: callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]:
         results = []
-        for args, kwargs in zip(args_batch, kwargs_batch):
+        for args, kwargs in batch:
             try:
                 result = fn(*args, **kwargs)
                 results.append((result, True))
@@ -52,13 +55,20 @@ class SimpleBatchCaller:
 
 class SortedArgsBatchCaller:
     """This implementation of the :class:`BatchCaller` protocol sorts the calls by their arguments."""
-    def call_batched(self, fn, args_batch, kwargs_batch) -> Sequence[Tuple[Any, bool]]:
-        batch_size = len(args_batch)
-        sort_idx = sorted(range(batch_size), key=args_batch.__getitem__)
-        results = [None] * len(args_batch)
-        for idx in sort_idx:
-            args = args_batch[idx]
-            kwargs = kwargs_batch[idx]
+
+    @staticmethod
+    def default_sorter(batch: Sequence[Tuple[tuple, dict]]) -> Iterable[int]:
+        """Returns indices that sort the batch arguments in non-descending order."""
+        return sorted(range(len(batch)), key=batch.__getitem__)
+
+    def __init__(self, sorter: Callable[[Sequence[Tuple[tuple, dict]]], Iterable[int]] = None):
+        self.sorter = sorter or self.default_sorter
+
+    def call_batched(self, fn: Callable, batch: Sequence[Tuple[tuple, dict]]) -> Sequence[Tuple[Any, bool]]:
+        index_order = self.sorter(batch)
+        results: MutableSequence[Optional[tuple]] = [None] * len(batch)
+        for idx in index_order:
+            args, kwargs = batch[idx]
             try:
                 result = fn(*args, **kwargs)
                 success = True
@@ -66,7 +76,7 @@ class SortedArgsBatchCaller:
                 result = err
                 success = False
             results[idx] = (result, success)
-        return typing.cast(Tuple[Any, bool], results)
+        return results
 
 
 class BatchExecutor(concurrent.futures.Executor):
@@ -74,6 +84,12 @@ class BatchExecutor(concurrent.futures.Executor):
     evaluated can be defined via the ``batch_callers`` argument/attribute. The default caller can be changed via
     ``BatchExecutor.DEFAULT_BATCH_CALLER``. It is :class:`SimpleBatchCaller` by default.
 
+    You are encouraged to implement your own batch caller. An example of how this can be done is given in
+    :py:class:`~SortedArgsBatchCaller`.
+
+    Examples
+    --------
+
     >>> executor = BatchExecutor()
     >>> f1 = executor.submit(print, "evaluated first")
     >>> f2 = executor.submit(print, "evaluated second")
@@ -101,9 +117,17 @@ class BatchExecutor(concurrent.futures.Executor):
         self.batch_callers = {} if batch_callers is None else batch_callers
 
     def trigger_execution(self, fn: callable) -> int:
-        """Trigger the execution of all futures with the given callable (that are not cancelled)."""
-        batch_args = []
-        batch_kwargs = []
+        """Trigger the execution of all futures with the given callable (that are not cancelled).
+
+        Parameters
+        ----------
+        fn : Submitted futures with this callable are executed
+
+        Returns
+        -------
+        The number of executed futures.
+        """
+        batch_arguments = []
         batch_futures = []
 
         caller = self.batch_callers.get(fn, self.DEFAULT_BATCH_CALLER)
@@ -115,12 +139,12 @@ class BatchExecutor(concurrent.futures.Executor):
 
         for future in futures:
             if not future.set_running_or_notify_cancel():
+                # ignore futures that have been cancelled
                 continue
             batch_futures.append(future)
-            batch_args.append(future.args)
-            batch_kwargs.append(future.kwargs)
+            batch_arguments.append((future.args, future.kwargs))
 
-        results = caller.call_batched(fn, batch_args, batch_kwargs)
+        results = caller.call_batched(fn, batch_arguments)
         for future, (result, success) in zip(batch_futures, results):
             if success:
                 future.set_result(result)