Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2Provide classes to perform the groupby aggregate operations. 

3 

4These are not exposed to the user and provide implementations of the grouping 

5operations, primarily in cython. These classes (BaseGrouper and BinGrouper) 

6are contained *in* the SeriesGroupBy and DataFrameGroupBy objects. 

7""" 

8 

9import collections 

10from typing import List, Optional, Sequence, Tuple, Type 

11 

12import numpy as np 

13 

14from pandas._libs import NaT, iNaT, lib 

15import pandas._libs.groupby as libgroupby 

16import pandas._libs.reduction as libreduction 

17from pandas._typing import FrameOrSeries 

18from pandas.errors import AbstractMethodError 

19from pandas.util._decorators import cache_readonly 

20 

21from pandas.core.dtypes.common import ( 

22 ensure_float64, 

23 ensure_int64, 

24 ensure_int_or_float, 

25 ensure_platform_int, 

26 is_bool_dtype, 

27 is_categorical_dtype, 

28 is_complex_dtype, 

29 is_datetime64_any_dtype, 

30 is_datetime64tz_dtype, 

31 is_extension_array_dtype, 

32 is_integer_dtype, 

33 is_numeric_dtype, 

34 is_period_dtype, 

35 is_sparse, 

36 is_timedelta64_dtype, 

37 needs_i8_conversion, 

38) 

39from pandas.core.dtypes.missing import _maybe_fill, isna 

40 

41import pandas.core.algorithms as algorithms 

42from pandas.core.base import SelectionMixin 

43import pandas.core.common as com 

44from pandas.core.frame import DataFrame 

45from pandas.core.generic import NDFrame 

46from pandas.core.groupby import base, grouper 

47from pandas.core.indexes.api import Index, MultiIndex, ensure_index 

48from pandas.core.series import Series 

49from pandas.core.sorting import ( 

50 compress_group_index, 

51 decons_obs_group_ids, 

52 get_flattened_iterator, 

53 get_group_index, 

54 get_group_index_sorter, 

55 get_indexer_dict, 

56) 

57 

58 

59class BaseGrouper: 

60 """ 

61 This is an internal Grouper class, which actually holds 

62 the generated groups 

63 

64 Parameters 

65 ---------- 

66 axis : Index 

67 groupings : Sequence[Grouping] 

68 all the grouping instances to handle in this grouper 

69 for example for grouper list to groupby, need to pass the list 

70 sort : bool, default True 

71 whether this grouper will give sorted result or not 

72 group_keys : bool, default True 

73 mutated : bool, default False 

74 indexer : intp array, optional 

75 the indexer created by Grouper 

76 some groupers (TimeGrouper) will sort its axis and its 

77 group_info is also sorted, so need the indexer to reorder 

78 

79 """ 

80 

81 def __init__( 

82 self, 

83 axis: Index, 

84 groupings: "Sequence[grouper.Grouping]", 

85 sort: bool = True, 

86 group_keys: bool = True, 

87 mutated: bool = False, 

88 indexer: Optional[np.ndarray] = None, 

89 ): 

90 assert isinstance(axis, Index), axis 

91 

92 self._filter_empty_groups = self.compressed = len(groupings) != 1 

93 self.axis = axis 

94 self._groupings: List[grouper.Grouping] = list(groupings) 

95 self.sort = sort 

96 self.group_keys = group_keys 

97 self.mutated = mutated 

98 self.indexer = indexer 

99 

100 @property 

101 def groupings(self) -> List["grouper.Grouping"]: 

102 return self._groupings 

103 

104 @property 

105 def shape(self): 

106 return tuple(ping.ngroups for ping in self.groupings) 

107 

108 def __iter__(self): 

109 return iter(self.indices) 

110 

111 @property 

112 def nkeys(self) -> int: 

113 return len(self.groupings) 

114 

115 def get_iterator(self, data: FrameOrSeries, axis: int = 0): 

116 """ 

117 Groupby iterator 

118 

119 Returns 

120 ------- 

121 Generator yielding sequence of (name, subsetted object) 

122 for each group 

