aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyecsca/misc/utils.py24
-rw-r--r--test/misc/test_utils.py20
2 files changed, 34 insertions, 10 deletions
diff --git a/pyecsca/misc/utils.py b/pyecsca/misc/utils.py
index fe45784..40849b3 100644
--- a/pyecsca/misc/utils.py
+++ b/pyecsca/misc/utils.py
@@ -79,12 +79,18 @@ class TaskExecutor(ProcessPoolExecutor):
"""A list of tasks that were submitted to this executor."""
return list(zip(self.keys, self.futures))
- def as_completed(self) -> Generator[tuple[Any, Future], Any, None]:
- """Like `concurrent.futures.as_completed`, but yields a pair of key and future."""
- for future in as_completed(self.futures):
- i = self.futures.index(future)
- yield self.keys[i], future
- del self.keys[i]
- del self.futures[i]
- self.futures = []
- self.keys = []
+ def as_completed(self, wait: bool = True) -> Generator[tuple[Any, Future], Any, None]:
+ """
+ Like `concurrent.futures.as_completed`, but yields a pair of key and future.
+
+ If `wait` is True, it will block until all futures are done.
+ If `wait` is False, it will return immediately with futures that are already done.
+ """
+ try:
+ for future in as_completed(self.futures, timeout=None if wait else 0):
+ i = self.futures.index(future)
+ yield self.keys[i], future
+ del self.keys[i]
+ del self.futures[i]
+ except TimeoutError:
+ pass
diff --git a/test/misc/test_utils.py b/test/misc/test_utils.py
index 72ce711..d4409ed 100644
--- a/test/misc/test_utils.py
+++ b/test/misc/test_utils.py
@@ -1,4 +1,4 @@
-
+import time
from pyecsca.misc.utils import TaskExecutor
@@ -6,6 +6,11 @@ def run(a, b):
return a + b
+def wait(a, b):
+ time.sleep(1)
+ return a + b
+
+
def test_executor():
with TaskExecutor(max_workers=2) as pool:
for i in range(10):
@@ -15,3 +20,16 @@ def test_executor():
for i, future in pool.as_completed():
res = future.result()
assert res == i + 5
+
+
+def test_executor_no_wait():
+ with TaskExecutor(max_workers=2) as pool:
+ for i in range(2):
+ pool.submit_task(i,
+ wait,
+ i, 5)
+ futures = list(pool.as_completed(wait=False))
+ assert len(futures) == 0
+ time.sleep(2.5)
+ futures = list(pool.as_completed(wait=False))
+ assert len(futures) == 2