Coverage for /home/mattis/projects/scripts/lingrex/src/lingrex/copar.py: 39%

488 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-06-30 19:08 +0200

1import math 

2import pathlib 

3import itertools 

4import collections 

5 

6from lingpy.sequence.sound_classes import class2tokens 

7from lingpy.settings import rc 

8from lingpy.align.sca import get_consensus, Alignments 

9from lingpy.util import pb 

10from lingpy import log 

11from lingpy import basictypes as bt 

12 

13import networkx as nx 

14 

15 

16def consensus_pattern(patterns, missing="Ø"): 

17 """ 

18 Return consensus pattern of multiple patterns. 

19 

20 :param patterns: list of patterns 

21 :param missing: the character used to represent missing values 

22 

23 .. note:: This consensus method raises an error if the patterns contain incompatible 

24 columns (non-identical values apart from the missing data character in the same 

25 column). 

26 """ 

27 out = [] 

28 for i in range(len(patterns[0])): 

29 col = [line[i] for line in patterns] 

30 no_gaps = [x for x in col if x != missing] 

31 if len(set(no_gaps)) > 1: 

32 raise ValueError("Your patterns are incompatible") 

33 out += [no_gaps[0] if no_gaps else missing] 

34 return tuple(out) 

35 

36 

37def incompatible_columns(patterns, missing="Ø"): 

38 """ 

39 Compute whether a pattern has incompatible columns. 

40 """ 

41 columns = [] 

42 for i in range(len(patterns[0])): 

43 col = [ 

44 patterns[j][i] for j in range(len(patterns)) if patterns[j][i] != missing 

45 ] 

46 columns.append("*" if len(set(col)) > 1 else "") 

47 return columns 

48 

49 

50def score_patterns(patterns, missing="Ø", mode="coverage"): 

51 """ 

52 Function gives a score for the overall number of reflexes. 

53 

54 .. note:: This score tells simply to which degree a pattern is filled. It divides the 

55 number of cells not containing missing data by the number of cells in the 

56 matrix. 

57 """ 

58 # return -1 if the patterns are not compatible 

59 for i in range(len(patterns[0])): 

60 if len(set([row[i] for row in patterns if row[i] != missing])) > 1: 

61 return -1 

62 if len(patterns) <= 1: 

63 return -1 

64 

65 if mode not in ["ranked", "pairs", "squared", "coverage"]: 

66 raise ValueError("you must select an appropriate mode") 

67 

68 # we rank the columns by sorting them first 

69 if mode == "ranked": 

70 cols = [] 

71 for i in range(len(patterns[0])): 

72 cols += [sum([0 if row[i] == missing else 1 for row in patterns])] 

73 # sort the columns 

74 ranks, cols = list(range(1, len(cols) + 1))[::-1], sorted(cols, reverse=True) 

75 scores = [] 

76 for rank, col in zip(ranks, cols): 

77 scores += [rank * col] 

78 return sum(scores) / sum(ranks) / len(patterns) 

79 

80 if mode == "squared": 

81 psize = len(patterns[0]) 

82 scores = [((psize - row.count(missing)) / psize) ** 2 for row in patterns] 

83 return sum(scores) / len(scores) 

84 

85 if mode == "pairs": 

86 # count the number of pairs in the data 

87 pairs = 0 

88 covered = 0 

89 m, n = len(patterns[0]), len(patterns) 

90 for i in range(n): 

91 vals = m - patterns[i].count(missing) 

92 pairs += (vals**2 - vals) / 2 

93 for i in range(m): 

94 vals = n - [p[i] for p in patterns].count(missing) 

95 pairs += (vals**2 - vals) / 2 

96 if vals != 0: 

97 covered += 1 

98 return ((pairs / n) / covered) / m 

99 

100 if mode == "coverage": 

101 cols = [] 

102 for i in range(len(patterns[0])): 

103 col = [row[i] for row in patterns] 

104 cols += [len(patterns) - col.count(missing)] 

105 return (sum(cols) / len(patterns[0])) / len(patterns) # * len(patterns[0])) 

106 

107 

108def compatible_columns(colA, colB, missing="Ø", gap="-"): 