123 """ 

124 splitter = self._get_splitter(data, axis=axis) 

125 keys = self._get_group_keys() 

126 for key, (i, group) in zip(keys, splitter): 

127 yield key, group 

128 

129 def _get_splitter(self, data: FrameOrSeries, axis: int = 0) -> "DataSplitter": 

130 comp_ids, _, ngroups = self.group_info 

131 return get_splitter(data, comp_ids, ngroups, axis=axis) 

132 

133 def _get_grouper(self): 

134 """ 

135 We are a grouper as part of another's groupings. 

136 

137 We have a specific method of grouping, so cannot 

138 convert to a Index for our grouper. 

139 """ 

140 return self.groupings[0].grouper 

141 

142 def _get_group_keys(self): 

143 if len(self.groupings) == 1: 

144 return self.levels[0] 

145 else: 

146 comp_ids, _, ngroups = self.group_info 

147 

148 # provide "flattened" iterator for multi-group setting 

149 return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes) 

150 

151 def apply(self, f, data: FrameOrSeries, axis: int = 0): 

152 mutated = self.mutated 

153 splitter = self._get_splitter(data, axis=axis) 

154 group_keys = self._get_group_keys() 

155 result_values = None 

156 

157 sdata: FrameOrSeries = splitter._get_sorted_data() 

158 if sdata.ndim == 2 and np.any(sdata.dtypes.apply(is_extension_array_dtype)): 

159 # calling splitter.fast_apply will raise TypeError via apply_frame_axis0 

160 # if we pass EA instead of ndarray 

161 # TODO: can we have a workaround for EAs backed by ndarray? 

162 pass 

163 

164 elif ( 

165 com.get_callable_name(f) not in base.plotting_methods 

166 and isinstance(splitter, FrameSplitter) 

167 and axis == 0 

168 # fast_apply/libreduction doesn't allow non-numpy backed indexes 

169 and not sdata.index._has_complex_internals 

170 ): 

171 try: 

172 result_values, mutated = splitter.fast_apply(f, group_keys) 

173 

174 except libreduction.InvalidApply as err: 

175 # This Exception is raised if `f` triggers an exception 

176 # but it is preferable to raise the exception in Python. 

177 if "Let this error raise above us" not in str(err): 

178 # TODO: can we infer anything about whether this is 

179 # worth-retrying in pure-python? 

180 raise 

181 

182 else: 

183 # If the fast apply path could be used we can return here. 

184 # Otherwise we need to fall back to the slow implementation. 

185 if len(result_values) == len(group_keys): 

186 return group_keys, result_values, mutated 

187 

188 for key, (i, group) in zip(group_keys, splitter): 

189 object.__setattr__(group, "name", key) 

190 

191 # result_values is None if fast apply path wasn't taken 

192 # or fast apply aborted with an unexpected exception. 

193 # In either case, initialize the result list and perform 

194 # the slow iteration. 

195 if result_values is None: 

196 result_values = [] 

197 

198 # If result_values is not None we're in the case that the 

199 # fast apply loop was broken prematurely but we have 

200 # already the result for the first group which we can reuse. 

201 elif i == 0: 

202 continue 

203 

204 # group might be modified 

205 group_axes = group.axes 

206 res = f(group) 

207 if not _is_indexed_like(res, group_axes): 

208 mutated = True 

209 result_values.append(res) 

210 

211 return group_keys, result_values, mutated 

212 

213 @cache_readonly 

214 def indices(self): 

215 """ dict {group name -> group indices} """ 

216 if len(self.groupings) == 1: 

217 return self.groupings[0].indices 

218 else: 

219 codes_list = [ping.codes for ping in self.groupings] 

220 keys = [com.values_from_object(ping.group_index) for ping in self.groupings] 

221 return get_indexer_dict(codes_list, keys) 

222 

223 @property 

224 def codes(self) -> List[np.ndarray]: 

225 return [ping.codes for ping in self.groupings] 

226 

227 @property 

228 def levels(self) -> List[Index]: 

229 return [ping.group_index for ping in self.groupings] 

230 

231 @property 

232 def names(self): 

233 return [ping.name for ping in self.groupings] 

234 

235 def size(self) -> Series: 

236 """ 

237 Compute group sizes. 

