99from __future__ import annotations
1010
1111from collections .abc import Iterable , Iterator
12- from typing import Union
12+ from typing import Optional , Union
1313
1414
1515class Trie (Iterable [str ]):
@@ -46,64 +46,74 @@ class Trie(Iterable[str]):
4646 # output: 5
4747 """
4848
49- words : set [str ]
5049 root : Node
50+ _word_count : int
5151
5252 class Node :
5353 __slots__ : tuple [str , str ] = ("end" , "children" )
5454
5555 def __init__ (self ) -> None :
5656 self .end : bool = False
57- self .children : dict [str , Trie .Node ] = {}
57+ # Children dict is created on demand to reduce memory for leaf nodes.
58+ self .children : Optional [dict [str , Trie .Node ]] = None
5859
5960 def __init__ (self , words : Iterable [str ]) -> None :
60- self .words : set [ str ] = set ( words )
61+ self ._word_count : int = 0
6162 self .root : Trie .Node = Trie .Node ()
62-
6363 for word in words :
6464 self .add (word )
6565
6666 def add (self , word : str ) -> None :
6767 """Add a word to the trie.
6868 Spaces in front of and following the word will be removed.
6969
70- :param str text : a word
70+ :param str word : a word
7171 """
7272 word = word .strip ()
73- self .words .add (word )
7473 cur = self .root
7574 for ch in word :
75+ if cur .children is None :
76+ cur .children = {}
7677 child = cur .children .get (ch )
77- if not child :
78+ if child is None :
7879 child = Trie .Node ()
7980 cur .children [ch ] = child
8081 cur = child
81- cur .end = True
82+ if not cur .end :
83+ cur .end = True
84+ self ._word_count += 1
8285
8386 def remove (self , word : str ) -> None :
8487 """Remove a word from the trie.
8588 If the word is not found, do nothing.
8689
87- :param str text : a word
90+ :param str word : a word
8891 """
89- # remove from set first
90- if word not in self .words :
91- return
92- self .words .remove (word )
93- # then remove from nodes
94- parent = self .root
95- data = [] # track path to leaf
92+ # Navigate to the word's end node, recording the path.
93+ node = self .root
94+ path : list [tuple [Trie .Node , Trie .Node , str ]] = []
9695 for ch in word :
97- child = parent .children [ch ]
98- data .append ((parent , child , ch ))
99- parent = child
100- # remove the last one
101- child .end = False
102- # prune up the tree
103- for parent , child , ch in reversed (data ):
96+ if node .children is None :
97+ return # word not in trie
98+ child = node .children .get (ch )
99+ if child is None :
100+ return # word not in trie
101+ path .append ((node , child , ch ))
102+ node = child
103+ if not node .end :
104+ return # path exists but not a complete word
105+ node .end = False
106+ self ._word_count -= 1
107+ # Prune nodes that are now unused (not an end and no children).
108+ # parent.children is always non-None here because the path was
109+ # built by traversing through existing children dicts.
110+ for parent , child , ch in reversed (path ):
104111 if child .end or child .children :
105112 break
106- del parent .children [ch ] # remove from parent dict
113+ if parent .children is not None : # always true; narrows type
114+ del parent .children [ch ]
115+ if not parent .children :
116+ parent .children = None # free empty dict
107117
108118 def prefixes (self , text : str , start : int = 0 ) -> list [str ]:
109119 """List all possible words from first sequence of characters in a word.
@@ -118,8 +128,10 @@ def prefixes(self, text: str, start: int = 0) -> list[str]:
118128 i = start
119129 n = len (text )
120130 while i < n :
131+ if cur .children is None :
132+ break
121133 node = cur .children .get (text [i ])
122- if not node :
134+ if node is None :
123135 break
124136 if node .end :
125137 res .append (text [start : i + 1 ])
@@ -128,13 +140,33 @@ def prefixes(self, text: str, start: int = 0) -> list[str]:
128140 return res
129141
130142 def __contains__ (self , key : str ) -> bool :
131- return key in self .words
143+ cur = self .root
144+ for ch in key :
145+ if cur .children is None :
146+ return False
147+ node = cur .children .get (ch )
148+ if node is None :
149+ return False
150+ cur = node
151+ return cur .end
132152
133153 def __iter__ (self ) -> Iterator [str ]:
134- yield from self .words
154+ # DFS through the trie to yield all stored words.
155+ # A shared mutable prefix list is appended/popped to avoid
156+ # O(k²) list copies that a stack-based approach would incur.
157+ def _dfs (node : Trie .Node , prefix : list [str ]) -> Iterator [str ]:
158+ if node .end :
159+ yield "" .join (prefix )
160+ if node .children :
161+ for ch , child in node .children .items ():
162+ prefix .append (ch )
163+ yield from _dfs (child , prefix )
164+ prefix .pop ()
165+
166+ yield from _dfs (self .root , [])
135167
136168 def __len__ (self ) -> int :
137- return len ( self .words )
169+ return self ._word_count
138170
139171
140172def dict_trie (dict_source : Union [str , Iterable [str ], Trie ]) -> Trie :
0 commit comments