75
75
import pandas.core.indexes.base as ibase
76
76
from pandas.core.internals import BlockManager, make_block
77
77
from pandas.core.series import Series
78
+
from pandas.core.util.numba_ import (
79
+
check_kwargs_and_nopython,
80
+
get_jit_arguments,
81
+
jit_user_function,
82
+
split_for_numba,
83
+
validate_udf,
84
+
)
78
85
79
86
from pandas.plotting import boxplot_frame_groupby
80
87
@@ -154,6 +161,8 @@ def pinner(cls):
154
161
class SeriesGroupBy(GroupBy[Series]):
155
162
_apply_whitelist = base.series_apply_whitelist
156
163
164
+
_numba_func_cache: Dict[Callable, Callable] = {}
165
+
157
166
def _iterate_slices(self) -> Iterable[Series]:
158
167
yield self._selected_obj
159
168
@@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs):
463
472
464
473
@Substitution(klass="Series", selected="A.")
465
474
@Appender(_transform_template)
466
-
def transform(self, func, *args, **kwargs):
475
+
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
467
476
func = self._get_cython_func(func) or func
468
477
469
478
if not isinstance(func, str):
470
-
return self._transform_general(func, *args, **kwargs)
479
+
return self._transform_general(
480
+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
481
+
)
471
482
472
483
elif func not in base.transform_kernel_whitelist:
473
484
msg = f"'{func}' is not a valid function name for transform(name)"
@@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs):
482
493
result = getattr(self, func)(*args, **kwargs)
483
494
return self._transform_fast(result, func)
484
495
485
-
def _transform_general(self, func, *args, **kwargs):
496
+
def _transform_general(
497
+
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
498
+
):
486
499
"""
487
500
Transform with a non-str `func`.
488
501
"""
502
+
503
+
if engine == "numba":
504
+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
505
+
check_kwargs_and_nopython(kwargs, nopython)
506
+
validate_udf(func)
507
+
numba_func = self._numba_func_cache.get(
508
+
func, jit_user_function(func, nopython, nogil, parallel)
509
+
)
510
+
489
511
klass = type(self._selected_obj)
490
512
491
513
results = []
492
514
for name, group in self:
493
515
object.__setattr__(group, "name", name)
494
-
res = func(group, *args, **kwargs)
516
+
if engine == "numba":
517
+
values, index = split_for_numba(group)
518
+
res = numba_func(values, index, *args)
519
+
if func not in self._numba_func_cache:
520
+
self._numba_func_cache[func] = numba_func
521
+
else:
522
+
res = func(group, *args, **kwargs)
495
523
496
524
if isinstance(res, (ABCDataFrame, ABCSeries)):
497
525
res = res._values
@@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
819
847
820
848
_apply_whitelist = base.dataframe_apply_whitelist
821
849
850
+
_numba_func_cache: Dict[Callable, Callable] = {}
851
+
822
852
_agg_see_also_doc = dedent(
823
853
"""
824
854
See Also
@@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
1355
1385
# Handle cases like BinGrouper
1356
1386
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
1357
1387
1358
-
def _transform_general(self, func, *args, **kwargs):
1388
+
def _transform_general(
1389
+
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
1390
+
):
1359
1391
from pandas.core.reshape.concat import concat
1360
1392
1361
1393
applied = []
1362
1394
obj = self._obj_with_exclusions
1363
1395
gen = self.grouper.get_iterator(obj, axis=self.axis)
1364
-
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
1396
+
if engine == "numba":
1397
+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
1398
+
check_kwargs_and_nopython(kwargs, nopython)
1399
+
validate_udf(func)
1400
+
numba_func = self._numba_func_cache.get(
1401
+
func, jit_user_function(func, nopython, nogil, parallel)
1402
+
)
1403
+
else:
1404
+
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
1365
1405
1366
-
path = None
1367
1406
for name, group in gen:
1368
1407
object.__setattr__(group, "name", name)
1369
1408
1370
-
if path is None:
1409
+
if engine == "numba":
1410
+
values, index = split_for_numba(group)
1411
+
res = numba_func(values, index, *args)
1412
+
if func not in self._numba_func_cache:
1413
+
self._numba_func_cache[func] = numba_func
1414
+
# Return the result as a DataFrame for concatenation later
1415
+
res = DataFrame(res, index=group.index, columns=group.columns)
1416
+
else:
1371
1417
# Try slow path and fast path.
1372
1418
try:
1373
1419
path, res = self._choose_path(fast_path, slow_path, group)
@@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs):
1376
1422
except ValueError as err:
1377
1423
msg = "transform must return a scalar value for each group"
1378
1424
raise ValueError(msg) from err
1379
-
else:
1380
-
res = path(group)
1381
1425
1382
1426
if isinstance(res, Series):
1383
1427
@@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs):
1411
1455
1412
1456
@Substitution(klass="DataFrame", selected="")
1413
1457
@Appender(_transform_template)
1414
-
def transform(self, func, *args, **kwargs):
1458
+
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
1415
1459
1416
1460
# optimized transforms
1417
1461
func = self._get_cython_func(func) or func
1418
1462
1419
1463
if not isinstance(func, str):
1420
-
return self._transform_general(func, *args, **kwargs)
1464
+
return self._transform_general(
1465
+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1466
+
)
1421
1467
1422
1468
elif func not in base.transform_kernel_whitelist:
1423
1469
msg = f"'{func}' is not a valid function name for transform(name)"
@@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs):
1439
1485
):
1440
1486
return self._transform_fast(result, func)
1441
1487
1442
-
return self._transform_general(func, *args, **kwargs)
1488
+
return self._transform_general(
1489
+
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
1490
+
)
1443
1491
1444
1492
def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
1445
1493
"""
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4