238 """ 

239 ids, _, ngroup = self.group_info 

240 ids = ensure_platform_int(ids) 

241 if ngroup: 

242 out = np.bincount(ids[ids != -1], minlength=ngroup) 

243 else: 

244 out = [] 

245 return Series(out, index=self.result_index, dtype="int64") 

246 

247 @cache_readonly 

248 def groups(self): 

249 """ dict {group name -> group labels} """ 

250 if len(self.groupings) == 1: 

251 return self.groupings[0].groups 

252 else: 

253 to_groupby = zip(*(ping.grouper for ping in self.groupings)) 

254 to_groupby = Index(to_groupby) 

255 return self.axis.groupby(to_groupby) 

256 

257 @cache_readonly 

258 def is_monotonic(self) -> bool: 

259 # return if my group orderings are monotonic 

260 return Index(self.group_info[0]).is_monotonic 

261 

262 @cache_readonly 

263 def group_info(self): 

264 comp_ids, obs_group_ids = self._get_compressed_codes() 

265 

266 ngroups = len(obs_group_ids) 

267 comp_ids = ensure_int64(comp_ids) 

268 return comp_ids, obs_group_ids, ngroups 

269 

270 @cache_readonly 

271 def codes_info(self) -> np.ndarray: 

272 # return the codes of items in original grouped axis 

273 codes, _, _ = self.group_info 

274 if self.indexer is not None: 

275 sorter = np.lexsort((codes, self.indexer)) 

276 codes = codes[sorter] 

277 return codes 

278 

279 def _get_compressed_codes(self) -> Tuple[np.ndarray, np.ndarray]: 

280 all_codes = self.codes 

281 if len(all_codes) > 1: 

282 group_index = get_group_index(all_codes, self.shape, sort=True, xnull=True) 

283 return compress_group_index(group_index, sort=self.sort) 

284 

285 ping = self.groupings[0] 

286 return ping.codes, np.arange(len(ping.group_index)) 

287 

288 @cache_readonly 

289 def ngroups(self) -> int: 

290 return len(self.result_index) 

291 

292 @property 

293 def reconstructed_codes(self) -> List[np.ndarray]: 

294 codes = self.codes 

295 comp_ids, obs_ids, _ = self.group_info 

296 return decons_obs_group_ids(comp_ids, obs_ids, self.shape, codes, xnull=True) 

297 

298 @cache_readonly 

299 def result_index(self) -> Index: 

300 if not self.compressed and len(self.groupings) == 1: 

301 return self.groupings[0].result_index.rename(self.names[0]) 

302 

303 codes = self.reconstructed_codes 

304 levels = [ping.result_index for ping in self.groupings] 

305 result = MultiIndex( 

306 levels=levels, codes=codes, verify_integrity=False, names=self.names 

307 ) 

308 return result 

309 

310 def get_group_levels(self): 

311 if not self.compressed and len(self.groupings) == 1: 

312 return [self.groupings[0].result_index] 

313 

314 name_list = [] 

315 for ping, codes in zip(self.groupings, self.reconstructed_codes): 

316 codes = ensure_platform_int(codes) 

317 levels = ping.result_index.take(codes) 

318 

319 name_list.append(levels) 

320 

321 return name_list 

322 

323 # ------------------------------------------------------------ 

324 # Aggregation functions 

325 

326 _cython_functions = { 

327 "aggregate": { 

328 "add": "group_add", 

329 "prod": "group_prod", 

330 "min": "group_min", 

331 "max": "group_max", 

332 "mean": "group_mean", 

333 "median": "group_median", 

334 "var": "group_var", 

335 "first": "group_nth", 

336 "last": "group_last", 

337 "ohlc": "group_ohlc", 

338 }, 

339 "transform": { 

340 "cumprod": "group_cumprod", 

341 "cumsum": "group_cumsum", 

342 "cummin": "group_cummin", 

343 "cummax": "group_cummax", 

344 "rank": "group_rank", 

345 }, 

346 } 

347 

348 _cython_arity = {"ohlc": 4} # OHLC 

349 

350 _name_functions = {"ohlc": ["open", "high", "low", "close"]} 

351 

352 def _is_builtin_func(self, arg): 

353 """ 

354 if we define an builtin function for this argument, return it, 

