Skip to content

Commit 6b996cd

Browse files
committed
feat(ai-cache): wire semantic.top_k through L2 vector search
1 parent e665d2a commit 6b996cd

1 file changed

Lines changed: 29 additions & 25 deletions

File tree

apisix/plugins/ai-cache/semantic.lua

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,23 @@ function _M.search(conf, scope_hash, embedding_vec, threshold)
8989
end
9090

9191
local binary_vec = pack_vector(embedding_vec)
92+
local top_k = (conf.semantic and conf.semantic.top_k) or 1
93+
local top_k_str = tostring(top_k)
9294

9395
local query
9496
if scope_hash == "" then
95-
query = "*=>[KNN 1 @embedding $vec AS dist]"
97+
query = "*=>[KNN " .. top_k_str .. " @embedding $vec AS dist]"
9698
else
97-
query = "@scope:{" .. scope_hash .. "}=>[KNN 1 @embedding $vec AS dist]"
99+
query = "@scope:{" .. scope_hash .. "}=>[KNN " .. top_k_str
100+
.. " @embedding $vec AS dist]"
98101
end
99102

100103
local res, search_err = red["FT.SEARCH"](red,
101104
index_name(#embedding_vec),
102105
query,
103106
"PARAMS", "2", "vec", binary_vec,
104107
"SORTBY", "dist", "ASC",
105-
"LIMIT", "0", "1",
108+
"LIMIT", "0", top_k_str,
106109
"RETURN", "2", "response", "dist",
107110
"DIALECT", "2"
108111
)
@@ -116,31 +119,32 @@ function _M.search(conf, scope_hash, embedding_vec, threshold)
116119
return nil, nil, nil
117120
end
118121

119-
-- RESP2: {count, key, {field, val, field, val, ...}, ...}
120-
local fields = res[3]
121-
if type(fields) ~= "table" then
122-
return nil, nil, nil
123-
end
124-
125-
local response_text, dist
126-
for i = 1, #fields, 2 do
127-
if fields[i] == "response" then
128-
response_text = fields[i + 1]
129-
elseif fields[i] == "dist" then
130-
dist = tonumber(fields[i + 1])
122+
-- RESP2: {count, key1, fields1, key2, fields2, ...}
123+
-- Results are sorted by dist ASC. Iterate candidates and return the first
124+
-- one whose similarity meets the threshold; skip candidates with missing
125+
-- or corrupt fields.
126+
for i = 3, #res, 2 do
127+
local fields = res[i]
128+
if type(fields) == "table" then
129+
local response_text, dist
130+
for j = 1, #fields, 2 do
131+
if fields[j] == "response" then
132+
response_text = fields[j + 1]
133+
elseif fields[j] == "dist" then
134+
dist = tonumber(fields[j + 1])
135+
end
136+
end
137+
138+
if response_text and dist then
139+
local similarity = 1 - dist
140+
if similarity >= threshold then
141+
return response_text, similarity, nil
142+
end
143+
end
131144
end
132145
end
133146

134-
if not response_text or not dist then
135-
return nil, nil, nil
136-
end
137-
138-
local similarity = 1 - dist
139-
if similarity < threshold then
140-
return nil, nil, nil
141-
end
142-
143-
return response_text, similarity, nil
147+
return nil, nil, nil
144148
end
145149

146150

0 commit comments

Comments
 (0)