109 """Check for column compatibility. 

110 

111 Parameters 

112 ---------- 

113 colA, colB = list 

114 Lists (sequence type) containing a given pattern. 

115 missing : str (default="Ø") 

116 A gap in the sense of "missing data", that is, a cognate set for which 

117 a value in a given language is absent. 

118 

119 Returns 

120 ------- 

121 matches, mismatches : tuple 

122 The score for matches gives zero if there is no conflict but also no 

123 match. For mismatches it is accordingly. So if you seek for 

124 compatibility, a mismatch greater 0 means the patterns are not 

125 compatible. 

126 """ 

127 matches, mismatches = 0, 0 

128 for a, b in zip(colA, colB): 

129 if missing not in [a, b]: 

130 if a != b: 

131 mismatches += 1 

132 else: 

133 if a != gap: 

134 matches += 1 

135 return matches, mismatches 

136 

137 

138def density(wordlist, ref="cogid"): 

139 """Compute the density of a wordlist. 

140 

141 Note 

142 ---- 

143 We define the density of a wordlist by measuring how many words can be 

144 explained by the same cognate set. 

145 """ 

146 scores = [] 

147 for concept in wordlist.rows: 

148 idxs = wordlist.get_list(row=concept, flat=True) 

149 cogids = [wordlist[idx, ref] for idx in idxs] 

150 sums = [1 / cogids.count(cogid) for idx, cogid in zip(idxs, cogids)] 

151 scores.append(sum(sums) / len(sums)) 

152 return 1 - sum(scores) / len(scores) 

153 

154 

155class CoPaR(Alignments): 

156 """Correspondence Pattern Recognition class 

157 

158 Parameters 

159 ---------- 

160 wordlist : ~lingpy.basic.wordlist.Wordlist 

161 A wordlist object which should have a column for segments and a column 

162 for cognate sets. Since the class inherits from LingPy's 

163 Alignments-class, the same kind of data should be submitted. 

164 ref : str (default="cogid") 

165 The column which stores the cognate sets. 

166 segments : str (default="tokens") 

167 The column which stores the segmented transcriptions. 

168 alignment : str (default="alignment") 

169 The column which stores the alignments (or will store the alignments if 

170 they have not yet been computed). 

171 

172 Note 

173 ---- 

174 This method was first introduced in List (2019). 

175 

176 > List, J.-M. (2019): Automatic inference of sound correspondence patterns 

177 > across multiple languages. Computational Linguistics 45.1. 137-161. DOI: 

178 > http://doi.org/10.1162/coli_a_00344 

179 """ 

180 

181 def __init__( 

182 self, 

183 wordlist, 

184 minrefs=3, 

185 ref="cogids", 

186 structure="structure", 

187 missing="Ø", 

188 gap="-", 

189 irregular="!?", 

190 **keywords 

191 ): 

192 Alignments.__init__(self, wordlist, ref=ref, **keywords) 

193 self.ref = ref 

194 self._structure = structure 

195 self.minrefs = minrefs 

196 self.missing = missing 

197 self.gap = gap 

198 self.irregular = irregular 

199 if structure not in self.columns: 

200 raise ValueError("no column {0} for structure was found".format(structure)) 

201 

202 def positions_from_prostrings(self, cogid, indices, alignment, structures): 

203 """ 

204 Return positions matching from an alignment and user-defined prosodic strings 

205 """ 

206 if self._mode == "fuzzy": 

207 strucs = [] 

208 for idx, struc, alm in zip(indices, structures, alignment): 

209 pos_ = self[idx, self._ref].index(cogid) 

210 strucs += [class2tokens(struc.n[pos_], alm)] 

211 else: 

212 strucs = [ 

213 class2tokens(struc, alm) for struc, alm in zip(structures, alignment) 

214 ] 

215 get_consensus(alignment, gaps=True) 

216 prostring = [] 

217 for i in range(len(strucs[0])): 

218 row = [x[i] for x in strucs if x[i] != "-"] 

219 prostring += [row[0] if row else "+"] 

220 return [(i, p) for i, p in enumerate(prostring)] 

221 

222 def reflexes_from_pos( 

223 self, position, taxa, current_taxa, alignment, missing, irregular 

224 ): 

225 reflexes = [] 

226 for t in taxa: 

227 if t not in current_taxa: 

228 reflexes += [missing] 

229 else: 

230 reflex = alignment[current_taxa.index(t)][position] 

231 if "/" in reflex: 

232 reflex = reflex.split("/")[1] 

233 elif reflex[0] in irregular: 

234 reflex = missing 