355 otherwise return the arg 

356 """ 

357 return SelectionMixin._builtin_table.get(arg, arg) 

358 

359 def _get_cython_function(self, kind: str, how: str, values, is_numeric: bool): 

360 

361 dtype_str = values.dtype.name 

362 ftype = self._cython_functions[kind][how] 

363 

364 # see if there is a fused-type version of function 

365 # only valid for numeric 

366 f = getattr(libgroupby, ftype, None) 

367 if f is not None and is_numeric: 

368 return f 

369 

370 # otherwise find dtype-specific version, falling back to object 

371 for dt in [dtype_str, "object"]: 

372 f2 = getattr(libgroupby, f"{ftype}_{dt}", None) 

373 if f2 is not None: 

374 return f2 

375 

376 if hasattr(f, "__signatures__"): 

377 # inspect what fused types are implemented 

378 if dtype_str == "object" and "object" not in f.__signatures__: 

379 # disallow this function so we get a NotImplementedError below 

380 # instead of a TypeError at runtime 

381 f = None 

382 

383 func = f 

384 

385 if func is None: 

386 raise NotImplementedError( 

387 f"function is not implemented for this dtype: " 

388 f"[how->{how},dtype->{dtype_str}]" 

389 ) 

390 

391 return func 

392 

393 def _get_cython_func_and_vals( 

394 self, kind: str, how: str, values: np.ndarray, is_numeric: bool 

395 ): 

396 """ 

397 Find the appropriate cython function, casting if necessary. 

398 

399 Parameters 

400 ---------- 

401 kind : sttr 

402 how : srt 

403 values : np.ndarray 

404 is_numeric : bool 

405 

406 Returns 

407 ------- 

408 func : callable 

409 values : np.ndarray 

410 """ 

411 try: 

412 func = self._get_cython_function(kind, how, values, is_numeric) 

413 except NotImplementedError: 

414 if is_numeric: 

415 try: 

416 values = ensure_float64(values) 

417 except TypeError: 

418 if lib.infer_dtype(values, skipna=False) == "complex": 

419 values = values.astype(complex) 

420 else: 

421 raise 

422 func = self._get_cython_function(kind, how, values, is_numeric) 

423 else: 

424 raise 

425 return func, values 

426 

427 def _cython_operation( 

428 self, kind: str, values, how: str, axis, min_count: int = -1, **kwargs 

429 ) -> Tuple[np.ndarray, Optional[List[str]]]: 

430 """ 

431 Returns the values of a cython operation as a Tuple of [data, names]. 

432 

433 Names is only useful when dealing with 2D results, like ohlc 

434 (see self._name_functions). 

