A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/pandas-dev/pandas/commit/b8b6471f85054d5e6aa5ba77e8ff4fe615a6a73c below:

Add numba engine to groupby.transform (#32854) · pandas-dev/pandas@b8b6471 · GitHub

75 75

import pandas.core.indexes.base as ibase

76 76

from pandas.core.internals import BlockManager, make_block

77 77

from pandas.core.series import Series

78 +

from pandas.core.util.numba_ import (

79 +

check_kwargs_and_nopython,

80 +

get_jit_arguments,

81 +

jit_user_function,

82 +

split_for_numba,

83 +

validate_udf,

84 +

)

78 85 79 86

from pandas.plotting import boxplot_frame_groupby

80 87

@@ -154,6 +161,8 @@ def pinner(cls):

154 161

class SeriesGroupBy(GroupBy[Series]):

155 162

_apply_whitelist = base.series_apply_whitelist

156 163 164 +

_numba_func_cache: Dict[Callable, Callable] = {}

165 + 157 166

def _iterate_slices(self) -> Iterable[Series]:

158 167

yield self._selected_obj

159 168

@@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs):

463 472 464 473

@Substitution(klass="Series", selected="A.")

465 474

@Appender(_transform_template)

466 -

def transform(self, func, *args, **kwargs):

475 +

def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):

467 476

func = self._get_cython_func(func) or func

468 477 469 478

if not isinstance(func, str):

470 -

return self._transform_general(func, *args, **kwargs)

479 +

return self._transform_general(

480 +

func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs

481 +

)

471 482 472 483

elif func not in base.transform_kernel_whitelist:

473 484

msg = f"'{func}' is not a valid function name for transform(name)"

@@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs):

482 493

result = getattr(self, func)(*args, **kwargs)

483 494

return self._transform_fast(result, func)

484 495 485 -

def _transform_general(self, func, *args, **kwargs):

496 +

def _transform_general(

497 +

self, func, *args, engine="cython", engine_kwargs=None, **kwargs

498 +

):

486 499

"""

487 500

Transform with a non-str `func`.

488 501

"""

502 + 503 +

if engine == "numba":

504 +

nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

505 +

check_kwargs_and_nopython(kwargs, nopython)

506 +

validate_udf(func)

507 +

numba_func = self._numba_func_cache.get(

508 +

func, jit_user_function(func, nopython, nogil, parallel)

509 +

)

510 + 489 511

klass = type(self._selected_obj)

490 512 491 513

results = []

492 514

for name, group in self:

493 515

object.__setattr__(group, "name", name)

494 -

res = func(group, *args, **kwargs)

516 +

if engine == "numba":

517 +

values, index = split_for_numba(group)

518 +

res = numba_func(values, index, *args)

519 +

if func not in self._numba_func_cache:

520 +

self._numba_func_cache[func] = numba_func

521 +

else:

522 +

res = func(group, *args, **kwargs)

495 523 496 524

if isinstance(res, (ABCDataFrame, ABCSeries)):

497 525

res = res._values

@@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):

819 847 820 848

_apply_whitelist = base.dataframe_apply_whitelist

821 849 850 +

_numba_func_cache: Dict[Callable, Callable] = {}

851 + 822 852

_agg_see_also_doc = dedent(

823 853

"""

824 854

See Also

@@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):

1355 1385

# Handle cases like BinGrouper

1356 1386

return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)

1357 1387 1358 -

def _transform_general(self, func, *args, **kwargs):

1388 +

def _transform_general(

1389 +

self, func, *args, engine="cython", engine_kwargs=None, **kwargs

1390 +

):

1359 1391

from pandas.core.reshape.concat import concat

1360 1392 1361 1393

applied = []

1362 1394

obj = self._obj_with_exclusions

1363 1395

gen = self.grouper.get_iterator(obj, axis=self.axis)

1364 -

fast_path, slow_path = self._define_paths(func, *args, **kwargs)

1396 +

if engine == "numba":

1397 +

nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

1398 +

check_kwargs_and_nopython(kwargs, nopython)

1399 +

validate_udf(func)

1400 +

numba_func = self._numba_func_cache.get(

1401 +

func, jit_user_function(func, nopython, nogil, parallel)

1402 +

)

1403 +

else:

1404 +

fast_path, slow_path = self._define_paths(func, *args, **kwargs)

1365 1405 1366 -

path = None

1367 1406

for name, group in gen:

1368 1407

object.__setattr__(group, "name", name)

1369 1408 1370 -

if path is None:

1409 +

if engine == "numba":

1410 +

values, index = split_for_numba(group)

1411 +

res = numba_func(values, index, *args)

1412 +

if func not in self._numba_func_cache:

1413 +

self._numba_func_cache[func] = numba_func

1414 +

# Return the result as a DataFrame for concatenation later

1415 +

res = DataFrame(res, index=group.index, columns=group.columns)

1416 +

else:

1371 1417

# Try slow path and fast path.

1372 1418

try:

1373 1419

path, res = self._choose_path(fast_path, slow_path, group)

@@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs):

1376 1422

except ValueError as err:

1377 1423

msg = "transform must return a scalar value for each group"

1378 1424

raise ValueError(msg) from err

1379 -

else:

1380 -

res = path(group)

1381 1425 1382 1426

if isinstance(res, Series):

1383 1427

@@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs):

1411 1455 1412 1456

@Substitution(klass="DataFrame", selected="")

1413 1457

@Appender(_transform_template)

1414 -

def transform(self, func, *args, **kwargs):

1458 +

def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):

1415 1459 1416 1460

# optimized transforms

1417 1461

func = self._get_cython_func(func) or func

1418 1462 1419 1463

if not isinstance(func, str):

1420 -

return self._transform_general(func, *args, **kwargs)

1464 +

return self._transform_general(

1465 +

func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs

1466 +

)

1421 1467 1422 1468

elif func not in base.transform_kernel_whitelist:

1423 1469

msg = f"'{func}' is not a valid function name for transform(name)"

@@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs):

1439 1485

):

1440 1486

return self._transform_fast(result, func)

1441 1487 1442 -

return self._transform_general(func, *args, **kwargs)

1488 +

return self._transform_general(

1489 +

func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs

1490 +

)

1443 1491 1444 1492

def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:

1445 1493

"""


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4