Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion components/knowledge_engine/langrag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import uuid

from langbot_plugin.api.definition.components.knowledge_engine import (
KnowledgeEngine,
Expand Down Expand Up @@ -33,6 +34,34 @@ def _query_log_ref(query: str | None) -> str:
return f"hash={hash_text(query)}, length={len(query or '')}"


def _new_span_id() -> str:
return f"span-{uuid.uuid4().hex[:16]}"


def _trace_spans_from_stages(
stage_durations: dict[str, float],
parent_span_id: str | None,
base_attributes: dict | None = None,
) -> list[dict]:
spans: list[dict] = []
for stage, duration_ms in stage_durations.items():
spans.append(
{
"span_id": _new_span_id(),
"parent_span_id": parent_span_id,
"name": stage,
"kind": "rag.stage",
"status": "success",
"duration_ms": duration_ms,
"attributes": {
**(base_attributes or {}),
"stage": stage,
},
}
)
return spans


class LangRAG(KnowledgeEngine):
"""Simple Knowledge Engine implementation using Plugin IPC.

Expand Down Expand Up @@ -477,7 +506,23 @@ async def ingest(self, context: IngestionContext) -> IngestionResult:
async def retrieve(self, context: RetrievalContext) -> RetrievalResponse:
"""Retrieve relevant content with support for vector, full-text, and hybrid search."""
started_at = telemetry.start_timer()
trace_id = telemetry.new_trace_id("retrieval")
host_trace_context = getattr(context, "trace_context", None)
host_trace_context_data = (
host_trace_context.model_dump()
if hasattr(host_trace_context, "model_dump")
else (host_trace_context or {})
)
host_parent_span_id = (
host_trace_context_data.get("parent_span_id")
if isinstance(host_trace_context_data, dict)
else None
)
trace_id = (
host_trace_context_data.get("trace_id")
if isinstance(host_trace_context_data, dict)
and host_trace_context_data.get("trace_id")
else telemetry.new_trace_id("retrieval")
)
stage_durations: dict[str, float] = {}
query = context.query
top_k = context.retrieval_settings.get("top_k", 5)
Expand Down Expand Up @@ -785,8 +830,27 @@ async def retrieve(self, context: RetrievalContext) -> RetrievalResponse:
total_found=len(entries),
metadata={
"trace_id": trace_id,
"status": telemetry_status,
"duration_ms": telemetry.elapsed_ms(started_at),
"raw_count": raw_count,
"result_count": result_count,
"reference_count": reference_count,
"distance_min": distance_min,
"distance_avg": distance_avg,
"distance_max": distance_max,
"stage_durations_ms": stage_durations,
"trace_spans": _trace_spans_from_stages(
stage_durations,
parent_span_id=host_parent_span_id,
base_attributes={
"query_hash": hash_text(query),
"query_length": len(query or ""),
"collection_id_hash": hash_text(collection_id),
"knowledge_base_id_hash": hash_text(
context.knowledge_base_id
),
},
),
},
)
except Exception as e:
Expand Down