435 """ 

436 

437 assert kind in ["transform", "aggregate"] 

438 orig_values = values 

439 

440 if values.ndim > 2: 

441 raise NotImplementedError("number of dimensions is currently limited to 2") 

442 elif values.ndim == 2: 

443 # Note: it is *not* the case that axis is always 0 for 1-dim values, 

444 # as we can have 1D ExtensionArrays that we need to treat as 2D 

445 assert axis == 1, axis 

446 

447 # can we do this operation with our cython functions 

448 # if not raise NotImplementedError 

449 

450 # we raise NotImplemented if this is an invalid operation 

451 # entirely, e.g. adding datetimes 

452 

453 # categoricals are only 1d, so we 

454 # are not setup for dim transforming 

455 if is_categorical_dtype(values) or is_sparse(values): 

456 raise NotImplementedError(f"{values.dtype} dtype not supported") 

457 elif is_datetime64_any_dtype(values): 

458 if how in ["add", "prod", "cumsum", "cumprod"]: 

459 raise NotImplementedError( 

460 f"datetime64 type does not support {how} operations" 

461 ) 

462 elif is_timedelta64_dtype(values): 

463 if how in ["prod", "cumprod"]: 

464 raise NotImplementedError( 

465 f"timedelta64 type does not support {how} operations" 

466 ) 

467 

468 if is_datetime64tz_dtype(values.dtype): 

469 # Cast to naive; we'll cast back at the end of the function 

470 # TODO: possible need to reshape? kludge can be avoided when 

471 # 2D EA is allowed. 

472 values = values.view("M8[ns]") 

473 

474 is_datetimelike = needs_i8_conversion(values.dtype) 

475 is_numeric = is_numeric_dtype(values.dtype) 

476 

477 if is_datetimelike: 

478 values = values.view("int64") 

479 is_numeric = True 

480 elif is_bool_dtype(values.dtype): 

481 values = ensure_float64(values) 

482 elif is_integer_dtype(values): 

483 # we use iNaT for the missing value on ints 

484 # so pre-convert to guard this condition 

485 if (values == iNaT).any(): 

486 values = ensure_float64(values) 

487 else: 

488 values = ensure_int_or_float(values) 

489 elif is_numeric and not is_complex_dtype(values): 

490 values = ensure_float64(values) 

491 else: 

492 values = values.astype(object) 

493 

494 arity = self._cython_arity.get(how, 1) 

495 

496 vdim = values.ndim 

497 swapped = False 

498 if vdim == 1: 

499 values = values[:, None] 

500 out_shape = (self.ngroups, arity) 

501 else: 

502 if axis > 0: 

503 swapped = True 

504 assert axis == 1, axis 

505 values = values.T 

506 if arity > 1: 

507 raise NotImplementedError( 

508 "arity of more than 1 is not supported for the 'how' argument" 

509 ) 

510 out_shape = (self.ngroups,) + values.shape[1:] 

511 

512 func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric) 

513 

514 if how == "rank": 

515 out_dtype = "float" 

516 else: 

517 if is_numeric: 

518 out_dtype = f"{values.dtype.kind}{values.dtype.itemsize}" 

519 else: 

520 out_dtype = "object" 

521 

522 codes, _, _ = self.group_info 

523 

524 if kind == "aggregate": 

525 result = _maybe_fill( 

526 np.empty(out_shape, dtype=out_dtype), fill_value=np.nan 

527 ) 

528 counts = np.zeros(self.ngroups, dtype=np.int64) 

529 result = self._aggregate( 

530 result, counts, values, codes, func, is_datetimelike, min_count 

531 ) 

532 elif kind == "transform": 

533 result = _maybe_fill( 

534 np.empty_like(values, dtype=out_dtype), fill_value=np.nan 

535 ) 

536 

537 # TODO: min_count 

538 result = self._transform( 

539 result, values, codes, func, is_datetimelike, **kwargs 

540 ) 

541 

542 if is_integer_dtype(result) and not is_datetimelike: 

543 mask = result == iNaT 

544 if mask.any(): 

545 result = result.astype("float64") 

546 result[mask] = np.nan 

547 elif ( 

548 how == "add" 

549 and is_integer_dtype(orig_values.dtype) 

550 and is_extension_array_dtype(orig_values.dtype) 

551 ): 

552 # We need this to ensure that Series[Int64Dtype].resample().sum() 

553 # remains int64 dtype. 

554 # Two options for avoiding this special case 

555 # 1. mask-aware ops and avoid casting to float with NaN above 

556 # 2. specify the result dtype when calling this method 

557 result = result.astype("int64") 

558 

559 if kind == "aggregate" and self._filter_empty_groups and not counts.all(): 

560 assert result.ndim != 2 

561 result = result[counts > 0] 

562 

563 if vdim == 1 and arity == 1: 

564 result = result[:, 0] 

565 

566 names: Optional[List[str]] = self._name_functions.get(how, None) 

567 

568 if swapped: 

569 result = result.swapaxes(0, axis) 

570 

571 if is_datetime64tz_dtype(orig_values.dtype) or is_period_dtype( 

572 orig_values.dtype 

573 ): 

574 # We need to use the constructors directly for these dtypes 

575 # since numpy won't recognize them 

576 # https://github.com/pandas-dev/pandas/issues/31471 

577 result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype) 

578 elif is_datetimelike and kind == "aggregate": 

579 result = result.astype(orig_values.dtype) 

580 

581 return result, names 

582 

583 def aggregate( 

584 self, values, how: str, axis: int = 0, min_count: int = -1 

585 ) -> Tuple[np.ndarray, Optional[List[str]]]: 

586 return self._cython_operation( 

587 "aggregate", values, how, axis, min_count=min_count 

588 ) 

589 

590 def transform(self, values, how: str, axis: int = 0, **kwargs): 

591 return self._cython_operation("transform", values, how, axis, **kwargs) 

592 

593 def _aggregate( 

594 self, 

595 result, 

596 counts, 

597 values, 

598 comp_ids, 

599 agg_func, 

600 is_datetimelike: bool, 

601 min_count: int = -1, 

602 ): 

603 if agg_func is libgroupby.group_nth: 

604 # different signature from the others 

605 # TODO: should we be using min_count instead of hard-coding it? 

606 agg_func(result, counts, values, comp_ids, rank=1, min_count=-1) 

607 else: 

608 agg_func(result, counts, values, comp_ids, min_count) 

609 

610 return result 

611 

612 def _transform( 

613 self, result, values, comp_ids, transform_func, is_datetimelike: bool, **kwargs 

614 ): 

615 

616 comp_ids, _, ngroups = self.group_info 

617 transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) 

618 

619 return result 

620 

621 def agg_series(self, obj: Series, func): 

622 # Caller is responsible for checking ngroups != 0 

623 assert self.ngroups != 0 

624 

625 if len(obj) == 0: 

626 # SeriesGrouper would raise if we were to call _aggregate_series_fast 

627 return self._aggregate_series_pure_python(obj, func) 

628 

629 elif is_extension_array_dtype(obj.dtype): 

630 # _aggregate_series_fast would raise TypeError when 

631 # calling libreduction.Slider 

632 # In the datetime64tz case it would incorrectly cast to tz-naive 

633 # TODO: can we get a performant workaround for EAs backed by ndarray? 

634 return self._aggregate_series_pure_python(obj, func) 

635 

636 elif obj.index._has_complex_internals: 

637 # Pre-empt TypeError in _aggregate_series_fast 

638 return self._aggregate_series_pure_python(obj, func) 

639 

640 try: 

641 return self._aggregate_series_fast(obj, func) 

642 except ValueError as err: 

643 if "Function does not reduce" in str(err): 

644 # raised in libreduction 

645 pass 

646 else: 

647 raise 

648 return self._aggregate_series_pure_python(obj, func) 

649 

650 def _aggregate_series_fast(self, obj: Series, func): 

651 # At this point we have already checked that 

652 # - obj.index is not a MultiIndex 

653 # - obj is backed by an ndarray, not ExtensionArray 

654 # - len(obj) > 0 

655 # - ngroups != 0 

656 func = self._is_builtin_func(func) 

657 

658 group_index, _, ngroups = self.group_info 

659 

660 # avoids object / Series creation overhead 

661 dummy = obj._get_values(slice(None, 0)) 

662 indexer = get_group_index_sorter(group_index, ngroups) 

663 obj = obj.take(indexer) 

664 group_index = algorithms.take_nd(group_index, indexer, allow_fill=False) 

665 grouper = libreduction.SeriesGrouper(obj, func, group_index, ngroups, dummy) 

666 result, counts = grouper.get_result() 

667 return result, counts 

668 

669 def _aggregate_series_pure_python(self, obj: Series, func): 

670 

671 group_index, _, ngroups = self.group_info 

672 

673 counts = np.zeros(ngroups, dtype=int) 

674 result = None 

675 

676 splitter = get_splitter(obj, group_index, ngroups, axis=0) 

677 

678 for label, group in splitter: 

679 res = func(group) 

680 if result is None: 

681 if isinstance(res, (Series, Index, np.ndarray)): 

682 if len(res) == 1: 

683 # e.g. test_agg_lambda_with_timezone lambda e: e.head(1) 

684 # FIXME: are we potentially losing import res.index info? 

685 res = res.item() 

686 else: 

687 raise ValueError("Function does not reduce") 

688 result = np.empty(ngroups, dtype="O") 

689 

690 counts[label] = group.shape[0] 

691 result[label] = res 

692 

693 assert result is not None 

694 result = lib.maybe_convert_objects(result, try_float=0) 

695 # TODO: try_cast back to EA? 

696 

697 return result, counts 

698 

699 

700class BinGrouper(BaseGrouper): 

701 """ 