235 reflexes += [reflex] 

236 return reflexes 

237 

238 def _check(self): 

239 """ 

240 Check for problematic patterns in the data. 

241 """ 

242 errors = [] 

243 for idx, struc, alm in self.iter_rows(self._structure, self._alignment): 

244 self[idx, self._structure] = self._str_type(struc) 

245 self[idx, self._alignment] = self._str_type(alm) 

246 if not len(self[idx, self._structure]) == len( 

247 [x for x in self[idx, self._alignment] if x != "-"] 

248 ): 

249 print( 

250 idx, 

251 self[idx, self._structure], 

252 "|", 

253 self[idx, self._alignment], 

254 "|", 

255 self[idx, "tokens"], 

256 ) 

257 log.warning("alignment and structure do not match in {0}".format(idx)) 

258 errors += [idx] 

259 return errors 

260 

261 def get_sites(self): 

262 """ 

263 Retrieve the alignment sites of interest for initial analysis. 

264 """ 

265 sites, all_sites, taxa = ( 

266 collections.OrderedDict(), 

267 collections.OrderedDict(), 

268 self.cols, 

269 ) 

270 errors = self._check() 

271 if errors: 

272 raise ValueError("found {0} problems in the data".format(len(errors))) 

273 

274 # iterate over all sites in the alignment 

275 visited = [] 

276 for cogid, msa in pb( 

277 sorted(self.msa[self.ref].items()), 

278 desc="CoPaR: get_patterns()", 

279 total=len(self.msa[self.ref]), 

280 ): 

281 # get essential data: taxa, alignment, etc. 

282 _taxa = [t for t in taxa if t in msa["taxa"]] 

283 _idxs = {t: msa["taxa"].index(t) for t in _taxa} 

284 _alms = [msa["alignment"][_idxs[t]] for t in _taxa] 

285 _wlid = [msa["ID"][_idxs[t]] for t in _taxa] 

286 

287 # store visited entries 

288 visited += msa["ID"] 

289 if len(_taxa) >= self.minrefs: 

290 if self._mode == "fuzzy": 

291 _strucs = [] 

292 for _widx in _wlid: 

293 _these_strucs = self[_widx, self._structure] 

294 _strucs += [_these_strucs] 

295 else: 

296 _strucs = [self[idx, self._structure] for idx in _wlid] 

297 positions = self.positions_from_prostrings(cogid, _wlid, _alms, _strucs) 

298 for pidx, pos in positions: 

299 reflexes = self.reflexes_from_pos( 

300 pidx, taxa, _taxa, _alms, self.missing, self.irregular 

301 ) 

302 sites[cogid, pidx] = [pos, tuple(reflexes)] 

303 for pidx in range(len(_alms[0])): 

304 reflexes = self.reflexes_from_pos( 

305 pidx, taxa, _taxa, _alms, self.missing, self.irregular 

306 ) 

307 all_sites[cogid, pidx] = reflexes 

308 

309 # add non-visited segments 

310 for idx in [i for i in self if i not in visited]: 

311 if self._mode == "fuzzy": 

312 for tt, ss, cogid in zip( 

313 self[idx, self._segments].n, 

314 self[idx, self._structure].n, 

315 self[idx, self._ref], 

316 ): 

317 for i, (t, s) in enumerate(zip(tt, ss)): 

318 all_sites[cogid, i] = [ 

319 self.missing if tax != self[idx][self._colIdx] else t 

320 for tax in self.cols 

321 ] 

322 else: 

323 for i, (t, s) in enumerate( 

324 zip(self[idx, self._segments], self[idx, self._structure]) 

325 ): 

326 all_sites[self[idx, self.ref], i] = [ 

327 self.missing if tax != self[idx][self._colIdx] else t 

328 for tax in self.cols 

329 ] 

330 

331 self.sites = sites 

332 self.all_sites = all_sites 

333 

334 def cluster_sites(self, match_threshold=1, score_mode="pairs"): 

335 """Cluster alignment sites using greedy clique cover. 

336 :param match_threshold: The threshold of matches for accepting two 

337 compatible columns. 

338 :param score_mode: select between "pairs", "coverage" 

339 

340 .. note:: This algorithm follows the spirit of the Welsh-Powell algorithm for 

341 graph coloring. Since graph coloring is the inverse of clique 

342 partitioning, we can use the algorithm in the same spirit. 

343 

344 """ 

