Coverage for tests\unit\test_parallel.py: 96%
131 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-05 19:24 -0700
1import pytest
2import multiprocessing
3import time
4from typing import Any, List, Iterable
6# Import the function to test
7from muutils.parallel import DEFAULT_PBAR_FN, run_maybe_parallel
9DATA: dict = dict(
10 empty=[],
11 single=[5],
12 small=list(range(4)),
13 medium=list(range(10)),
14 large=list(range(50)),
15)
16SQUARE_RESULTS: dict = {k: [x**2 for x in v] for k, v in DATA.items()}
17ADD_ONE_RESULTS: dict = {k: [x + 1 for x in v] for k, v in DATA.items()}
20# Basic test functions
21def square(x: int) -> int:
22 return x**2
25def add_one(x: int) -> int:
26 return x + 1
29def raise_value_error(x: int) -> int:
30 if x == 5:
31 raise ValueError("Test error")
32 return x**2
35def slow_square(x: int) -> int:
36 time.sleep(0.0001)
37 return x**2
40def raise_on_negative(x: int) -> int:
41 if x < 0:
42 raise ValueError("Negative number")
43 return x
46def stateful_fn(x: list) -> list:
47 x.append(1)
48 return x
51class ComplexObject:
52 def __init__(self, value: int):
53 self.value = value
55 def __eq__(self, other: Any) -> bool:
56 return isinstance(other, ComplexObject) and self.value == other.value
59def dataset_decorator(keys: List[str]):
60 def wrapper(test_func):
61 return pytest.mark.parametrize(
62 "input_range, expected",
63 [(DATA[k], SQUARE_RESULTS[k]) for k in keys],
64 ids=keys,
65 )(test_func)
67 return wrapper
70@dataset_decorator(["empty", "single", "small"])
71@pytest.mark.parametrize("parallel", [False, True, 2, 4])
72@pytest.mark.parametrize("keep_ordered", [True, False])
73@pytest.mark.parametrize("use_multiprocess", [True, False])
74def test_general_functionality(
75 input_range, expected, parallel, keep_ordered, use_multiprocess
76):
77 # whether it's possible to use multiprocess
78 if use_multiprocess and (
79 parallel is False or parallel == 1 or len(input_range) == 1
80 ):
81 return
83 # run the function
84 results = run_maybe_parallel(
85 func=square,
86 iterable=input_range,
87 parallel=parallel,
88 pbar_kwargs={},
89 keep_ordered=keep_ordered,
90 use_multiprocess=use_multiprocess,
91 )
93 # check the results
94 assert set(results) == set(expected)
95 if keep_ordered:
96 assert results == expected
99@dataset_decorator(["small"])
100@pytest.mark.parametrize(
101 "pbar_type",
102 ["tqdm", "spinner", "none", None, "invalid"],
103)
104@pytest.mark.parametrize("disable_flag", [True, False])
105def test_progress_bar_types_and_disable(input_range, expected, pbar_type, disable_flag):
106 pbar_kwargs = {"disable": disable_flag}
107 if pbar_type == "invalid" and not disable_flag:
108 with pytest.raises(ValueError):
109 run_maybe_parallel(square, input_range, False, pbar_kwargs, pbar=pbar_type)
110 else:
111 results = run_maybe_parallel(
112 square, input_range, False, pbar_kwargs, pbar=pbar_type
113 )
114 assert results == expected
117@dataset_decorator(["small"])
118@pytest.mark.parametrize("chunksize", [None, 1, 5])
119@pytest.mark.parametrize("parallel", [False, True, 2])
120def test_chunksize_and_parallel(input_range, expected, chunksize, parallel):
121 results = run_maybe_parallel(square, input_range, parallel, {}, chunksize=chunksize)
122 assert results == expected
125@dataset_decorator(["small"])
126@pytest.mark.parametrize("invalid_parallel", ["invalid", 0, -1, 1.5])
127def test_invalid_parallel_values(input_range, expected, invalid_parallel):
128 with pytest.raises(ValueError):
129 run_maybe_parallel(square, input_range, invalid_parallel)
132def test_exception_in_func():
133 # one of the inputs is 0..3, no error here
134 # Let's inject a known error
135 error_input = [5] # Will raise ValueError
136 with pytest.raises(ValueError):
137 run_maybe_parallel(raise_value_error, error_input, True, {})
140@dataset_decorator(["small"])
141@pytest.mark.parametrize(
142 "iterable_factory",
143 [
144 lambda x: list(x),
145 lambda x: tuple(x),
146 lambda x: set(x),
147 lambda x: dict.fromkeys(x, 0),
148 ],
149)
150def test_different_iterables(input_range, expected, iterable_factory):
151 test_input = iterable_factory(input_range)
152 result = run_maybe_parallel(square, test_input, False)
153 if isinstance(test_input, set):
154 assert set(result) == set(expected)
155 else:
156 assert result == expected
159@pytest.mark.parametrize("parallel", [False, True])
160def test_error_handling(parallel):
161 # input_range is all positive small range, let's modify it to include negatives
162 input_data = [-1, 0, 1, -2]
163 with pytest.raises(ValueError):
164 run_maybe_parallel(raise_on_negative, input_data, parallel)
167def _process_complex(obj):
168 return ComplexObject(obj.value * 2)
171COMPLEX_DATA: List[ComplexObject] = [ComplexObject(i) for i in range(5)]
172EXPECTED_COMPLEX = [ComplexObject(i * 2) for i in range(5)]
175@pytest.mark.parametrize("parallel", [False, True])
176@pytest.mark.parametrize("pbar_type", [None, DEFAULT_PBAR_FN])
177def test_complex_objects(parallel, pbar_type):
178 # override input_range with complex objects just for this test
179 result = run_maybe_parallel(
180 _process_complex, COMPLEX_DATA, parallel, pbar=pbar_type
181 )
182 expected_complex = EXPECTED_COMPLEX
183 assert all(a == b for a, b in zip(result, expected_complex))
186@dataset_decorator(["small"])
187def test_resource_cleanup(input_range, expected):
188 initial_processes = len(multiprocessing.active_children())
189 run_maybe_parallel(square, input_range, True)
190 time.sleep(0.05)
191 final_processes = len(multiprocessing.active_children())
192 assert abs(final_processes - initial_processes) <= 2
195@dataset_decorator(["small"])
196def test_custom_progress_bar(input_range, expected):
197 def custom_progress_bar_fn(iterable: Iterable, **kwargs: Any) -> Iterable:
198 return iterable
200 result = run_maybe_parallel(square, input_range, False, pbar=custom_progress_bar_fn)
201 assert result == expected
204@dataset_decorator(["small"])
205@pytest.mark.parametrize(
206 "kwargs",
207 [
208 None,
209 dict(),
210 dict(desc="Processing"),
211 dict(disable=True),
212 dict(ascii=True),
213 dict(config="default"),
214 dict(config="bar"),
215 dict(ascii=True, config="bar"),
216 dict(message="Processing"),
217 dict(message="Processing", desc="Processing"),
218 ],
219)
220def test_progress_bar_kwargs(input_range, expected, kwargs):
221 result = run_maybe_parallel(square, input_range, False, pbar_kwargs=kwargs)
222 assert result == expected
225@dataset_decorator(["medium"])
226def test_parallel_performance(input_range, expected):
227 serial_result = run_maybe_parallel(slow_square, input_range, False)
228 parallel_result = run_maybe_parallel(slow_square, input_range, True)
229 assert serial_result == parallel_result
232@dataset_decorator(["small"])
233def test_reject_pbar_str_when_not_str_or_callable(input_range, expected):
234 with pytest.raises(TypeError):
235 run_maybe_parallel(square, input_range, False, pbar=12345)
238def custom_pbar(iterable: Iterable, **kwargs: Any) -> List:
239 return list(iterable)
242@dataset_decorator(["small"])
243def test_manual_callable_pbar(input_range, expected):
244 results = run_maybe_parallel(square, input_range, False, pbar=custom_pbar)
245 assert results == expected, "Manual callable pbar test failed."
248@pytest.mark.parametrize(
249 "input_data, parallel",
250 [
251 (range(multiprocessing.cpu_count() + 1), True),
252 (range(multiprocessing.cpu_count() - 1), True),
253 ],
254)
255def test_edge_cases(input_data, parallel):
256 result = run_maybe_parallel(square, input_data, parallel)
257 assert result == [square(x) for x in input_data]