702 This is an internal Grouper class 

703 

704 Parameters 

705 ---------- 

706 bins : the split index of binlabels to group the item of axis 

707 binlabels : the label list 

708 filter_empty : boolean, default False 

709 mutated : boolean, default False 

710 indexer : a intp array 

711 

712 Examples 

713 -------- 

714 bins: [2, 4, 6, 8, 10] 

715 binlabels: DatetimeIndex(['2005-01-01', '2005-01-03', 

716 '2005-01-05', '2005-01-07', '2005-01-09'], 

717 dtype='datetime64[ns]', freq='2D') 

718 

719 the group_info, which contains the label of each item in grouped 

720 axis, the index of label in label list, group number, is 

721 

722 (array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]), array([0, 1, 2, 3, 4]), 5) 

723 

724 means that, the grouped axis has 10 items, can be grouped into 5 

725 labels, the first and second items belong to the first label, the 

726 third and forth items belong to the second label, and so on 

727 

728 """ 

729 

730 def __init__( 

731 self, 

732 bins, 

733 binlabels, 

734 filter_empty: bool = False, 

735 mutated: bool = False, 

736 indexer=None, 

737 ): 

738 self.bins = ensure_int64(bins) 

739 self.binlabels = ensure_index(binlabels) 

740 self._filter_empty_groups = filter_empty 

741 self.mutated = mutated 

742 self.indexer = indexer 

743 

744 # These lengths must match, otherwise we could call agg_series 

745 # with empty self.bins, which would raise in libreduction. 

746 assert len(self.binlabels) == len(self.bins) 

747 

748 @cache_readonly 

749 def groups(self): 

750 """ dict {group name -> group labels} """ 

751 

752 # this is mainly for compat 

753 # GH 3881 

754 result = { 

755 key: value 

756 for key, value in zip(self.binlabels, self.bins) 

757 if key is not NaT 

758 } 

759 return result 

760 

761 @property 

762 def nkeys(self) -> int: 

763 return 1 

764 

765 def _get_grouper(self): 

766 """ 