345 if not hasattr(self, "clusters"): 

346 self.clusters = collections.defaultdict(list) 

347 for (cogid, idx), (pos, ptn) in self.sites.items(): 

348 self.clusters[pos, ptn] += [(cogid, idx)] 

349 clusters = self.clusters 

350 while True: 

351 prog = 0 

352 with pb( 

353 desc="CoPaR: cluster_sites()", total=len(self.clusters) 

354 ) as progress: 

355 sorted_clusters = sorted( 

356 clusters.items(), 

357 key=lambda x: ( 

358 score_patterns( 

359 [self.sites[y][1] for y in x[1]], mode=score_mode 

360 ), 

361 len(x[1]), 

362 ), 

363 reverse=True, 

364 ) 

365 out = [] 

366 while sorted_clusters: 

367 ((this_pos, this_cluster), these_vals), remaining_clusters = ( 

368 sorted_clusters[0], 

369 sorted_clusters[1:], 

370 ) 

371 queue = [] 

372 for (next_pos, next_cluster), next_vals in remaining_clusters: 

373 match, mism = compatible_columns( 

374 this_cluster, 

375 next_cluster, 

376 missing=self.missing, 

377 gap=self.gap, 

378 ) 

379 if ( 

380 this_pos == next_pos 

381 and match >= match_threshold # noqa: W503 

382 and mism == 0 # noqa: W503 

383 ): 

384 this_cluster = consensus_pattern( 

385 [this_cluster, next_cluster] 

386 ) 

387 these_vals += next_vals 

388 else: 

389 queue += [((next_pos, next_cluster), next_vals)] 

390 sorted_clusters = queue 

391 out += [((this_pos, this_cluster), these_vals)] 

392 progress.update(len(self.sites) - len(queue) - prog) 

393 prog = len(self.sites) - len(queue) 

394 clusters = {tuple(a): b for a, b in out} 

395 alls = [c for c in clusters] 

396 match = 0 

397 for i, (_a, a) in enumerate(alls): 

398 for j, (_b, b) in enumerate(alls): 

399 if i < j and _a == _b: 

400 ma, mi = compatible_columns( 

401 a, b, missing=self.missing, gap=self.gap 

402 ) 

403 if ma and not mi: 

404 match += 1 

405 if not match: 

406 break 

407 else: 

408 log.warning( 

409 "iterating, since {0} clusters can further be merged".format( 

410 match 

411 ) 

412 ) 

413 self.clusters = clusters 

414 self.ordered_clusters = sorted(clusters, key=lambda x: len(x[1])) 

415 

416 def sites_to_pattern(self, threshold=1): 

417 """Algorithm assigns alignment sites to patterns. 

418 

419 Notes 

420 ----- 

421 We rank according to general compatibility. 

422 """ 

423 asites = collections.defaultdict(list) 

424 for consensus in pb( 

425 self.clusters, desc="CoPaR: sites_to_pattern()", total=len(self.clusters) 

426 ): 

427 sites = self.clusters[consensus] 

428 for cog, pos in sites: 

429 struc, pattern = self.sites[cog, pos] 

430 for strucB, consensusB in self.clusters: 

431 ma, mi = compatible_columns(pattern, consensusB) 

432 if struc == strucB and not mi and ma >= threshold: 

433 asites[cog, pos] += [(ma, struc, consensusB)] 

434 self.patterns = asites 

435 

436 def fuzziness(self): 

437 return sum([len(b) for a, b in self.patterns.items()]) / len(self.patterns) 

438 

439 def irregular_patterns(self, accepted=2, matches=1, irregular_prefix="!"): 

440 """ 

441 Try to assign irregular patterns to accepted patterns. 

442 

443 Parameters 

444 ---------- 

445 accepted : int (default=2) 

446 Minimal size of clusters that we regard as regular. 

447 

448 """ 

449 bad_clusters = [ 

450 (clr, pts[0]) for clr, pts in self.clusters.items() if len(pts) == 1 

451 ] 

452 good_clusters = sorted( 

453 [(clr, pts) for clr, pts in self.clusters.items() if len(pts) >= accepted], 

454 key=lambda x: len(x[1]), 

455 reverse=True, 

456 ) 

457 new_clusters = {clr: [] for clr, pts in good_clusters} 

458 irregular_patterns = [] 

459 for clr, ptn in bad_clusters: 

