class TrieNode: def __init__(self): self.children: Dict[str, TrieNode] = {} self.is_end: bool = False self.word: Optional[str] = None self.frequency: int = 0 self.top_results: List[str] = []
class SearchTrie: """搜索建议 Trie 树"""
def __init__(self, top_k: int = 10): self.root = TrieNode() self.top_k = top_k
def insert(self, word: str, frequency: int = 1): """插入词条并更新频率""" node = self.root path = []
for char in word.lower(): if char not in node.children: node.children[char] = TrieNode() node = node.children[char] path.append(node)
node.is_end = True node.word = word node.frequency += frequency
for n in path: self._update_top_k(n, word, node.frequency)
def search(self, prefix: str, limit: int = None) -> List[str]: """根据前缀搜索 Top-K 建议""" limit = limit or self.top_k node = self._find_node(prefix.lower()) if node is None: return [] return node.top_results[:limit]
def _find_node(self, prefix: str) -> Optional[TrieNode]: """查找前缀对应的节点""" node = self.root for char in prefix: if char not in node.children: return None node = node.children[char] return node
def _update_top_k(self, node: TrieNode, word: str, frequency: int): """更新节点的 Top-K 高频建议""" existing = [(i, w, f) for i, (w, f) in enumerate(node.top_results) if w == word]
if existing: idx, _, _ = existing[0] node.top_results[idx] = (word, frequency) else: node.top_results.append((word, frequency))
node.top_results = sorted( node.top_results, key=lambda x: x[1], reverse=True )[:self.top_k]
def fuzzy_search(self, prefix: str, max_edit_distance: int = 2) -> List[str]: """模糊搜索:容忍输入错误的前缀匹配""" candidates = [] self._dfs_fuzzy(self.root, '', prefix, 0, 0, max_edit_distance, candidates) candidates.sort(key=lambda x: x[1], reverse=True) return [word for word, _ in candidates[:self.top_k]]
def _dfs_fuzzy(self, node: TrieNode, current_word: str, target: str, pos: int, edits: int, max_edits: int, candidates: List): """DFS 搜索容错匹配""" if edits > max_edits or pos > len(target): return
if pos >= len(target): if node.is_end: candidates.append((node.word, node.frequency - edits * 1000)) self._collect_subtree(node, candidates) return
char = target[pos]
if char in node.children: self._dfs_fuzzy( node.children[char], current_word + char, target, pos + 1, edits, max_edits, candidates )
for child_char, child_node in node.children.items(): if child_char != char: self._dfs_fuzzy( child_node, current_word + child_char, target, pos + 1, edits + 1, max_edits, candidates )
self._dfs_fuzzy( node, current_word, target, pos + 1, edits + 1, max_edits, candidates )
for child_char, child_node in node.children.items(): self._dfs_fuzzy( child_node, current_word + child_char, target, pos, edits + 1, max_edits, candidates )
def _collect_subtree(self, node: TrieNode, candidates: List): """收集子树中所有完整单词""" if node.is_end: candidates.append((node.word, node.frequency)) for child in node.children.values(): self._collect_subtree(child, candidates)
|