Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Operator classes for eval. 

2""" 

3 

4from datetime import datetime 

5from distutils.version import LooseVersion 

6from functools import partial 

7import operator 

8 

9import numpy as np 

10 

11from pandas._libs.tslibs import Timestamp 

12 

13from pandas.core.dtypes.common import is_list_like, is_scalar 

14 

15import pandas.core.common as com 

16from pandas.core.computation.common import _ensure_decoded, result_type_many 

17from pandas.core.computation.scope import _DEFAULT_GLOBALS 

18 

19from pandas.io.formats.printing import pprint_thing, pprint_thing_encoded 

20 

21_reductions = ("sum", "prod") 

22 

23_unary_math_ops = ( 

24 "sin", 

25 "cos", 

26 "exp", 

27 "log", 

28 "expm1", 

29 "log1p", 

30 "sqrt", 

31 "sinh", 

32 "cosh", 

33 "tanh", 

34 "arcsin", 

35 "arccos", 

36 "arctan", 

37 "arccosh", 

38 "arcsinh", 

39 "arctanh", 

40 "abs", 

41 "log10", 

42 "floor", 

43 "ceil", 

44) 

45_binary_math_ops = ("arctan2",) 

46 

47_mathops = _unary_math_ops + _binary_math_ops 

48 

49 

50_LOCAL_TAG = "__pd_eval_local_" 

51 

52 

53class UndefinedVariableError(NameError): 

54 """ 

55 NameError subclass for local variables. 

56 """ 

57 

58 def __init__(self, name, is_local: bool): 

59 base_msg = f"{repr(name)} is not defined" 

60 if is_local: 

61 msg = f"local variable {base_msg}" 

62 else: 

63 msg = f"name {base_msg}" 

64 super().__init__(msg) 

65 

66 

67class Term: 

68 def __new__(cls, name, env, side=None, encoding=None): 

69 klass = Constant if not isinstance(name, str) else cls 

70 supr_new = super(Term, klass).__new__ 

71 return supr_new(klass) 

72 

73 is_local: bool 

74 

75 def __init__(self, name, env, side=None, encoding=None): 

76 # name is a str for Term, but may be something else for subclasses 

77 self._name = name 

78 self.env = env 

79 self.side = side 

80 tname = str(name) 

81 self.is_local = tname.startswith(_LOCAL_TAG) or tname in _DEFAULT_GLOBALS 

82 self._value = self._resolve_name() 

83 self.encoding = encoding 

84 

85 @property 

86 def local_name(self) -> str: 

87 return self.name.replace(_LOCAL_TAG, "") 

88 

89 def __repr__(self) -> str: 

90 return pprint_thing(self.name) 

91 

92 def __call__(self, *args, **kwargs): 

93 return self.value 

94 

95 def evaluate(self, *args, **kwargs): 

96 return self 

97 

98 def _resolve_name(self): 

99 res = self.env.resolve(self.local_name, is_local=self.is_local) 

100 self.update(res) 

101 

102 if hasattr(res, "ndim") and res.ndim > 2: 

103 raise NotImplementedError( 

104 "N-dimensional objects, where N > 2, are not supported with eval" 

105 ) 

106 return res 

107 

108 def update(self, value): 

109 """ 

110 search order for local (i.e., @variable) variables: 

111 

112 scope, key_variable 

113 [('locals', 'local_name'), 

114 ('globals', 'local_name'), 

115 ('locals', 'key'), 

116 ('globals', 'key')] 

117 """ 

118 key = self.name 

119 

120 # if it's a variable name (otherwise a constant) 

121 if isinstance(key, str): 

122 self.env.swapkey(self.local_name, key, new_value=value) 

123 

124 self.value = value 

125 

126 @property 

127 def is_scalar(self) -> bool: 

128 return is_scalar(self._value) 

129 

130 @property 

131 def type(self): 

132 try: 

133 # potentially very slow for large, mixed dtype frames 

134 return self._value.values.dtype 

135 except AttributeError: 

136 try: 

137 # ndarray 

138 return self._value.dtype 

139 except AttributeError: 

140 # scalar 

141 return type(self._value) 

142 

143 return_type = type 

144 

145 @property 

146 def raw(self) -> str: 

147 return f"{type(self).__name__}(name={repr(self.name)}, type={self.type})" 

148 

149 @property 

150 def is_datetime(self) -> bool: 

151 try: 

152 t = self.type.type 

153 except AttributeError: 

154 t = self.type 

155 

156 return issubclass(t, (datetime, np.datetime64)) 

157 

158 @property 

159 def value(self): 

160 return self._value 

161 

162 @value.setter 

163 def value(self, new_value): 

164 self._value = new_value 

165 

166 @property 

167 def name(self): 

168 return self._name 

169 

170 @property 

171 def ndim(self) -> int: 

172 return self._value.ndim 

173 

174 

175class Constant(Term): 

176 def __init__(self, value, env, side=None, encoding=None): 

177 super().__init__(value, env, side=side, encoding=encoding) 

178 

179 def _resolve_name(self): 

180 return self._name 

181 

182 @property 

183 def name(self): 

184 return self.value 

185 

186 def __repr__(self) -> str: 

187 # in python 2 str() of float 

188 # can truncate shorter than repr() 

189 return repr(self.name) 

190 

191 

192_bool_op_map = {"not": "~", "and": "&", "or": "|"} 

193 

194 

195class Op: 

196 """ 