460 if ptn.count(self.missing) <= 2: 

461 for clrB, pts in good_clusters: 

462 match, mism = compatible_columns(clr[1], clrB[1]) 

463 if mism <= matches and match > matches: 

464 new_clusters[clrB] += [clr] 

465 irregular_patterns += [clr] 

466 break 

467 # re-assign alignments to the data by adding the irregular character 

468 for key, value in sorted( 

469 new_clusters.items(), key=lambda x: len(x[1]), reverse=True 

470 ): 

471 if len(value) > 0: 

472 for i, pattern in enumerate(value): 

473 pt = [] 

474 for lid, (a, b) in enumerate(zip(key[1], pattern[1])): 

475 if a != b and self.missing not in [a, b]: 

476 pt += [irregular_prefix + b] 

477 # assign pattern to the corresponding alignments 

478 for cogid, position in self.clusters[pattern]: 

479 if self._mode == "fuzzy": 

480 word_indices = self.etd[self.ref][cogid][lid] 

481 if word_indices: 

482 for widx in word_indices: 

483 # get the position in the alignment 

484 alms = self[widx, self._alignment].n 

485 cog_pos = self[widx, self.ref].index(cogid) 

486 new_alm = alms[cog_pos] 

487 new_alm[position] = "{0}{1}/{2}".format( 

488 irregular_prefix, b, a 

489 ) 

490 alms[cog_pos] = new_alm 

491 self[ 

492 widx, self._alignment 

493 ] = self._str_type( 

494 " + ".join( 

495 [" ".join(x) for x in alms] 

496 ).split() 

497 ) 

498 else: 

499 word_indices = self.etd[self.ref][cogid][lid] 

500 if word_indices: 

501 for widx in word_indices: 

502 alm = self._str_type( 

503 self[widx, self._alignment] 

504 ) 

505 alm[position] = "{0}{1}/{2}".format( 

506 irregular_prefix, b, a 

507 ) 

508 self[ 

509 widx, self._alignment 

510 ] = self._str_type(" ".join(alm)) 

511 else: 

512 pt += [b] 

513 

514 self.ipatterns = new_clusters 

515 for pattern, data in [ 

516 (a, b) for a, b in bad_clusters if a not in irregular_patterns 

517 ]: 

518 cogid, position = data 

519 if self._mode == "fuzzy": 

520 for indices in [idx for idx in self.etd[self.ref][cogid] if idx]: 

521 for widx in indices: 

522 cog_pos = self[widx, self.ref].index(cogid) 

523 alms = self[widx, self._alignment].n 

524 new_alm = alms[cog_pos] 

525 new_alm[position] = "{0}{1}".format( 

526 irregular_prefix, new_alm[position] 

527 ) 

528 alms[cog_pos] = new_alm 

529 self[widx, self._alignment] = self._str_type( 

530 " + ".join([" ".join(x) for x in alms]).split() 

531 ) 

532 

533 return new_clusters 

534 

535 def load_patterns(self, patterns="patterns"): 

536 self.id2ptn = collections.OrderedDict() 

537 self.clusters = collections.OrderedDict() 

538 self.id2pos = collections.defaultdict(set) 

539 self.sites = collections.OrderedDict() 

540 # get the template 

541 template = [self.missing for m in self.cols] 

542 tidx = {self.cols[i]: i for i in range(self.width)} 

543 for idx, ptn, alm, struc, doc, cogs in self.iter_rows( 

544 patterns, self._alignment, self._structure, "doculect", self._ref 

545 ): 

546 if self._mode == "fuzzy": 

547 ptn = bt.lists(ptn) 

548 for i in range(len(alm.n)): 

549 for j, (p, a) in enumerate(zip(ptn.n[i], alm.n[i])): 

550 if not p == "0/n": 

551 this_pattern = self.id2ptn.get(p, [t for t in template]) 

552 if this_pattern[tidx[doc]] == "Ø": 

553 this_pattern[tidx[doc]] = a 

554 self.id2ptn[p] = this_pattern 

555 self.id2pos[p].add((cogs[i], j)) 

556 else: 

557 for j, (p, a) in enumerate(zip(ptn, alm)): 

558 if not p == "0/n": 

559 this_pattern = self.id2ptn.get(p, [t for t in template]) 

560 if this_pattern[tidx[doc]] == "Ø": 

561 this_pattern[tidx[doc]] = a 