767 We are a grouper as part of another's groupings. 

768 

769 We have a specific method of grouping, so cannot 

770 convert to a Index for our grouper. 

771 """ 

772 return self 

773 

774 def get_iterator(self, data: FrameOrSeries, axis: int = 0): 

775 """ 

776 Groupby iterator 

777 

778 Returns 

779 ------- 

780 Generator yielding sequence of (name, subsetted object) 

781 for each group 

782 """ 

783 slicer = lambda start, edge: data._slice(slice(start, edge), axis=axis) 

784 length = len(data.axes[axis]) 

785 

786 start = 0 

787 for edge, label in zip(self.bins, self.binlabels): 

788 if label is not NaT: 

789 yield label, slicer(start, edge) 

790 start = edge 

791 

792 if start < length: 

793 yield self.binlabels[-1], slicer(start, None) 

794 

795 @cache_readonly 

796 def indices(self): 

797 indices = collections.defaultdict(list) 

798 

799 i = 0 

800 for label, bin in zip(self.binlabels, self.bins): 

801 if i < bin: 

802 if label is not NaT: 

803 indices[label] = list(range(i, bin)) 

804 i = bin 

805 return indices 

806 

807 @cache_readonly 

808 def group_info(self): 

809 ngroups = self.ngroups 

810 obs_group_ids = np.arange(ngroups) 

811 rep = np.diff(np.r_[0, self.bins]) 

812 

813 rep = ensure_platform_int(rep) 

814 if ngroups == len(self.bins): 

815 comp_ids = np.repeat(np.arange(ngroups), rep) 

816 else: 

817 comp_ids = np.repeat(np.r_[-1, np.arange(ngroups)], rep) 

818 

819 return ( 

820 comp_ids.astype("int64", copy=False), 

821 obs_group_ids.astype("int64", copy=False), 

822 ngroups, 

823 ) 

824 

825 @cache_readonly 

826 def reconstructed_codes(self) -> List[np.ndarray]: 

827 # get unique result indices, and prepend 0 as groupby starts from the first 

828 return [np.r_[0, np.flatnonzero(self.bins[1:] != self.bins[:-1]) + 1]] 

829 

830 @cache_readonly 

831 def result_index(self): 

832 if len(self.binlabels) != 0 and isna(self.binlabels[0]): 

833 return self.binlabels[1:] 

834 

835 return self.binlabels 

836 

837 @property 

838 def levels(self): 

839 return [self.binlabels] 

840 

841 @property 

842 def names(self): 

843 return [self.binlabels.name] 

844 

845 @property 

846 def groupings(self) -> "List[grouper.Grouping]": 

847 return [ 

848 grouper.Grouping(lvl, lvl, in_axis=False, level=None, name=name) 

849 for lvl, name in zip(self.levels, self.names) 

850 ] 

851 

852 def agg_series(self, obj: Series, func): 

853 # Caller is responsible for checking ngroups != 0 

854 assert self.ngroups != 0 

855 assert len(self.bins) > 0 # otherwise we'd get IndexError in get_result 

856 

857 if is_extension_array_dtype(obj.dtype): 

858 # pre-empt SeriesBinGrouper from raising TypeError 

859 return self._aggregate_series_pure_python(obj, func) 

860 

861 dummy = obj[:0] 

862 grouper = libreduction.SeriesBinGrouper(obj, func, self.bins, dummy) 

863 return grouper.get_result() 

864 

865 

866def _is_indexed_like(obj, axes) -> bool: 

867 if isinstance(obj, Series): 

868 if len(axes) > 1: 

869 return False 

870 return obj.index.equals(axes[0]) 

871 elif isinstance(obj, DataFrame): 

872 return obj.index.equals(axes[0]) 

873 

874 return False 

875 

876 

877# ---------------------------------------------------------------------- 

878# Splitting / application 

879 

880 

881class DataSplitter: 

882 def __init__(self, data: FrameOrSeries, labels, ngroups: int, axis: int = 0): 

883 self.data = data 

884 self.labels = ensure_int64(labels) 

885 self.ngroups = ngroups 

886 

887 self.axis = axis 

888 assert isinstance(axis, int), axis 

889 

890 @cache_readonly 

891 def slabels(self): 

892 # Sorted labels 

893 return algorithms.take_nd(self.labels, self.sort_idx, allow_fill=False) 

894 

895 @cache_readonly 

896 def sort_idx(self): 

897 # Counting sort indexer 

898 return get_group_index_sorter(self.labels, self.ngroups) 

899 

900 def __iter__(self): 

901 sdata = self._get_sorted_data() 

902 

903 if self.ngroups == 0: 

904 # we are inside a generator, rather than raise StopIteration 

905 # we merely return signal the end 

906 return 

907 

908 starts, ends = lib.generate_slices(self.slabels, self.ngroups) 

909 

910 for i, (start, end) in enumerate(zip(starts, ends)): 

911 yield i, self._chop(sdata, slice(start, end)) 

912 

913 def _get_sorted_data(self) -> FrameOrSeries: 

914 return self.data.take(self.sort_idx, axis=self.axis) 

915 

916 def _chop(self, sdata, slice_obj: slice) -> NDFrame: 

917 raise AbstractMethodError(self) 

918 

919 

920class SeriesSplitter(DataSplitter): 

921 def _chop(self, sdata: Series, slice_obj: slice) -> Series: 

922 return sdata._get_values(slice_obj) 

923 

924 

925class FrameSplitter(DataSplitter): 

926 def fast_apply(self, f, names): 

927 # must return keys::list, values::list, mutated::bool 

928 starts, ends = lib.generate_slices(self.slabels, self.ngroups) 

929 

930 sdata = self._get_sorted_data() 

931 return libreduction.apply_frame_axis0(sdata, f, names, starts, ends) 

932 

933 def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: 

934 if self.axis == 0: 

935 return sdata.iloc[slice_obj] 

936 else: 

937 return sdata._slice(slice_obj, axis=1) 

938 

939 

940def get_splitter(data: FrameOrSeries, *args, **kwargs) -> DataSplitter: 

941 if isinstance(data, Series): 

942 klass: Type[DataSplitter] = SeriesSplitter 

943 else: 

944 # i.e. DataFrame 

945 klass = FrameSplitter 

946 

947 return klass(data, *args, **kwargs)