Skip to content

Commit ef81afc

Browse files
authored
Merge pull request #1323 from PyThaiNLP/copilot/fix-memory-usage-newmm-tokenization
fix: reduce Trie and newmm peak memory; add tcc_pos_array()
2 parents 621f8f7 + 7562b15 commit ef81afc

6 files changed

Lines changed: 114 additions & 35 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ The minimum requirement is now Python 3.9.
9494
downloaded (#1317)
9595
- `newmm` tokenization: Exponential-time explosion when text has
9696
many ambiguous breaking points (#1319)
97+
- `Trie`: Reduce memory usage and faster TCC boundary lookups (#1323)
9798

9899
### Security
99100

pythainlp/tokenize/newmm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pythainlp.util import Trie
2828

2929
from pythainlp.tokenize import word_dict_trie
30-
from pythainlp.tokenize.tcc_p import tcc_pos
30+
from pythainlp.tokenize.tcc_p import tcc_pos_array
3131

3232
# match non-Thai tokens
3333
# `|` is used as like "early return",
@@ -86,7 +86,7 @@ def _onecut(text: str, custom_dict: Trie) -> Generator[str, None, None]:
8686

8787
graph_size = 0 # keep track of graph size, if too big, force cutoff
8888

89-
valid_poss = tcc_pos(text) # breaking positions that are TCC-valid
89+
valid_poss = tcc_pos_array(text) # bytearray of valid TCC break positions
9090

9191
len_text = len(text)
9292
pos_list = [0] # priority queue of possible breaking positions
@@ -95,7 +95,7 @@ def _onecut(text: str, custom_dict: Trie) -> Generator[str, None, None]:
9595
begin_pos = heappop(pos_list)
9696
for word in custom_dict.prefixes(text, begin_pos):
9797
end_pos_candidate = begin_pos + len(word)
98-
if end_pos_candidate in valid_poss:
98+
if valid_poss[end_pos_candidate]:
9999
graph[begin_pos].append(end_pos_candidate)
100100
graph_size = graph_size + 1
101101

@@ -121,12 +121,12 @@ def _onecut(text: str, custom_dict: Trie) -> Generator[str, None, None]:
121121
end_pos = m.end()
122122
else: # Thai token, find minimum skip
123123
for pos in range(begin_pos + 1, len_text):
124-
if pos in valid_poss:
124+
if valid_poss[pos]:
125125
words = [
126126
word
127127
for word in custom_dict.prefixes(text, pos)
128128
if (
129-
(pos + len(word) in valid_poss)
129+
valid_poss[pos + len(word)]
130130
and not _PAT_THAI_TWOCHARS.match(word)
131131
)
132132
]

pythainlp/tokenize/tcc_p.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def tcc_pos(text: str) -> set[int]:
8888
"""TCC positions
8989
9090
:param str text: text to be tokenized into character clusters
91-
:return: list of the ending position of subwords
91+
:return: set of the ending positions of character clusters
9292
:rtype: set[int]
9393
"""
9494
if not text or not isinstance(text, str):
@@ -103,6 +103,30 @@ def tcc_pos(text: str) -> set[int]:
103103
return p_set
104104

105105

106+
def tcc_pos_array(text: str) -> bytearray:
107+
"""TCC positions as a bytearray.
108+
109+
Returns a bytearray of length ``len(text) + 1`` where index ``i``
110+
is ``1`` if position ``i`` is a valid Thai Character Cluster boundary,
111+
and ``0`` otherwise. Array-index lookup is faster and uses less
112+
memory than set membership for large texts.
113+
114+
:param str text: text to be tokenized into character clusters
115+
:return: bytearray of valid TCC boundary flags, indexed by position
116+
:rtype: bytearray
117+
"""
118+
if not text or not isinstance(text, str):
119+
return bytearray(1)
120+
121+
arr = bytearray(len(text) + 1)
122+
p = 0
123+
for w in tcc(text):
124+
p += len(w)
125+
arr[p] = 1
126+
127+
return arr
128+
129+
106130
def segment(text: str) -> list[str]:
107131
"""Subword segmentation
108132

pythainlp/util/trie.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import annotations
1010

1111
from collections.abc import Iterable, Iterator
12-
from typing import Union
12+
from typing import Optional, Union
1313

1414

1515
class Trie(Iterable[str]):
@@ -46,64 +46,74 @@ class Trie(Iterable[str]):
4646
# output: 5
4747
"""
4848

49-
words: set[str]
5049
root: Node
50+
_word_count: int
5151

5252
class Node:
5353
__slots__: tuple[str, str] = ("end", "children")
5454

5555
def __init__(self) -> None:
5656
self.end: bool = False
57-
self.children: dict[str, Trie.Node] = {}
57+
# Children dict is created on demand to reduce memory for leaf nodes.
58+
self.children: Optional[dict[str, Trie.Node]] = None
5859

5960
def __init__(self, words: Iterable[str]) -> None:
60-
self.words: set[str] = set(words)
61+
self._word_count: int = 0
6162
self.root: Trie.Node = Trie.Node()
62-
6363
for word in words:
6464
self.add(word)
6565

6666
def add(self, word: str) -> None:
6767
"""Add a word to the trie.
6868
Spaces in front of and following the word will be removed.
6969
70-
:param str text: a word
70+
:param str word: a word
7171
"""
7272
word = word.strip()
73-
self.words.add(word)
7473
cur = self.root
7574
for ch in word:
75+
if cur.children is None:
76+
cur.children = {}
7677
child = cur.children.get(ch)
77-
if not child:
78+
if child is None:
7879
child = Trie.Node()
7980
cur.children[ch] = child
8081
cur = child
81-
cur.end = True
82+
if not cur.end:
83+
cur.end = True
84+
self._word_count += 1
8285

8386
def remove(self, word: str) -> None:
8487
"""Remove a word from the trie.
8588
If the word is not found, do nothing.
8689
87-
:param str text: a word
90+
:param str word: a word
8891
"""
89-
# remove from set first
90-
if word not in self.words:
91-
return
92-
self.words.remove(word)
93-
# then remove from nodes
94-
parent = self.root
95-
data = [] # track path to leaf
92+
# Navigate to the word's end node, recording the path.
93+
node = self.root
94+
path: list[tuple[Trie.Node, Trie.Node, str]] = []
9695
for ch in word:
97-
child = parent.children[ch]
98-
data.append((parent, child, ch))
99-
parent = child
100-
# remove the last one
101-
child.end = False
102-
# prune up the tree
103-
for parent, child, ch in reversed(data):
96+
if node.children is None:
97+
return # word not in trie
98+
child = node.children.get(ch)
99+
if child is None:
100+
return # word not in trie
101+
path.append((node, child, ch))
102+
node = child
103+
if not node.end:
104+
return # path exists but not a complete word
105+
node.end = False
106+
self._word_count -= 1
107+
# Prune nodes that are now unused (not an end and no children).
108+
# parent.children is always non-None here because the path was
109+
# built by traversing through existing children dicts.
110+
for parent, child, ch in reversed(path):
104111
if child.end or child.children:
105112
break
106-
del parent.children[ch] # remove from parent dict
113+
if parent.children is not None: # always true; narrows type
114+
del parent.children[ch]
115+
if not parent.children:
116+
parent.children = None # free empty dict
107117

108118
def prefixes(self, text: str, start: int = 0) -> list[str]:
109119
"""List all possible words from first sequence of characters in a word.
@@ -118,8 +128,10 @@ def prefixes(self, text: str, start: int = 0) -> list[str]:
118128
i = start
119129
n = len(text)
120130
while i < n:
131+
if cur.children is None:
132+
break
121133
node = cur.children.get(text[i])
122-
if not node:
134+
if node is None:
123135
break
124136
if node.end:
125137
res.append(text[start : i + 1])
@@ -128,13 +140,33 @@ def prefixes(self, text: str, start: int = 0) -> list[str]:
128140
return res
129141

130142
def __contains__(self, key: str) -> bool:
131-
return key in self.words
143+
cur = self.root
144+
for ch in key:
145+
if cur.children is None:
146+
return False
147+
node = cur.children.get(ch)
148+
if node is None:
149+
return False
150+
cur = node
151+
return cur.end
132152

133153
def __iter__(self) -> Iterator[str]:
134-
yield from self.words
154+
# DFS through the trie to yield all stored words.
155+
# A shared mutable prefix list is appended/popped to avoid
156+
# O(k²) list copies that a stack-based approach would incur.
157+
def _dfs(node: Trie.Node, prefix: list[str]) -> Iterator[str]:
158+
if node.end:
159+
yield "".join(prefix)
160+
if node.children:
161+
for ch, child in node.children.items():
162+
prefix.append(ch)
163+
yield from _dfs(child, prefix)
164+
prefix.pop()
165+
166+
yield from _dfs(self.root, [])
135167

136168
def __len__(self) -> int:
137-
return len(self.words)
169+
return self._word_count
138170

139171

140172
def dict_trie(dict_source: Union[str, Iterable[str], Trie]) -> Trie:

tests/core/test_tokenize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,14 @@ def test_tcc_p(self):
684684
# )
685685
self.assertEqual(list(tcc_p.tcc("")), [])
686686
self.assertEqual(tcc_p.tcc_pos(""), set())
687+
# tcc_pos_array: edge cases
688+
self.assertIsInstance(tcc_p.tcc_pos_array(""), bytearray)
689+
self.assertIsInstance(tcc_p.tcc_pos_array(None), bytearray)
690+
self.assertIsInstance(tcc_p.tcc_pos_array(42), bytearray)
691+
# valid text: array length must equal len(text)+1 and mark boundaries
692+
arr = tcc_p.tcc_pos_array("ประเทศ")
693+
self.assertEqual(len(arr), len("ประเทศ") + 1)
694+
self.assertEqual(arr[0], 0) # position 0 is never a boundary
687695

688696
def test_display_cell_tokenize(self):
689697
self.assertEqual(display_cell_tokenize(""), [])

tests/core/test_util.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,20 @@ def test_trie(self):
516516
trie.remove("ทด")
517517
self.assertEqual(len(trie), 2)
518518

519+
# _word_count must not double-count re-added words
520+
trie2 = Trie(["ก", "ข", "ก"])
521+
self.assertEqual(len(trie2), 2)
522+
trie2.add("ก") # already present – count must stay the same
523+
self.assertEqual(len(trie2), 2)
524+
trie2.add("ค")
525+
self.assertEqual(len(trie2), 3)
526+
trie2.remove("ข")
527+
self.assertEqual(len(trie2), 2)
528+
trie2.remove("ข") # removing non-existent word must not change count
529+
self.assertEqual(len(trie2), 2)
530+
# All remaining words must be reachable via __iter__
531+
self.assertEqual(sorted(trie2), ["ก", "ค"])
532+
519533
trie = Trie([])
520534
self.assertEqual(len(trie), 0)
521535
trie.remove("หมด")

0 commit comments

Comments
 (0)