diff --git a/components/knowledge_engine/langrag.py b/components/knowledge_engine/langrag.py index c9b204d..fd0a7ab 100644 --- a/components/knowledge_engine/langrag.py +++ b/components/knowledge_engine/langrag.py @@ -1,4 +1,5 @@ import logging +import uuid from langbot_plugin.api.definition.components.knowledge_engine import ( KnowledgeEngine, @@ -33,6 +34,34 @@ def _query_log_ref(query: str | None) -> str: return f"hash={hash_text(query)}, length={len(query or '')}" +def _new_span_id() -> str: + return f"span-{uuid.uuid4().hex[:16]}" + + +def _trace_spans_from_stages( + stage_durations: dict[str, float], + parent_span_id: str | None, + base_attributes: dict | None = None, +) -> list[dict]: + spans: list[dict] = [] + for stage, duration_ms in stage_durations.items(): + spans.append( + { + "span_id": _new_span_id(), + "parent_span_id": parent_span_id, + "name": stage, + "kind": "rag.stage", + "status": "success", + "duration_ms": duration_ms, + "attributes": { + **(base_attributes or {}), + "stage": stage, + }, + } + ) + return spans + + class LangRAG(KnowledgeEngine): """Simple Knowledge Engine implementation using Plugin IPC. @@ -477,7 +506,23 @@ async def ingest(self, context: IngestionContext) -> IngestionResult: async def retrieve(self, context: RetrievalContext) -> RetrievalResponse: """Retrieve relevant content with support for vector, full-text, and hybrid search.""" started_at = telemetry.start_timer() - trace_id = telemetry.new_trace_id("retrieval") + host_trace_context = getattr(context, "trace_context", None) + host_trace_context_data = ( + host_trace_context.model_dump() + if hasattr(host_trace_context, "model_dump") + else (host_trace_context or {}) + ) + host_parent_span_id = ( + host_trace_context_data.get("parent_span_id") + if isinstance(host_trace_context_data, dict) + else None + ) + trace_id = ( + host_trace_context_data.get("trace_id") + if isinstance(host_trace_context_data, dict) + and host_trace_context_data.get("trace_id") + else telemetry.new_trace_id("retrieval") + ) stage_durations: dict[str, float] = {} query = context.query top_k = context.retrieval_settings.get("top_k", 5) @@ -785,8 +830,27 @@ async def retrieve(self, context: RetrievalContext) -> RetrievalResponse: total_found=len(entries), metadata={ "trace_id": trace_id, + "status": telemetry_status, + "duration_ms": telemetry.elapsed_ms(started_at), "raw_count": raw_count, + "result_count": result_count, + "reference_count": reference_count, + "distance_min": distance_min, + "distance_avg": distance_avg, + "distance_max": distance_max, "stage_durations_ms": stage_durations, + "trace_spans": _trace_spans_from_stages( + stage_durations, + parent_span_id=host_parent_span_id, + base_attributes={ + "query_hash": hash_text(query), + "query_length": len(query or ""), + "collection_id_hash": hash_text(collection_id), + "knowledge_base_id_hash": hash_text( + context.knowledge_base_id + ), + }, + ), }, ) except Exception as e: