Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Copyright Anne M. Archibald 2008 

2# Released under the scipy license 

3import numpy as np 

4from heapq import heappush, heappop 

5import scipy.sparse 

6 

7__all__ = ['minkowski_distance_p', 'minkowski_distance', 

8 'distance_matrix', 

9 'Rectangle', 'KDTree'] 

10 

11 

12def minkowski_distance_p(x, y, p=2): 

13 """ 

14 Compute the pth power of the L**p distance between two arrays. 

15 

16 For efficiency, this function computes the L**p distance but does 

17 not extract the pth root. If `p` is 1 or infinity, this is equal to 

18 the actual L**p distance. 

19 

20 Parameters 

21 ---------- 

22 x : (M, K) array_like 

23 Input array. 

24 y : (N, K) array_like 

25 Input array. 

26 p : float, 1 <= p <= infinity 

27 Which Minkowski p-norm to use. 

28 

29 Examples 

30 -------- 

31 >>> from scipy.spatial import minkowski_distance_p 

32 >>> minkowski_distance_p([[0,0],[0,0]], [[1,1],[0,1]]) 

33 array([2, 1]) 

34 

35 """ 

36 x = np.asarray(x) 

37 y = np.asarray(y) 

38 

39 # Find smallest common datatype with float64 (return type of this function) - addresses #10262. 

40 # Don't just cast to float64 for complex input case. 

41 common_datatype = np.promote_types(np.promote_types(x.dtype, y.dtype), 'float64') 

42 

43 # Make sure x and y are NumPy arrays of correct datatype. 

44 x = x.astype(common_datatype) 

45 y = y.astype(common_datatype) 

46 

47 if p == np.inf: 

48 return np.amax(np.abs(y-x), axis=-1) 

49 elif p == 1: 

50 return np.sum(np.abs(y-x), axis=-1) 

51 else: 

52 return np.sum(np.abs(y-x)**p, axis=-1) 

53 

54 

55def minkowski_distance(x, y, p=2): 

56 """ 

57 Compute the L**p distance between two arrays. 

58 

59 Parameters 

60 ---------- 

61 x : (M, K) array_like 

62 Input array. 

63 y : (N, K) array_like 

64 Input array. 

65 p : float, 1 <= p <= infinity 

66 Which Minkowski p-norm to use. 

67 

68 Examples 

69 -------- 

70 >>> from scipy.spatial import minkowski_distance 

71 >>> minkowski_distance([[0,0],[0,0]], [[1,1],[0,1]]) 

72 array([ 1.41421356, 1. ]) 

73 

74 """ 

75 x = np.asarray(x) 

76 y = np.asarray(y) 

77 if p == np.inf or p == 1: 

78 return minkowski_distance_p(x, y, p) 

79 else: 

80 return minkowski_distance_p(x, y, p)**(1./p) 

81 

82 

83class Rectangle(object): 

84 """Hyperrectangle class. 

85 

86 Represents a Cartesian product of intervals. 

87 """ 

88 def __init__(self, maxes, mins): 

89 """Construct a hyperrectangle.""" 

90 self.maxes = np.maximum(maxes,mins).astype(float) 

91 self.mins = np.minimum(maxes,mins).astype(float) 

92 self.m, = self.maxes.shape 

93 

94 def __repr__(self): 

95 return "<Rectangle %s>" % list(zip(self.mins, self.maxes)) 

96 

97 def volume(self): 

98 """Total volume.""" 

99 return np.prod(self.maxes-self.mins) 

100 

101 def split(self, d, split): 

102 """ 

103 Produce two hyperrectangles by splitting. 

104 

105 In general, if you need to compute maximum and minimum 

106 distances to the children, it can be done more efficiently 

107 by updating the maximum and minimum distances to the parent. 

108 

109 Parameters 

110 ---------- 

111 d : int 

112 Axis to split hyperrectangle along. 

113 split : float 

114 Position along axis `d` to split at. 

115 

116 """ 

117 mid = np.copy(self.maxes) 

118 mid[d] = split 

119 less = Rectangle(self.mins, mid) 

120 mid = np.copy(self.mins) 

121 mid[d] = split 

122 greater = Rectangle(mid, self.maxes) 

123 return less, greater 

124 

125 def min_distance_point(self, x, p=2.): 

126 """ 

127 Return the minimum distance between input and points in the hyperrectangle. 

128 

129 Parameters 

130 ---------- 

131 x : array_like 

132 Input. 

133 p : float, optional 

134 Input. 

135 

136 """ 

137 return minkowski_distance(0, np.maximum(0,np.maximum(self.mins-x,x-self.maxes)),p) 

138 

139 def max_distance_point(self, x, p=2.): 

140 """ 

141 Return the maximum distance between input and points in the hyperrectangle. 

142 

143 Parameters 

144 ---------- 

145 x : array_like 

146 Input array. 

147 p : float, optional 

148 Input. 

149 

150 """ 

151 return minkowski_distance(0, np.maximum(self.maxes-x,x-self.mins),p) 

152 

153 def min_distance_rectangle(self, other, p=2.): 

154 """ 

155 Compute the minimum distance between points in the two hyperrectangles. 

156 

157 Parameters 

158 ---------- 

159 other : hyperrectangle 

160 Input. 

161 p : float 

162 Input. 

163 

164 """ 

165 return minkowski_distance(0, np.maximum(0,np.maximum(self.mins-other.maxes,other.mins-self.maxes)),p) 

166 

167 def max_distance_rectangle(self, other, p=2.): 

168 """ 

169 Compute the maximum distance between points in the two hyperrectangles. 

170 

171 Parameters 

172 ---------- 

173 other : hyperrectangle 

174 Input. 

175 p : float, optional 

176 Input. 

177 

178 """ 

179 return minkowski_distance(0, np.maximum(self.maxes-other.mins,other.maxes-self.mins),p) 

180 

181 

182class KDTree(object): 

183 """ 

184 kd-tree for quick nearest-neighbor lookup 

185 

186 This class provides an index into a set of k-D points which 

187 can be used to rapidly look up the nearest neighbors of any point. 

188 

189 Parameters 

190 ---------- 

191 data : (N,K) array_like 

192 The data points to be indexed. This array is not copied, and 

193 so modifying this data will result in bogus results. 

194 leafsize : int, optional 

195 The number of points at which the algorithm switches over to 

196 brute-force. Has to be positive. 

197 

198 Raises 

199 ------ 

200 RuntimeError 

201 The maximum recursion limit can be exceeded for large data 

202 sets. If this happens, either increase the value for the `leafsize` 

203 parameter or increase the recursion limit by:: 

204 

205 >>> import sys 

206 >>> sys.setrecursionlimit(10000) 

207 

208 See Also 

209 -------- 

210 cKDTree : Implementation of `KDTree` in Cython 

211 

212 Notes 

213 ----- 

214 The algorithm used is described in Maneewongvatana and Mount 1999. 

215 The general idea is that the kd-tree is a binary tree, each of whose 

216 nodes represents an axis-aligned hyperrectangle. Each node specifies 

217 an axis and splits the set of points based on whether their coordinate 

218 along that axis is greater than or less than a particular value. 

219 

220 During construction, the axis and splitting point are chosen by the 

221 "sliding midpoint" rule, which ensures that the cells do not all 

222 become long and thin. 

223 

224 The tree can be queried for the r closest neighbors of any given point 

225 (optionally returning only those within some maximum distance of the 

226 point). It can also be queried, with a substantial gain in efficiency, 

227 for the r approximate closest neighbors. 

228 

229 For large dimensions (20 is already large) do not expect this to run 

230 significantly faster than brute force. High-dimensional nearest-neighbor 

231 queries are a substantial open problem in computer science. 

232 

233 The tree also supports all-neighbors queries, both with arrays of points 

234 and with other kd-trees. These do use a reasonably efficient algorithm, 

235 but the kd-tree is not necessarily the best data structure for this 

236 sort of calculation. 

237 

238 """ 

239 def __init__(self, data, leafsize=10): 

240 self.data = np.asarray(data) 

241 self.n, self.m = np.shape(self.data) 

242 self.leafsize = int(leafsize) 

243 if self.leafsize < 1: 

244 raise ValueError("leafsize must be at least 1") 

245 self.maxes = np.amax(self.data,axis=0) 

246 self.mins = np.amin(self.data,axis=0) 

247 

248 self.tree = self.__build(np.arange(self.n), self.maxes, self.mins) 

249 

250 class node(object): 

251 def __lt__(self, other): 

252 return id(self) < id(other) 

253 

254 def __gt__(self, other): 

255 return id(self) > id(other) 

256 

257 def __le__(self, other): 

258 return id(self) <= id(other) 

259 

260 def __ge__(self, other): 

261 return id(self) >= id(other) 

262 

263 def __eq__(self, other): 

264 return id(self) == id(other) 

265 

266 class leafnode(node): 

267 def __init__(self, idx): 

268 self.idx = idx 

269 self.children = len(idx) 

270 

271 class innernode(node): 

272 def __init__(self, split_dim, split, less, greater): 

273 self.split_dim = split_dim 

274 self.split = split 

275 self.less = less 

276 self.greater = greater 

277 self.children = less.children+greater.children 

278 

279 def __build(self, idx, maxes, mins): 

280 if len(idx) <= self.leafsize: 

281 return KDTree.leafnode(idx) 

282 else: 

283 data = self.data[idx] 

284 # maxes = np.amax(data,axis=0) 

285 # mins = np.amin(data,axis=0) 

286 d = np.argmax(maxes-mins) 

287 maxval = maxes[d] 

288 minval = mins[d] 

289 if maxval == minval: 

290 # all points are identical; warn user? 

291 return KDTree.leafnode(idx) 

292 data = data[:,d] 

293 

294 # sliding midpoint rule; see Maneewongvatana and Mount 1999 

295 # for arguments that this is a good idea. 

296 split = (maxval+minval)/2 

297 less_idx = np.nonzero(data <= split)[0] 

298 greater_idx = np.nonzero(data > split)[0] 

299 if len(less_idx) == 0: 

300 split = np.amin(data) 

301 less_idx = np.nonzero(data <= split)[0] 

302 greater_idx = np.nonzero(data > split)[0] 

303 if len(greater_idx) == 0: 

304 split = np.amax(data) 

305 less_idx = np.nonzero(data < split)[0] 

306 greater_idx = np.nonzero(data >= split)[0] 

307 if len(less_idx) == 0: 

308 # _still_ zero? all must have the same value 

309 if not np.all(data == data[0]): 

310 raise ValueError("Troublesome data array: %s" % data) 

311 split = data[0] 

312 less_idx = np.arange(len(data)-1) 

313 greater_idx = np.array([len(data)-1]) 

314 

315 lessmaxes = np.copy(maxes) 

316 lessmaxes[d] = split 

317 greatermins = np.copy(mins) 

318 greatermins[d] = split 

319 return KDTree.innernode(d, split, 

320 self.__build(idx[less_idx],lessmaxes,mins), 

321 self.__build(idx[greater_idx],maxes,greatermins)) 

322 

323 def __query(self, x, k=1, eps=0, p=2, distance_upper_bound=np.inf): 

324 

325 side_distances = np.maximum(0,np.maximum(x-self.maxes,self.mins-x)) 

326 if p != np.inf: 

327 side_distances **= p 

328 min_distance = np.sum(side_distances) 

329 else: 

330 min_distance = np.amax(side_distances) 

331 

332 # priority queue for chasing nodes 

333 # entries are: 

334 # minimum distance between the cell and the target 

335 # distances between the nearest side of the cell and the target 

336 # the head node of the cell 

337 q = [(min_distance, 

338 tuple(side_distances), 

339 self.tree)] 

340 # priority queue for the nearest neighbors 

341 # furthest known neighbor first 

342 # entries are (-distance**p, i) 

343 neighbors = [] 

344 

345 if eps == 0: 

346 epsfac = 1 

347 elif p == np.inf: 

348 epsfac = 1/(1+eps) 

349 else: 

350 epsfac = 1/(1+eps)**p 

351 

352 if p != np.inf and distance_upper_bound != np.inf: 

353 distance_upper_bound = distance_upper_bound**p 

354 

355 while q: 

356 min_distance, side_distances, node = heappop(q) 

357 if isinstance(node, KDTree.leafnode): 

358 # brute-force 

359 data = self.data[node.idx] 

360 ds = minkowski_distance_p(data,x[np.newaxis,:],p) 

361 for i in range(len(ds)): 

362 if ds[i] < distance_upper_bound: 

363 if len(neighbors) == k: 

364 heappop(neighbors) 

365 heappush(neighbors, (-ds[i], node.idx[i])) 

366 if len(neighbors) == k: 

367 distance_upper_bound = -neighbors[0][0] 

368 else: 

369 # we don't push cells that are too far onto the queue at all, 

370 # but since the distance_upper_bound decreases, we might get 

371 # here even if the cell's too far 

372 if min_distance > distance_upper_bound*epsfac: 

373 # since this is the nearest cell, we're done, bail out 

374 break 

375 # compute minimum distances to the children and push them on 

376 if x[node.split_dim] < node.split: 

377 near, far = node.less, node.greater 

378 else: 

379 near, far = node.greater, node.less 

380 

381 # near child is at the same distance as the current node 

382 heappush(q,(min_distance, side_distances, near)) 

383 

384 # far child is further by an amount depending only 

385 # on the split value 

386 sd = list(side_distances) 

387 if p == np.inf: 

388 min_distance = max(min_distance, abs(node.split-x[node.split_dim])) 

389 elif p == 1: 

390 sd[node.split_dim] = np.abs(node.split-x[node.split_dim]) 

391 min_distance = min_distance - side_distances[node.split_dim] + sd[node.split_dim] 

392 else: 

393 sd[node.split_dim] = np.abs(node.split-x[node.split_dim])**p 

394 min_distance = min_distance - side_distances[node.split_dim] + sd[node.split_dim] 

395 

396 # far child might be too far, if so, don't bother pushing it 

397 if min_distance <= distance_upper_bound*epsfac: 

398 heappush(q,(min_distance, tuple(sd), far)) 

399 

400 if p == np.inf: 

401 return sorted([(-d,i) for (d,i) in neighbors]) 

402 else: 

403 return sorted([((-d)**(1./p),i) for (d,i) in neighbors]) 

404 

405 def query(self, x, k=1, eps=0, p=2, distance_upper_bound=np.inf): 

406 """ 

407 Query the kd-tree for nearest neighbors 

408 

409 Parameters 

410 ---------- 

411 x : array_like, last dimension self.m 

412 An array of points to query. 

413 k : int, optional 

414 The number of nearest neighbors to return. 

415 eps : nonnegative float, optional 

416 Return approximate nearest neighbors; the kth returned value 

417 is guaranteed to be no further than (1+eps) times the 

418 distance to the real kth nearest neighbor. 

419 p : float, 1<=p<=infinity, optional 

420 Which Minkowski p-norm to use. 

421 1 is the sum-of-absolute-values "Manhattan" distance 

422 2 is the usual Euclidean distance 

423 infinity is the maximum-coordinate-difference distance 

424 distance_upper_bound : nonnegative float, optional 

425 Return only neighbors within this distance. This is used to prune 

426 tree searches, so if you are doing a series of nearest-neighbor 

427 queries, it may help to supply the distance to the nearest neighbor 

428 of the most recent point. 

429 

430 Returns 

431 ------- 

432 d : float or array of floats 

433 The distances to the nearest neighbors. 

434 If x has shape tuple+(self.m,), then d has shape tuple if 

435 k is one, or tuple+(k,) if k is larger than one. Missing 

436 neighbors (e.g. when k > n or distance_upper_bound is 

437 given) are indicated with infinite distances. If k is None, 

438 then d is an object array of shape tuple, containing lists 

439 of distances. In either case the hits are sorted by distance 

440 (nearest first). 

441 i : integer or array of integers 

442 The locations of the neighbors in self.data. i is the same 

443 shape as d. 

444 

445 Examples 

446 -------- 

447 >>> from scipy import spatial 

448 >>> x, y = np.mgrid[0:5, 2:8] 

449 >>> tree = spatial.KDTree(list(zip(x.ravel(), y.ravel()))) 

450 >>> tree.data 

451 array([[0, 2], 

452 [0, 3], 

453 [0, 4], 

454 [0, 5], 

455 [0, 6], 

456 [0, 7], 

457 [1, 2], 

458 [1, 3], 

459 [1, 4], 

460 [1, 5], 

461 [1, 6], 

462 [1, 7], 

463 [2, 2], 

464 [2, 3], 

465 [2, 4], 

466 [2, 5], 

467 [2, 6], 

468 [2, 7], 

469 [3, 2], 

470 [3, 3], 

471 [3, 4], 

472 [3, 5], 

473 [3, 6], 

474 [3, 7], 

475 [4, 2], 

476 [4, 3], 

477 [4, 4], 

478 [4, 5], 

479 [4, 6], 

480 [4, 7]]) 

481 >>> pts = np.array([[0, 0], [2.1, 2.9]]) 

482 >>> tree.query(pts) 

483 (array([ 2. , 0.14142136]), array([ 0, 13])) 

484 >>> tree.query(pts[0]) 

485 (2.0, 0) 

486 

487 """ 

488 x = np.asarray(x) 

489 if np.shape(x)[-1] != self.m: 

490 raise ValueError("x must consist of vectors of length %d but has shape %s" % (self.m, np.shape(x))) 

491 if p < 1: 

492 raise ValueError("Only p-norms with 1<=p<=infinity permitted") 

493 retshape = np.shape(x)[:-1] 

494 if retshape != (): 

495 if k is None: 

496 dd = np.empty(retshape,dtype=object) 

497 ii = np.empty(retshape,dtype=object) 

498 elif k > 1: 

499 dd = np.empty(retshape+(k,),dtype=float) 

500 dd.fill(np.inf) 

501 ii = np.empty(retshape+(k,),dtype=int) 

502 ii.fill(self.n) 

503 elif k == 1: 

504 dd = np.empty(retshape,dtype=float) 

505 dd.fill(np.inf) 

506 ii = np.empty(retshape,dtype=int) 

507 ii.fill(self.n) 

508 else: 

509 raise ValueError("Requested %s nearest neighbors; acceptable numbers are integers greater than or equal to one, or None") 

510 for c in np.ndindex(retshape): 

511 hits = self.__query(x[c], k=k, eps=eps, p=p, distance_upper_bound=distance_upper_bound) 

512 if k is None: 

513 dd[c] = [d for (d,i) in hits] 

514 ii[c] = [i for (d,i) in hits] 

515 elif k > 1: 

516 for j in range(len(hits)): 

517 dd[c+(j,)], ii[c+(j,)] = hits[j] 

518 elif k == 1: 

519 if len(hits) > 0: 

520 dd[c], ii[c] = hits[0] 

521 else: 

522 dd[c] = np.inf 

523 ii[c] = self.n 

524 return dd, ii 

525 else: 

526 hits = self.__query(x, k=k, eps=eps, p=p, distance_upper_bound=distance_upper_bound) 

527 if k is None: 

528 return [d for (d,i) in hits], [i for (d,i) in hits] 

529 elif k == 1: 

530 if len(hits) > 0: 

531 return hits[0] 

532 else: 

533 return np.inf, self.n 

534 elif k > 1: 

535 dd = np.empty(k,dtype=float) 

536 dd.fill(np.inf) 

537 ii = np.empty(k,dtype=int) 

538 ii.fill(self.n) 

539 for j in range(len(hits)): 

540 dd[j], ii[j] = hits[j] 

541 return dd, ii 

542 else: 

543 raise ValueError("Requested %s nearest neighbors; acceptable numbers are integers greater than or equal to one, or None") 

544 

545 def __query_ball_point(self, x, r, p=2., eps=0): 

546 R = Rectangle(self.maxes, self.mins) 

547 

548 def traverse_checking(node, rect): 

549 if rect.min_distance_point(x, p) > r / (1. + eps): 

550 return [] 

551 elif rect.max_distance_point(x, p) < r * (1. + eps): 

552 return traverse_no_checking(node) 

553 elif isinstance(node, KDTree.leafnode): 

554 d = self.data[node.idx] 

555 return node.idx[minkowski_distance(d, x, p) <= r].tolist() 

556 else: 

557 less, greater = rect.split(node.split_dim, node.split) 

558 return traverse_checking(node.less, less) + \ 

559 traverse_checking(node.greater, greater) 

560 

561 def traverse_no_checking(node): 

562 if isinstance(node, KDTree.leafnode): 

563 return node.idx.tolist() 

564 else: 

565 return traverse_no_checking(node.less) + \ 

566 traverse_no_checking(node.greater) 

567 

568 return traverse_checking(self.tree, R) 

569 

570 def query_ball_point(self, x, r, p=2., eps=0): 

571 """Find all points within distance r of point(s) x. 

572 

573 Parameters 

574 ---------- 

575 x : array_like, shape tuple + (self.m,) 

576 The point or points to search for neighbors of. 

577 r : positive float 

578 The radius of points to return. 

579 p : float, optional 

580 Which Minkowski p-norm to use. Should be in the range [1, inf]. 

581 eps : nonnegative float, optional 

582 Approximate search. Branches of the tree are not explored if their 

583 nearest points are further than ``r / (1 + eps)``, and branches are 

584 added in bulk if their furthest points are nearer than 

585 ``r * (1 + eps)``. 

586 

587 Returns 

588 ------- 

589 results : list or array of lists 

590 If `x` is a single point, returns a list of the indices of the 

591 neighbors of `x`. If `x` is an array of points, returns an object 

592 array of shape tuple containing lists of neighbors. 

593 

594 Notes 

595 ----- 

596 If you have many points whose neighbors you want to find, you may save 

597 substantial amounts of time by putting them in a KDTree and using 

598 query_ball_tree. 

599 

600 Examples 

601 -------- 

602 >>> from scipy import spatial 

603 >>> x, y = np.mgrid[0:5, 0:5] 

604 >>> points = np.c_[x.ravel(), y.ravel()] 

605 >>> tree = spatial.KDTree(points) 

606 >>> tree.query_ball_point([2, 0], 1) 

607 [5, 10, 11, 15] 

608 

609 Query multiple points and plot the results: 

610 

611 >>> import matplotlib.pyplot as plt 

612 >>> points = np.asarray(points) 

613 >>> plt.plot(points[:,0], points[:,1], '.') 

614 >>> for results in tree.query_ball_point(([2, 0], [3, 3]), 1): 

615 ... nearby_points = points[results] 

616 ... plt.plot(nearby_points[:,0], nearby_points[:,1], 'o') 

617 >>> plt.margins(0.1, 0.1) 

618 >>> plt.show() 

619 

620 """ 

621 x = np.asarray(x) 

622 if x.shape[-1] != self.m: 

623 raise ValueError("Searching for a %d-dimensional point in a " 

624 "%d-dimensional KDTree" % (x.shape[-1], self.m)) 

625 if len(x.shape) == 1: 

626 return self.__query_ball_point(x, r, p, eps) 

627 else: 

628 retshape = x.shape[:-1] 

629 result = np.empty(retshape, dtype=object) 

630 for c in np.ndindex(retshape): 

631 result[c] = self.__query_ball_point(x[c], r, p=p, eps=eps) 

632 return result 

633 

634 def query_ball_tree(self, other, r, p=2., eps=0): 

635 """Find all pairs of points whose distance is at most r 

636 

637 Parameters 

638 ---------- 

639 other : KDTree instance 

640 The tree containing points to search against. 

641 r : float 

642 The maximum distance, has to be positive. 

643 p : float, optional 

644 Which Minkowski norm to use. `p` has to meet the condition 

645 ``1 <= p <= infinity``. 

646 eps : float, optional 

647 Approximate search. Branches of the tree are not explored 

648 if their nearest points are further than ``r/(1+eps)``, and 

649 branches are added in bulk if their furthest points are nearer 

650 than ``r * (1+eps)``. `eps` has to be non-negative. 

651 

652 Returns 

653 ------- 

654 results : list of lists 

655 For each element ``self.data[i]`` of this tree, ``results[i]`` is a 

656 list of the indices of its neighbors in ``other.data``. 

657 

658 """ 

659 results = [[] for i in range(self.n)] 

660 

661 def traverse_checking(node1, rect1, node2, rect2): 

662 if rect1.min_distance_rectangle(rect2, p) > r/(1.+eps): 

663 return 

664 elif rect1.max_distance_rectangle(rect2, p) < r*(1.+eps): 

665 traverse_no_checking(node1, node2) 

666 elif isinstance(node1, KDTree.leafnode): 

667 if isinstance(node2, KDTree.leafnode): 

668 d = other.data[node2.idx] 

669 for i in node1.idx: 

670 results[i] += node2.idx[minkowski_distance(d,self.data[i],p) <= r].tolist() 

671 else: 

672 less, greater = rect2.split(node2.split_dim, node2.split) 

673 traverse_checking(node1,rect1,node2.less,less) 

674 traverse_checking(node1,rect1,node2.greater,greater) 

675 elif isinstance(node2, KDTree.leafnode): 

676 less, greater = rect1.split(node1.split_dim, node1.split) 

677 traverse_checking(node1.less,less,node2,rect2) 

678 traverse_checking(node1.greater,greater,node2,rect2) 

679 else: 

680 less1, greater1 = rect1.split(node1.split_dim, node1.split) 

681 less2, greater2 = rect2.split(node2.split_dim, node2.split) 

682 traverse_checking(node1.less,less1,node2.less,less2) 

683 traverse_checking(node1.less,less1,node2.greater,greater2) 

684 traverse_checking(node1.greater,greater1,node2.less,less2) 

685 traverse_checking(node1.greater,greater1,node2.greater,greater2) 

686 

687 def traverse_no_checking(node1, node2): 

688 if isinstance(node1, KDTree.leafnode): 

689 if isinstance(node2, KDTree.leafnode): 

690 for i in node1.idx: 

691 results[i] += node2.idx.tolist() 

692 else: 

693 traverse_no_checking(node1, node2.less) 

694 traverse_no_checking(node1, node2.greater) 

695 else: 

696 traverse_no_checking(node1.less, node2) 

697 traverse_no_checking(node1.greater, node2) 

698 

699 traverse_checking(self.tree, Rectangle(self.maxes, self.mins), 

700 other.tree, Rectangle(other.maxes, other.mins)) 

701 return results 

702 

703 def query_pairs(self, r, p=2., eps=0): 

704 """ 

705 Find all pairs of points within a distance. 

706 

707 Parameters 

708 ---------- 

709 r : positive float 

710 The maximum distance. 

711 p : float, optional 

712 Which Minkowski norm to use. `p` has to meet the condition 

713 ``1 <= p <= infinity``. 

714 eps : float, optional 

715 Approximate search. Branches of the tree are not explored 

716 if their nearest points are further than ``r/(1+eps)``, and 

717 branches are added in bulk if their furthest points are nearer 

718 than ``r * (1+eps)``. `eps` has to be non-negative. 

719 

720 Returns 

721 ------- 

722 results : set 

723 Set of pairs ``(i,j)``, with ``i < j``, for which the corresponding 

724 positions are close. 

725 

726 """ 

727 results = set() 

728 

729 def traverse_checking(node1, rect1, node2, rect2): 

730 if rect1.min_distance_rectangle(rect2, p) > r/(1.+eps): 

731 return 

732 elif rect1.max_distance_rectangle(rect2, p) < r*(1.+eps): 

733 traverse_no_checking(node1, node2) 

734 elif isinstance(node1, KDTree.leafnode): 

735 if isinstance(node2, KDTree.leafnode): 

736 # Special care to avoid duplicate pairs 

737 if id(node1) == id(node2): 

738 d = self.data[node2.idx] 

739 for i in node1.idx: 

740 for j in node2.idx[minkowski_distance(d,self.data[i],p) <= r]: 

741 if i < j: 

742 results.add((i,j)) 

743 else: 

744 d = self.data[node2.idx] 

745 for i in node1.idx: 

746 for j in node2.idx[minkowski_distance(d,self.data[i],p) <= r]: 

747 if i < j: 

748 results.add((i,j)) 

749 elif j < i: 

750 results.add((j,i)) 

751 else: 

752 less, greater = rect2.split(node2.split_dim, node2.split) 

753 traverse_checking(node1,rect1,node2.less,less) 

754 traverse_checking(node1,rect1,node2.greater,greater) 

755 elif isinstance(node2, KDTree.leafnode): 

756 less, greater = rect1.split(node1.split_dim, node1.split) 

757 traverse_checking(node1.less,less,node2,rect2) 

758 traverse_checking(node1.greater,greater,node2,rect2) 

759 else: 

760 less1, greater1 = rect1.split(node1.split_dim, node1.split) 

761 less2, greater2 = rect2.split(node2.split_dim, node2.split) 

762 traverse_checking(node1.less,less1,node2.less,less2) 

763 traverse_checking(node1.less,less1,node2.greater,greater2) 

764 

765 # Avoid traversing (node1.less, node2.greater) and 

766 # (node1.greater, node2.less) (it's the same node pair twice 

767 # over, which is the source of the complication in the 

768 # original KDTree.query_pairs) 

769 if id(node1) != id(node2): 

770 traverse_checking(node1.greater,greater1,node2.less,less2) 

771 

772 traverse_checking(node1.greater,greater1,node2.greater,greater2) 

773 

774 def traverse_no_checking(node1, node2): 

775 if isinstance(node1, KDTree.leafnode): 

776 if isinstance(node2, KDTree.leafnode): 

777 # Special care to avoid duplicate pairs 

778 if id(node1) == id(node2): 

779 for i in node1.idx: 

780 for j in node2.idx: 

781 if i < j: 

782 results.add((i,j)) 

783 else: 

784 for i in node1.idx: 

785 for j in node2.idx: 

786 if i < j: 

787 results.add((i,j)) 

788 elif j < i: 

789 results.add((j,i)) 

790 else: 

791 traverse_no_checking(node1, node2.less) 

792 traverse_no_checking(node1, node2.greater) 

793 else: 

794 # Avoid traversing (node1.less, node2.greater) and 

795 # (node1.greater, node2.less) (it's the same node pair twice 

796 # over, which is the source of the complication in the 

797 # original KDTree.query_pairs) 

798 if id(node1) == id(node2): 

799 traverse_no_checking(node1.less, node2.less) 

800 traverse_no_checking(node1.less, node2.greater) 

801 traverse_no_checking(node1.greater, node2.greater) 

802 else: 

803 traverse_no_checking(node1.less, node2) 

804 traverse_no_checking(node1.greater, node2) 

805 

806 traverse_checking(self.tree, Rectangle(self.maxes, self.mins), 

807 self.tree, Rectangle(self.maxes, self.mins)) 

808 return results 

809 

810 def count_neighbors(self, other, r, p=2.): 

811 """ 

812 Count how many nearby pairs can be formed. 

813 

814 Count the number of pairs (x1,x2) can be formed, with x1 drawn 

815 from self and x2 drawn from ``other``, and where 

816 ``distance(x1, x2, p) <= r``. 

817 This is the "two-point correlation" described in Gray and Moore 2000, 

818 "N-body problems in statistical learning", and the code here is based 

819 on their algorithm. 

820 

821 Parameters 

822 ---------- 

823 other : KDTree instance 

824 The other tree to draw points from. 

825 r : float or one-dimensional array of floats 

826 The radius to produce a count for. Multiple radii are searched with 

827 a single tree traversal. 

828 p : float, 1<=p<=infinity, optional 

829 Which Minkowski p-norm to use 

830 

831 Returns 

832 ------- 

833 result : int or 1-D array of ints 

834 The number of pairs. Note that this is internally stored in a numpy 

835 int, and so may overflow if very large (2e9). 

836 

837 """ 

838 def traverse(node1, rect1, node2, rect2, idx): 

839 min_r = rect1.min_distance_rectangle(rect2,p) 

840 max_r = rect1.max_distance_rectangle(rect2,p) 

841 c_greater = r[idx] > max_r 

842 result[idx[c_greater]] += node1.children*node2.children 

843 idx = idx[(min_r <= r[idx]) & (r[idx] <= max_r)] 

844 if len(idx) == 0: 

845 return 

846 

847 if isinstance(node1,KDTree.leafnode): 

848 if isinstance(node2,KDTree.leafnode): 

849 ds = minkowski_distance(self.data[node1.idx][:,np.newaxis,:], 

850 other.data[node2.idx][np.newaxis,:,:], 

851 p).ravel() 

852 ds.sort() 

853 result[idx] += np.searchsorted(ds,r[idx],side='right') 

854 else: 

855 less, greater = rect2.split(node2.split_dim, node2.split) 

856 traverse(node1, rect1, node2.less, less, idx) 

857 traverse(node1, rect1, node2.greater, greater, idx) 

858 else: 

859 if isinstance(node2,KDTree.leafnode): 

860 less, greater = rect1.split(node1.split_dim, node1.split) 

861 traverse(node1.less, less, node2, rect2, idx) 

862 traverse(node1.greater, greater, node2, rect2, idx) 

863 else: 

864 less1, greater1 = rect1.split(node1.split_dim, node1.split) 

865 less2, greater2 = rect2.split(node2.split_dim, node2.split) 