197 Hold an operator of arbitrary arity. 

198 """ 

199 

200 op: str 

201 

202 def __init__(self, op: str, operands, *args, **kwargs): 

203 self.op = _bool_op_map.get(op, op) 

204 self.operands = operands 

205 self.encoding = kwargs.get("encoding", None) 

206 

207 def __iter__(self): 

208 return iter(self.operands) 

209 

210 def __repr__(self) -> str: 

211 """ 

212 Print a generic n-ary operator and its operands using infix notation. 

213 """ 

214 # recurse over the operands 

215 parened = (f"({pprint_thing(opr)})" for opr in self.operands) 

216 return pprint_thing(f" {self.op} ".join(parened)) 

217 

218 @property 

219 def return_type(self): 

220 # clobber types to bool if the op is a boolean operator 

221 if self.op in (_cmp_ops_syms + _bool_ops_syms): 

222 return np.bool_ 

223 return result_type_many(*(term.type for term in com.flatten(self))) 

224 

225 @property 

226 def has_invalid_return_type(self) -> bool: 

227 types = self.operand_types 

228 obj_dtype_set = frozenset([np.dtype("object")]) 

229 return self.return_type == object and types - obj_dtype_set 

230 

231 @property 

232 def operand_types(self): 

233 return frozenset(term.type for term in com.flatten(self)) 

234 

235 @property 

236 def is_scalar(self) -> bool: 

237 return all(operand.is_scalar for operand in self.operands) 

238 

239 @property 

240 def is_datetime(self) -> bool: 

241 try: 

242 t = self.return_type.type 

243 except AttributeError: 

244 t = self.return_type 

245 

246 return issubclass(t, (datetime, np.datetime64)) 

247 

248 

249def _in(x, y): 

250 """Compute the vectorized membership of ``x in y`` if possible, otherwise 

251 use Python. 

252 """ 

253 try: 

254 return x.isin(y) 

255 except AttributeError: 

256 if is_list_like(x): 

257 try: 

258 return y.isin(x) 

259 except AttributeError: 

260 pass 

261 return x in y 

262 

263 

264def _not_in(x, y): 

265 """Compute the vectorized membership of ``x not in y`` if possible, 

266 otherwise use Python. 

267 """ 

268 try: 

269 return ~x.isin(y) 

270 except AttributeError: 

271 if is_list_like(x): 

272 try: 

273 return ~y.isin(x) 

274 except AttributeError: 

275 pass 

276 return x not in y 

277 

278 

279_cmp_ops_syms = (">", "<", ">=", "<=", "==", "!=", "in", "not in") 

280_cmp_ops_funcs = ( 

281 operator.gt, 

282 operator.lt, 

283 operator.ge, 

284 operator.le, 

285 operator.eq, 

286 operator.ne, 

287 _in, 

288 _not_in, 

289) 

290_cmp_ops_dict = dict(zip(_cmp_ops_syms, _cmp_ops_funcs)) 

291 

292_bool_ops_syms = ("&", "|", "and", "or") 

293_bool_ops_funcs = (operator.and_, operator.or_, operator.and_, operator.or_) 

294_bool_ops_dict = dict(zip(_bool_ops_syms, _bool_ops_funcs)) 

295 

296_arith_ops_syms = ("+", "-", "*", "/", "**", "//", "%") 

297_arith_ops_funcs = ( 

298 operator.add, 

299 operator.sub, 

300 operator.mul, 

301 operator.truediv, 

302 operator.pow, 

303 operator.floordiv, 

304 operator.mod, 

305) 

306_arith_ops_dict = dict(zip(_arith_ops_syms, _arith_ops_funcs)) 

307 

308_special_case_arith_ops_syms = ("**", "//", "%") 

309_special_case_arith_ops_funcs = (operator.pow, operator.floordiv, operator.mod) 

310_special_case_arith_ops_dict = dict( 

311 zip(_special_case_arith_ops_syms, _special_case_arith_ops_funcs) 

312) 

313 

314_binary_ops_dict = {} 

315 

316for d in (_cmp_ops_dict, _bool_ops_dict, _arith_ops_dict): 

317 _binary_ops_dict.update(d) 

318 

319 

320def _cast_inplace(terms, acceptable_dtypes, dtype): 

321 """ 