562 self.id2ptn[p] = this_pattern 

563 self.id2pos[p].add((cogs, j)) 

564 

565 self.ptn2id = {tuple(v): k for k, v in self.id2ptn.items()} 

566 for k, v in self.id2ptn.items(): 

567 self.clusters[tuple(v)] = list(self.id2pos[k]) 

568 self.id2pos[k] = list(self.id2pos[k]) 

569 for s in self.id2pos[k]: 

570 self.sites[s] = [(len(self.id2pos[k]), tuple(v))] 

571 

572 def add_patterns( 

573 self, ref="patterns", irregular_patterns=False, proto=False, override=True 

574 ): 

575 """Assign patterns to a new column in the word list.""" 

576 if not hasattr(self, "id2ptn"): 

577 self.id2ptn = {} 

578 if not hasattr(self, "pattern2id"): 

579 self.ptn2id = {} 

580 if proto: 

581 pidx = self.cols.index(proto) 

582 else: 

583 pidx = 0 

584 

585 if irregular_patterns: 

586 new_clusters = collections.defaultdict(list) 

587 for reg, iregs in self.ipatterns.items(): 

588 for cogid, position in self.clusters[reg]: 

589 new_clusters[reg] += [(cogid, position)] 

590 for ireg in iregs: 

591 for cogid, position in self.clusters[ireg]: 

592 new_clusters[reg] += [(cogid, position)] 

593 else: 

594 new_clusters = self.clusters 

595 for pattern, rest in self.clusters.items(): 

596 for cogid, position in rest: 

597 if (cogid, position) not in new_clusters[pattern]: 

598 new_clusters[pattern] += [(cogid, position)] 

599 

600 P = { 

601 idx: bt.lists( 

602 [ 

603 "0" if x not in rc("morpheme_separators") else "+" 

604 for x in self[idx, self._alignment] 

605 ] 

606 ) 

607 for idx in self 

608 } 

609 for i, ((struc, pattern), data) in enumerate( 

610 sorted(new_clusters.items(), key=lambda x: len(x), reverse=True) 

611 ): 

612 pattern_id = "{0}".format( 

613 i + 1 #, len(self.clusters[struc, pattern]), pattern[pidx] 

614 ) 

615 self.id2ptn[pattern_id] = pattern 

616 self.ptn2id[pattern] = pattern_id 

617 for cogid, position in data: 

618 word_indices = [c for c in self.etd[self.ref][cogid] if c] 

619 for idxs in word_indices: 

620 for idx in idxs: 

621 if self._mode == "fuzzy": 

622 pattern_position = self[idx, self.ref].index(cogid) 

623 this_pattern = P[idx].n[pattern_position] 

624 try: 

625 this_pattern[position] = pattern_id 

626 P[idx].change(pattern_position, this_pattern) 

627 except: # noqa: E722 

628 log.warning("error in {0}".format(cogid)) 

629 

630 else: 

631 P[idx][position] = pattern_id 

632 self.add_entries(ref, P, lambda x: x, override=override) 

633 

634 def write_patterns(self, filename, proto=False, irregular_patterns=False): 

635 if proto: 

636 pidx = self.cols.index(proto) 

637 else: 

638 pidx = 0 

639 

640 if not hasattr(self, "id2ptn"): 

641 raise ValueError("You should run CoPaR.add_patterns first!") 

642 

643 if irregular_patterns: 

644 new_clusters = collections.defaultdict(list) 

645 for (pos, reg), iregs in self.ipatterns.items(): 

646 for cogid, position in self.clusters[pos, reg]: 

647 new_clusters[pos, reg] += [(cogid, position)] 

648 for _, ireg in iregs: 

649 ireg_ = list(ireg) 

650 print(ireg_) 

651 for i, (a, b) in enumerate(zip(reg, ireg)): 

652 print(i, a, b) 

653 if a != b and b != self.missing: 

654 ireg_[i] = a + "/" + b 

655 ireg_ = tuple(ireg_) 

656 self.ptn2id[ireg_] = self.ptn2id[reg] 

657 for cogid, position in self.clusters[pos, ireg]: 

658 new_clusters[pos, ireg_] += [(cogid, position)] 

659 else: 

660 new_clusters = self.clusters 

661 for (struc, pattern), rest in self.clusters.items(): 

662 for cogid, position in rest: 

