Coverage for src/extratools_core/trie.py: 89%

96 statements  

« prev     ^ index     » next       coverage.py v7.8.1, created at 2025-06-24 04:41 -0700

1from __future__ import annotations 

2 

3from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping 

4from typing import Any, override 

5 

6from .typing import SearchableMapping 

7 

8 

9class TrieDict[VT: Any](MutableMapping[str, VT], SearchableMapping[str, VT]): 

10 def __init__( 

11 self, 

12 initial_data: Mapping[str, VT] | Iterable[tuple[str, VT]] | None = None, 

13 ) -> None: 

14 self.root: dict[str, Any] = {} 

15 

16 self.__len: int = 0 

17 

18 if initial_data: 

19 for key, value in ( 

20 initial_data.items() if isinstance(initial_data, Mapping) 

21 else initial_data 

22 ): 

23 self.__setitem__(key, value) 

24 

25 def __len__(self) -> int: 

26 return self.__len 

27 

28 def __find(self, s: str, func: Callable[[dict[str, Any], str], Any]) -> Any: 

29 node: dict[str, Any] = self.root 

30 

31 while True: 

32 c: str = s[0] if s else "" 

33 rest: str = s[1:] if s else "" 

34 

35 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c) 

36 if next_node is None: 

37 raise KeyError 

38 

39 if isinstance(next_node, dict): 

40 node = next_node 

41 s = rest 

42 continue 

43 

44 if rest == next_node[0]: 

45 return func(node, c) 

46 

47 raise KeyError 

48 

49 def __delitem__(self, s: str) -> None: 

50 def delitem(node: dict[str, Any], c: str) -> None: 

51 del node[c] 

52 self.__len -= 1 

53 

54 return self.__find(s, delitem) 

55 

56 def __getitem__(self, s: str) -> VT: 

57 def getitem(node: dict[str, Any], c: str) -> VT: 

58 return node[c][1] 

59 

60 return self.__find(s, getitem) 

61 

62 def __setitem__(self, s: str, v: VT) -> None: 

63 self.__set(s, v, self.root, is_new=True) 

64 

65 def __set(self, s: str, v: VT, node: dict[str, Any], *, is_new: bool) -> None: 

66 if not s: 

67 is_new = is_new and "" not in node 

68 node[""] = ("", v) 

69 if is_new: 

70 self.__len += 1 

71 

72 return 

73 

74 c: str = s[0] 

75 rest: str = s[1:] 

76 

77 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c) 

78 if next_node is None: 

79 node[c] = (rest, v) 

80 if is_new: 

81 self.__len += 1 

82 elif isinstance(next_node, dict): 

83 self.__set(rest, v, next_node, is_new=is_new) 

84 else: 

85 other_rest: str 

86 other_value: VT 

87 other_rest, other_value = next_node 

88 

89 if rest == other_rest: 

90 node[c] = (rest, v) 

91 return 

92 

93 next_node = node[c] = {} 

94 

95 self.__set(other_rest, other_value, next_node, is_new=False) 

96 self.__set(rest, v, next_node, is_new=is_new) 

97 

98 def __iter__(self) -> Iterator[str]: 

99 for _, value in self.__prefixes("", self.root): 

100 yield value 

101 

102 def prefixes(self) -> Iterator[tuple[str, str]]: 

103 yield from self.__prefixes("", self.root) 

104 

105 def __prefixes(self, prefix: str, node: dict[str, Any]) -> Iterator[tuple[str, str]]: 

106 for key, next_node in node.items(): 

107 new_prefix = prefix + key 

108 if isinstance(next_node, dict): 

109 yield from self.__prefixes(new_prefix, next_node) 

110 else: 

111 yield (new_prefix, new_prefix + next_node[0]) 

112 

113 @override 

114 def search(self, filter_body: str | None = None) -> Iterator[str]: 

115 prefix: str = filter_body or "" 

116 

117 node: dict[str, Any] = self.root 

118 s: str = prefix 

119 

120 matched: str = "" 

121 

122 while s: 

123 c: str = s[0] 

124 rest: str = s[1:] 

125 matched += c 

126 

127 next_node: dict[str, Any] | tuple[str, VT] | None = node.get(c) 

128 if next_node is None: 

129 return 

130 

131 if isinstance(next_node, dict): 

132 node = next_node 

133 s = rest 

134 continue 

135 

136 other_rest: str = next_node[0] 

137 if other_rest.startswith(rest): 

138 yield matched + other_rest 

139 

140 return 

141 

142 for _, value in self.__prefixes(prefix, node): 

143 yield value