322 Cast an expression inplace. 

323 

324 Parameters 

325 ---------- 

326 terms : Op 

327 The expression that should cast. 

328 acceptable_dtypes : list of acceptable numpy.dtype 

329 Will not cast if term's dtype in this list. 

330 dtype : str or numpy.dtype 

331 The dtype to cast to. 

332 """ 

333 dt = np.dtype(dtype) 

334 for term in terms: 

335 if term.type in acceptable_dtypes: 

336 continue 

337 

338 try: 

339 new_value = term.value.astype(dt) 

340 except AttributeError: 

341 new_value = dt.type(term.value) 

342 term.update(new_value) 

343 

344 

345def is_term(obj) -> bool: 

346 return isinstance(obj, Term) 

347 

348 

349class BinOp(Op): 

350 """ 

351 Hold a binary operator and its operands. 

352 

353 Parameters 

354 ---------- 

355 op : str 

356 left : Term or Op 

357 right : Term or Op 

358 """ 

359 

360 def __init__(self, op: str, lhs, rhs, **kwargs): 

361 super().__init__(op, (lhs, rhs)) 

362 self.lhs = lhs 

363 self.rhs = rhs 

364 

365 self._disallow_scalar_only_bool_ops() 

366 

367 self.convert_values() 

368 

369 try: 

370 self.func = _binary_ops_dict[op] 

371 except KeyError: 

372 # has to be made a list for python3 

373 keys = list(_binary_ops_dict.keys()) 

374 raise ValueError( 

375 f"Invalid binary operator {repr(op)}, valid operators are {keys}" 

376 ) 

377 

378 def __call__(self, env): 

379 """ 

380 Recursively evaluate an expression in Python space. 

381 

382 Parameters 

383 ---------- 

384 env : Scope 

385 

386 Returns 

387 ------- 

388 object 

389 The result of an evaluated expression. 

390 """ 

391 

392 # recurse over the left/right nodes 

393 left = self.lhs(env) 

394 right = self.rhs(env) 

395 

396 return self.func(left, right) 

397 

398 def evaluate(self, env, engine: str, parser, term_type, eval_in_python): 

399 """ 

400 Evaluate a binary operation *before* being passed to the engine. 

401 

402 Parameters 

403 ---------- 

404 env : Scope 

405 engine : str 

406 parser : str 

407 term_type : type 

408 eval_in_python : list 

409 

410 Returns 

411 ------- 

412 term_type 

413 The "pre-evaluated" expression as an instance of ``term_type`` 

414 """ 

415 if engine == "python": 

416 res = self(env) 

417 else: 

418 # recurse over the left/right nodes 

419 left = self.lhs.evaluate( 

420 env, 

421 engine=engine, 

422 parser=parser, 

423 term_type=term_type, 

424 eval_in_python=eval_in_python, 

425 ) 

426 right = self.rhs.evaluate( 

427 env, 

428 engine=engine, 

429 parser=parser, 

430 term_type=term_type, 

431 eval_in_python=eval_in_python, 

432 ) 

433 

434 # base cases 

435 if self.op in eval_in_python: 

436 res = self.func(left.value, right.value) 

437 else: 

438 from pandas.core.computation.eval import eval 

439 

440 res = eval(self, local_dict=env, engine=engine, parser=parser) 

441 

442 name = env.add_tmp(res) 

443 return term_type(name, env=env) 

444 

445 def convert_values(self): 

446 """Convert datetimes to a comparable value in an expression. 

