1
+
from functools import wraps
2
+
import threading
3
+
1
4
import numpy as np
2
5
3
6
from pandas import (
30
33
from pandas._libs import algos
31
34
except ImportError:
32
35
from pandas import algos
33
-
try:
34
-
from pandas._testing import test_parallel # noqa: PDF014
35
36
36
-
have_real_test_parallel = True
37
-
except ImportError:
38
-
have_real_test_parallel = False
39
37
40
-
def test_parallel(num_threads=1):
41
-
def wrapper(fname):
42
-
return fname
38
+
from .pandas_vb_common import BaseIO # isort:skip
43
39
44
-
return wrapper
45
40
41
+
def test_parallel(num_threads=2, kwargs_list=None):
42
+
"""
43
+
Decorator to run the same function multiple times in parallel.
46
44
47
-
from .pandas_vb_common import BaseIO # isort:skip
45
+
Parameters
46
+
----------
47
+
num_threads : int, optional
48
+
The number of times the function is run in parallel.
49
+
kwargs_list : list of dicts, optional
50
+
The list of kwargs to update original
51
+
function kwargs on different threads.
52
+
53
+
Notes
54
+
-----
55
+
This decorator does not pass the return value of the decorated function.
56
+
57
+
Original from scikit-image:
58
+
59
+
https://github.com/scikit-image/scikit-image/pull/1519
60
+
61
+
"""
62
+
assert num_threads > 0
63
+
has_kwargs_list = kwargs_list is not None
64
+
if has_kwargs_list:
65
+
assert len(kwargs_list) == num_threads
66
+
67
+
def wrapper(func):
68
+
@wraps(func)
69
+
def inner(*args, **kwargs):
70
+
if has_kwargs_list:
71
+
update_kwargs = lambda i: dict(kwargs, **kwargs_list[i])
72
+
else:
73
+
update_kwargs = lambda i: kwargs
74
+
threads = []
75
+
for i in range(num_threads):
76
+
updated_kwargs = update_kwargs(i)
77
+
thread = threading.Thread(target=func, args=args, kwargs=updated_kwargs)
78
+
threads.append(thread)
79
+
for thread in threads:
80
+
thread.start()
81
+
for thread in threads:
82
+
thread.join()
83
+
84
+
return inner
85
+
86
+
return wrapper
48
87
49
88
50
89
class ParallelGroupbyMethods:
@@ -53,8 +92,7 @@ class ParallelGroupbyMethods:
53
92
param_names = ["threads", "method"]
54
93
55
94
def setup(self, threads, method):
56
-
if not have_real_test_parallel:
57
-
raise NotImplementedError
95
+
58
96
N = 10**6
59
97
ngroups = 10**3
60
98
df = DataFrame(
@@ -86,8 +124,7 @@ class ParallelGroups:
86
124
param_names = ["threads"]
87
125
88
126
def setup(self, threads):
89
-
if not have_real_test_parallel:
90
-
raise NotImplementedError
127
+
91
128
size = 2**22
92
129
ngroups = 10**3
93
130
data = Series(np.random.randint(0, ngroups, size=size))
@@ -108,8 +145,7 @@ class ParallelTake1D:
108
145
param_names = ["dtype"]
109
146
110
147
def setup(self, dtype):
111
-
if not have_real_test_parallel:
112
-
raise NotImplementedError
148
+
113
149
N = 10**6
114
150
df = DataFrame({"col": np.arange(N, dtype=dtype)})
115
151
indexer = np.arange(100, len(df) - 100)
@@ -131,8 +167,7 @@ class ParallelKth:
131
167
repeat = 5
132
168
133
169
def setup(self):
134
-
if not have_real_test_parallel:
135
-
raise NotImplementedError
170
+
136
171
N = 10**7
137
172
k = 5 * 10**5
138
173
kwargs_list = [{"arr": np.random.randn(N)}, {"arr": np.random.randn(N)}]
@@ -149,8 +184,7 @@ def time_kth_smallest(self):
149
184
150
185
class ParallelDatetimeFields:
151
186
def setup(self):
152
-
if not have_real_test_parallel:
153
-
raise NotImplementedError
187
+
154
188
N = 10**6
155
189
self.dti = date_range("1900-01-01", periods=N, freq="T")
156
190
self.period = self.dti.to_period("D")
@@ -204,8 +238,7 @@ class ParallelRolling:
204
238
param_names = ["method"]
205
239
206
240
def setup(self, method):
207
-
if not have_real_test_parallel:
208
-
raise NotImplementedError
241
+
209
242
win = 100
210
243
arr = np.random.rand(100000)
211
244
if hasattr(DataFrame, "rolling"):
@@ -248,8 +281,7 @@ class ParallelReadCSV(BaseIO):
248
281
param_names = ["dtype"]
249
282
250
283
def setup(self, dtype):
251
-
if not have_real_test_parallel:
252
-
raise NotImplementedError
284
+
253
285
rows = 10000
254
286
cols = 50
255
287
data = {
@@ -284,8 +316,6 @@ class ParallelFactorize:
284
316
param_names = ["threads"]
285
317
286
318
def setup(self, threads):
287
-
if not have_real_test_parallel:
288
-
raise NotImplementedError
289
319
290
320
strings = tm.makeStringIndex(100000)
291
321
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4