663 if (cogid, position) not in new_clusters[struc, pattern]: 

664 new_clusters[struc, pattern] += [(cogid, position)] 

665 text = "ID\tSTRUCTURE\tFREQUENCY\t{0}\t{1}\tCOGNATESETS\tCONCEPTS\n".format( 

666 self.cols[pidx], "\t".join([c for c in self.cols if c != self.cols[pidx]]) 

667 ) 

668 

669 sound = "" 

670 idx = 0 

671 for (struc, pattern), entries in sorted( 

672 new_clusters.items(), 

673 key=lambda x: (x[0][0], x[0][1][pidx], len(x[1])), 

674 reverse=True, 

675 ): 

676 if sound != pattern[pidx]: 

677 sound = pattern[pidx] 

678 idx = 0 

679 concepts = [] 

680 for x, y in entries: 

681 for entry in self.etd[self.ref][x]: 

682 if entry: 

683 for value in entry: 

684 concepts += [self[value, "concept"]] 

685 concepts = " / ".join(sorted(set(concepts))) 

686 

687 idx += 1 

688 text += "{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n".format( 

689 self.ptn2id[pattern].split("/")[0], 

690 struc, 

691 len(entries), 

692 pattern[pidx], 

693 "\t".join([p for i, p in enumerate(pattern) if i != pidx]), 

694 ", ".join(["{0}:{1}".format(x, y) for x, y in entries]), 

695 concepts, 

696 ) 

697 pathlib.Path(filename).write_text(text, encoding="utf8") 

698 

699 def purity(self): 

700 """ 

701 Compute the purity of the cluster analysis. 

702 

703 .. note:: The purity is here interpreted as the degree to which 

704 patterns are filled with non-missing values. In this sense, it 

705 indicates to which degree information is computed and to which 

706 degree information is already provided by the data itself. 

707 """ 

708 

709 def get_purity(patterns): 

710 all_sums = [] 

711 for i in range(len(patterns[0])): 

712 col = [line[i] for line in patterns] 

713 subset = set(col) 

714 sums = [] 

715 for itm in subset: 

716 if itm != self.missing: 

717 sums += [col.count(itm) ** 2] 

718 if sums: 

719 sums = math.sqrt(sum(sums)) / len(col) 

720 else: 

721 sums = 0 

722 all_sums += [sums] 

723 return sum(all_sums) / len(all_sums) 

724 

725 graph = self.get_cluster_graph() 

726 purities = [] 

727 for node, data in graph.nodes(data=True): 

728 patterns = [] 

729 for neighbor in graph[node]: 

730 patterns += [graph.nodes[neighbor]["pattern"].split()] 

731 if patterns: 

732 purities += [get_purity(patterns)] 

733 else: 

734 purities += [0] 

735 return sum(purities) / len(purities) 

736 

737 def get_cluster_graph(self): 

738 """ 

739 Compute a graph of the clusters. 

740 

741 .. note:: In the cluster graph, the sites in the alignments are the 

742 nodes and the edges are drawn between nodes assigned to the same 

743 pattern. 

744 """ 

745 

746 graph = nx.Graph() 

747 for (pos, ptn), sites in self.clusters.items(): 

748 for site in sites: 

749 graph.add_node( 

750 "{0[0]}-{0[1]}".format(site), 

751 pattern=" ".join(ptn), 

752 site=" ".join(self.sites[site][1]), 

753 ) 

754 

755 for ((s1, p1), ptn1), ((s2, p2), ptn2) in itertools.combinations( 

756 self.sites.items(), r=2 

757 ): 

758 if ptn1[0] == ptn2[0]: 

759 m, mm = compatible_columns(ptn1[1], ptn2[1]) 

760 if m and not mm: 

761 graph.add_edge("{0}-{1}".format(s1, p1), "{0}-{1}".format(s2, p2)) 

762 return graph 

763 

764 def upper_bound(self): 

765 """ 

766 Compute upper bound for clique partitioning following Bhasker 1991. 

767 """ 

768 degs = {s: 0 for s in self.sites} 

769 sings = {s: 0 for s in self.sites} 

770 for (nA, (posA, ptnA)), (nB, (posB, ptnB)) in itertools.combinations( 

771 self.sites.items(), r=2 

772 ): 

773 if posA == posB: 

774 m, n = compatible_columns(ptnA, ptnB) 

775 if n > 0: 

776 degs[nA] += 1 