447 """ 

448 

449 def stringify(value): 

450 if self.encoding is not None: 

451 encoder = partial(pprint_thing_encoded, encoding=self.encoding) 

452 else: 

453 encoder = pprint_thing 

454 return encoder(value) 

455 

456 lhs, rhs = self.lhs, self.rhs 

457 

458 if is_term(lhs) and lhs.is_datetime and is_term(rhs) and rhs.is_scalar: 

459 v = rhs.value 

460 if isinstance(v, (int, float)): 

461 v = stringify(v) 

462 v = Timestamp(_ensure_decoded(v)) 

463 if v.tz is not None: 

464 v = v.tz_convert("UTC") 

465 self.rhs.update(v) 

466 

467 if is_term(rhs) and rhs.is_datetime and is_term(lhs) and lhs.is_scalar: 

468 v = lhs.value 

469 if isinstance(v, (int, float)): 

470 v = stringify(v) 

471 v = Timestamp(_ensure_decoded(v)) 

472 if v.tz is not None: 

473 v = v.tz_convert("UTC") 

474 self.lhs.update(v) 

475 

476 def _disallow_scalar_only_bool_ops(self): 

477 if ( 

478 (self.lhs.is_scalar or self.rhs.is_scalar) 

479 and self.op in _bool_ops_dict 

480 and ( 

481 not ( 

482 issubclass(self.rhs.return_type, (bool, np.bool_)) 

483 and issubclass(self.lhs.return_type, (bool, np.bool_)) 

484 ) 

485 ) 

486 ): 

487 raise NotImplementedError("cannot evaluate scalar only bool ops") 

488 

489 

490def isnumeric(dtype) -> bool: 

491 return issubclass(np.dtype(dtype).type, np.number) 

492 

493 

494class Div(BinOp): 

495 """ 

496 Div operator to special case casting. 

497 

498 Parameters 

499 ---------- 

500 lhs, rhs : Term or Op 

501 The Terms or Ops in the ``/`` expression. 

502 """ 

503 

504 def __init__(self, lhs, rhs, **kwargs): 

505 super().__init__("/", lhs, rhs, **kwargs) 

506 

507 if not isnumeric(lhs.return_type) or not isnumeric(rhs.return_type): 

508 raise TypeError( 

509 f"unsupported operand type(s) for {self.op}: " 

510 f"'{lhs.return_type}' and '{rhs.return_type}'" 

511 ) 

512 

513 # do not upcast float32s to float64 un-necessarily 

514 acceptable_dtypes = [np.float32, np.float_] 

515 _cast_inplace(com.flatten(self), acceptable_dtypes, np.float_) 

516 

517 

518_unary_ops_syms = ("+", "-", "~", "not") 

519_unary_ops_funcs = (operator.pos, operator.neg, operator.invert, operator.invert) 

520_unary_ops_dict = dict(zip(_unary_ops_syms, _unary_ops_funcs)) 

521 

522 

523class UnaryOp(Op): 

524 """ 

525 Hold a unary operator and its operands. 

526 

527 Parameters 

528 ---------- 

529 op : str 

530 The token used to represent the operator. 

531 operand : Term or Op 

532 The Term or Op operand to the operator. 

533 

534 Raises 

535 ------ 

536 ValueError 

537 * If no function associated with the passed operator token is found. 

538 """ 

539 

540 def __init__(self, op: str, operand): 

541 super().__init__(op, (operand,)) 

542 self.operand = operand 

543 

544 try: 

545 self.func = _unary_ops_dict[op] 

546 except KeyError: 

547 raise ValueError( 

548 f"Invalid unary operator {repr(op)}, " 

549 f"valid operators are {_unary_ops_syms}" 

550 ) 

551 

552 def __call__(self, env): 

553 operand = self.operand(env) 

554 return self.func(operand) 

555 

556 def __repr__(self) -> str: 

557 return pprint_thing(f"{self.op}({self.operand})") 

558 

559 @property 

560 def return_type(self) -> np.dtype: 

561 operand = self.operand 

562 if operand.return_type == np.dtype("bool"): 

563 return np.dtype("bool") 

564 if isinstance(operand, Op) and ( 

565 operand.op in _cmp_ops_dict or operand.op in _bool_ops_dict 

566 ): 

567 return np.dtype("bool") 

568 return np.dtype("int") 

569 

570 

571class MathCall(Op): 

572 def __init__(self, func, args): 

573 super().__init__(func.name, args) 

574 self.func = func 

575 

576 def __call__(self, env): 

577 operands = [op(env) for op in self.operands] 

578 with np.errstate(all="ignore"): 

579 return self.func.func(*operands) 

580 

581 def __repr__(self) -> str: 

582 operands = map(str, self.operands) 

583 return pprint_thing(f"{self.op}({','.join(operands)})") 

584 

585 

586class FuncNode: 

587 def __init__(self, name: str): 

588 from pandas.core.computation.check import _NUMEXPR_INSTALLED, _NUMEXPR_VERSION 

589 

590 if name not in _mathops or ( 

591 _NUMEXPR_INSTALLED 

592 and _NUMEXPR_VERSION < LooseVersion("2.6.9") 

593 and name in ("floor", "ceil") 

594 ): 

595 raise ValueError(f'"{name}" is not a supported function') 

596 

597 self.name = name 

598 self.func = getattr(np, name) 

599 

600 def __call__(self, *args): 

601 return MathCall(self, args)