Coverage for src/extratools_core/trie.py: 89%
96 statements
« prev ^ index » next coverage.py v7.8.1, created at 2025-06-24 04:41 -0700
« prev ^ index » next coverage.py v7.8.1, created at 2025-06-24 04:41 -0700
1from __future__ import annotations
3from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping
4from typing import Any, override
6from .typing import SearchableMapping
9class TrieDict[VT: Any](MutableMapping[str, VT], SearchableMapping[str, VT]):
10 def __init__(
11 self,
12 initial_data: Mapping[str, VT] | Iterable[tuple[str, VT]] | None = None,
13 ) -> None:
14 self.root: dict[str, Any] = {}
16 self.__len: int = 0
18 if initial_data:
19 for key, value in (
20 initial_data.items() if isinstance(initial_data, Mapping)
21 else initial_data
22 ):
23 self.__setitem__(key, value)
25 def __len__(self) -> int:
26 return self.__len
28 def __find(self, s: str, func: Callable[[dict[str, Any], str], Any]) -> Any:
29 node: dict[str, Any] = self.root
31 while True:
32 c: str = s[0] if s else ""
33 rest: str = s[1:] if s else ""
35 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c)
36 if next_node is None:
37 raise KeyError
39 if isinstance(next_node, dict):
40 node = next_node
41 s = rest
42 continue
44 if rest == next_node[0]:
45 return func(node, c)
47 raise KeyError
49 def __delitem__(self, s: str) -> None:
50 def delitem(node: dict[str, Any], c: str) -> None:
51 del node[c]
52 self.__len -= 1
54 return self.__find(s, delitem)
56 def __getitem__(self, s: str) -> VT:
57 def getitem(node: dict[str, Any], c: str) -> VT:
58 return node[c][1]
60 return self.__find(s, getitem)
62 def __setitem__(self, s: str, v: VT) -> None:
63 self.__set(s, v, self.root, is_new=True)
65 def __set(self, s: str, v: VT, node: dict[str, Any], *, is_new: bool) -> None:
66 if not s:
67 is_new = is_new and "" not in node
68 node[""] = ("", v)
69 if is_new:
70 self.__len += 1
72 return
74 c: str = s[0]
75 rest: str = s[1:]
77 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c)
78 if next_node is None:
79 node[c] = (rest, v)
80 if is_new:
81 self.__len += 1
82 elif isinstance(next_node, dict):
83 self.__set(rest, v, next_node, is_new=is_new)
84 else:
85 other_rest: str
86 other_value: VT
87 other_rest, other_value = next_node
89 if rest == other_rest:
90 node[c] = (rest, v)
91 return
93 next_node = node[c] = {}
95 self.__set(other_rest, other_value, next_node, is_new=False)
96 self.__set(rest, v, next_node, is_new=is_new)
98 def __iter__(self) -> Iterator[str]:
99 for _, value in self.__prefixes("", self.root):
100 yield value
102 def prefixes(self) -> Iterator[tuple[str, str]]:
103 yield from self.__prefixes("", self.root)
105 def __prefixes(self, prefix: str, node: dict[str, Any]) -> Iterator[tuple[str, str]]:
106 for key, next_node in node.items():
107 new_prefix = prefix + key
108 if isinstance(next_node, dict):
109 yield from self.__prefixes(new_prefix, next_node)
110 else:
111 yield (new_prefix, new_prefix + next_node[0])
113 @override
114 def search(self, filter_body: str | None = None) -> Iterator[str]:
115 prefix: str = filter_body or ""
117 node: dict[str, Any] = self.root
118 s: str = prefix
120 matched: str = ""
122 while s:
123 c: str = s[0]
124 rest: str = s[1:]
125 matched += c
127 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c)
128 if next_node is None:
129 return
131 if isinstance(next_node, dict):
132 node = next_node
133 s = rest
134 continue
136 other_rest: str = next_node[0]
137 if other_rest.startswith(rest):
138 yield matched + other_rest
140 return
142 for _, value in self.__prefixes(prefix, node):
143 yield value