777 degs[nB] += 1 

778 else: 

779 sings[nA] += 1 

780 sings[nB] += 1 

781 else: 

782 degs[nA] += 1 

783 degs[nB] += 1 

784 

785 return max([b for a, b in degs.items() if sings[a] > 0]) 

786 

787 def predict_words(self, **kw): 

788 """ 

789 Predict patterns for those cognate sets where we have missing data. 

790 

791 .. note:: 

792 

793 Purity (one of the return values) measures how well a given sound 

794 for a given site is reflected by one single sound (rather than 

795 multiple patterns pointing to different sounds) for a given 

796 doculect. It may be seen as a control case for the purity of a given 

797 prediction: if there are many alternative possibilities, this means 

798 that there is more uncertainty regarding the reconstructions or 

799 predictions. 

800 

801 """ 

802 if not hasattr(self, "sites"): 

803 raise ValueError("You need to compute alignment sites first") 

804 

805 minrefs = self.minrefs 

806 missing = self.missing 

807 samples = kw.get("samples", 3) 

808 

809 # pre-analyse the data to get for each site the best patterns in ranked 

810 # form 

811 ranked_sites = {} 

812 ranked_clusters = sorted( 

813 [(s, p, len(f)) for (s, p), f in self.clusters.items()], 

814 key=lambda x: x[2], 

815 reverse=True, 

816 ) 

817 for (cogid, pos), ptns in self.patterns.items(): 

818 struc, ptn = self.sites[cogid, pos] 

819 missings = [i for i in range(self.width) if ptn[i] == missing] 

820 if (struc, ptn) in self.clusters: 

821 ranked_sites[cogid, pos] = [ 

822 (len(self.clusters[struc, ptn]), struc, ptn) 

823 ] 

824 else: 

825 ranked_sites[cogid, pos] = [(1, struc, ptn)] 

826 for strucB, ptnB, freq in ranked_clusters: 

827 m, mm = compatible_columns(ptn, ptnB) 

828 if struc == strucB and m >= 1 and mm == 0: 

829 if len(missings) > len( 

830 [ptnB[i] for i in missings if ptnB[i] == missing] 

831 ): 

832 ranked_sites[cogid, pos] += [(freq, strucB, ptnB)] 

833 

834 purity = {site: {} for site in ranked_sites} 

835 

836 preds = {} 

837 for cogid, msa in self.msa[self._ref].items(): 

838 missings = [t for t in self.cols if t not in msa["taxa"]] 

839 if len(set(msa["taxa"])) >= minrefs: 

840 words = [bt.strings("") for m in missings] 

841 for i, m in enumerate(missings): 

842 tidx = self.cols.index(m) 

843 for j in range(len(msa["alignment"][0])): 

844 segments = collections.defaultdict(int) 

845 sidx = 0 

846 if (cogid, j) in ranked_sites: 

847 while True: 

848 this_segment = ranked_sites[cogid, j][sidx][2][tidx] 

849 score = ranked_sites[cogid, j][sidx][0] 

850 if this_segment != missing: 

851 segments[this_segment] += score 

852 sidx += 1 

853 if sidx == len(ranked_sites[cogid, j]): 

854 break 

855 

856 if not (cogid, j) in purity: 

857 purity[cogid, j] = {} 

858 

859 if not segments: 

860 words[i] += ["Ø"] 

861 purity[cogid, j][m] = 0 

862 else: 

863 purity[cogid, j][m] = math.sqrt( 

864 sum( 

865 [ 

866 (s / sum(segments.values())) ** 2 

867 for s in segments.values() 

868 ] 

869 ) 

870 ) 

871 words[i] += [ 

872 "|".join( 

873 [ 

874 s 

875 for s in sorted( 

876 segments, 

877 key=lambda x: segments[x], 

878 reverse=True, 

879 ) 

880 ][:samples] 

881 ) 

882 ] 

883 if words: 

884 preds[cogid] = dict(zip(missings, words)) 

885 

886 pudity = {doc: [] for doc in self.cols} 

887 for site, docs in purity.items(): 

888 for doc in docs: 

889 pudity[doc] += [purity[site][doc]] 

890 for doc, purs in pudity.items(): 

891 if purs: 

892 pudity[doc] = sum(purs) / len(purs) 

893 else: 

894 pudity[doc] = 0 

895 

896 return preds, purity, pudity