866 traverse(node1.less,less1,node2.less,less2,idx) 

867 traverse(node1.less,less1,node2.greater,greater2,idx) 

868 traverse(node1.greater,greater1,node2.less,less2,idx) 

869 traverse(node1.greater,greater1,node2.greater,greater2,idx) 

870 

871 R1 = Rectangle(self.maxes, self.mins) 

872 R2 = Rectangle(other.maxes, other.mins) 

873 if np.shape(r) == (): 

874 r = np.array([r]) 

875 result = np.zeros(1,dtype=int) 

876 traverse(self.tree, R1, other.tree, R2, np.arange(1)) 

877 return result[0] 

878 elif len(np.shape(r)) == 1: 

879 r = np.asarray(r) 

880 n, = r.shape 

881 result = np.zeros(n,dtype=int) 

882 traverse(self.tree, R1, other.tree, R2, np.arange(n)) 

883 return result 

884 else: 

885 raise ValueError("r must be either a single value or a one-dimensional array of values") 

886 

887 def sparse_distance_matrix(self, other, max_distance, p=2.): 

888 """ 

889 Compute a sparse distance matrix 

890 

891 Computes a distance matrix between two KDTrees, leaving as zero 

892 any distance greater than max_distance. 

893 

894 Parameters 

895 ---------- 

896 other : KDTree 

897 

898 max_distance : positive float 

899 

900 p : float, optional 

901 

902 Returns 

903 ------- 

904 result : dok_matrix 

905 Sparse matrix representing the results in "dictionary of keys" format. 

906 

907 """ 

908 result = scipy.sparse.dok_matrix((self.n,other.n)) 

909 

910 def traverse(node1, rect1, node2, rect2): 

911 if rect1.min_distance_rectangle(rect2, p) > max_distance: 

912 return 

913 elif isinstance(node1, KDTree.leafnode): 

914 if isinstance(node2, KDTree.leafnode): 

915 for i in node1.idx: 

916 for j in node2.idx: 

917 d = minkowski_distance(self.data[i],other.data[j],p) 

918 if d <= max_distance: 

919 result[i,j] = d 

920 else: 

921 less, greater = rect2.split(node2.split_dim, node2.split) 

922 traverse(node1,rect1,node2.less,less) 

923 traverse(node1,rect1,node2.greater,greater) 

924 elif isinstance(node2, KDTree.leafnode): 

925 less, greater = rect1.split(node1.split_dim, node1.split) 

926 traverse(node1.less,less,node2,rect2) 

927 traverse(node1.greater,greater,node2,rect2) 

928 else: 

929 less1, greater1 = rect1.split(node1.split_dim, node1.split) 

930 less2, greater2 = rect2.split(node2.split_dim, node2.split) 

931 traverse(node1.less,less1,node2.less,less2) 

932 traverse(node1.less,less1,node2.greater,greater2) 

933 traverse(node1.greater,greater1,node2.less,less2) 

934 traverse(node1.greater,greater1,node2.greater,greater2) 

935 traverse(self.tree, Rectangle(self.maxes, self.mins), 

936 other.tree, Rectangle(other.maxes, other.mins)) 

937 

938 return result 

939 

940 

941def distance_matrix(x, y, p=2, threshold=1000000): 

942 """ 

943 Compute the distance matrix. 

944 

945 Returns the matrix of all pair-wise distances. 

946 

947 Parameters 

948 ---------- 

949 x : (M, K) array_like 

950 Matrix of M vectors in K dimensions. 

951 y : (N, K) array_like 

952 Matrix of N vectors in K dimensions. 

953 p : float, 1 <= p <= infinity 

954 Which Minkowski p-norm to use. 

955 threshold : positive int 

956 If ``M * N * K`` > `threshold`, algorithm uses a Python loop instead 

957 of large temporary arrays. 

958 

959 Returns 

960 ------- 

961 result : (M, N) ndarray 

962 Matrix containing the distance from every vector in `x` to every vector 

963 in `y`. 

964 

965 Examples 

966 -------- 

967 >>> from scipy.spatial import distance_matrix 

968 >>> distance_matrix([[0,0],[0,1]], [[1,0],[1,1]]) 

969 array([[ 1. , 1.41421356], 

970 [ 1.41421356, 1. ]]) 

971 

972 """ 

973 

974 x = np.asarray(x) 

975 m, k = x.shape 

976 y = np.asarray(y) 

977 n, kk = y.shape 

978 

979 if k != kk: 

980 raise ValueError("x contains %d-dimensional vectors but y contains %d-dimensional vectors" % (k, kk)) 

981 

982 if m*n*k <= threshold: 

983 return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p) 

984 else: 

985 result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype 

986 if m < n: 

987 for i in range(m): 

988 result[i,:] = minkowski_distance(x[i],y,p) 

989 else: 

990 for j in range(n): 

991 result[:,j] = minkowski_distance(x,y[j],p) 

992 return result