aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sec_certs/utils/parallel_processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/sec_certs/utils/parallel_processing.py')
-rw-r--r--src/sec_certs/utils/parallel_processing.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/src/sec_certs/utils/parallel_processing.py b/src/sec_certs/utils/parallel_processing.py
new file mode 100644
index 00000000..50806a67
--- /dev/null
+++ b/src/sec_certs/utils/parallel_processing.py
@@ -0,0 +1,43 @@
+from __future__ import annotations
+
+import time
+from multiprocessing.pool import ThreadPool
+from typing import Any, Callable, Iterable
+
+from billiard.pool import Pool
+
+from sec_certs.utils.tqdm import tqdm
+
+
+def process_parallel(
+ func: Callable,
+ items: Iterable,
+ max_workers: int,
+ callback: Callable | None = None,
+ use_threading: bool = True,
+ progress_bar: bool = True,
+ unpack: bool = False,
+ progress_bar_desc: str | None = None,
+) -> list[Any]:
+
+ pool: Pool | ThreadPool = ThreadPool(max_workers) if use_threading else Pool(max_workers)
+ results = (
+ [pool.apply_async(func, (*i,), callback=callback) for i in items]
+ if unpack
+ else [pool.apply_async(func, (i,), callback=callback) for i in items]
+ )
+
+ if progress_bar is True and items:
+ bar = tqdm(total=len(results), desc=progress_bar_desc)
+ while not all(all_done := [x.ready() for x in results]):
+ done_count = len(list(filter(lambda x: x, all_done)))
+ bar.update(done_count - bar.n)
+ time.sleep(1)
+ bar.update(len(results) - bar.n)
+ bar.close()
+
+ pool.close()
+ pool.join()
+ pool.terminate()
+
+ return [r.get() for r in results]