diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index b8d1ae101..030b36b88 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -1,7 +1,7 @@ import json import threading import logging -from typing import List, Optional +from typing import Any, Dict, List, Optional from urllib.parse import urljoin from jinja2 import Template, StrictUndefined @@ -33,12 +33,71 @@ from utils.config_utils import tenant_config_manager, get_model_name_from_config from utils.context_utils import build_context_components from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET +from consts.model import AgentToolParamsRequest, ToolParamsRequest from consts.exceptions import ValidationError logger = logging.getLogger("create_agent_info") logger.setLevel(logging.DEBUG) +def _normalize_tool_params_request(tool_params: Optional[ToolParamsRequest | Dict[str, Any]]) -> ToolParamsRequest: + """Normalize request-scoped tool parameter overrides into a ToolParamsRequest.""" + if tool_params is None: + return ToolParamsRequest() + if isinstance(tool_params, ToolParamsRequest): + return tool_params + if not isinstance(tool_params, dict): + raise ValidationError("tool_params must be an object.") + try: + return ToolParamsRequest.model_validate(tool_params) + except Exception as exc: + raise ValidationError(f"Invalid tool_params payload: {exc}") from exc + + +def _get_agent_tool_overrides( + tool_params: Optional[ToolParamsRequest], + agent_name: Optional[str], +) -> Dict[str, Dict[str, Any]]: + """Resolve tool overrides for a specific agent by its name.""" + if tool_params is None: + return {} + if not agent_name: + return {} + agent_override = tool_params.agents.get(agent_name) + if agent_override is None: + return {} + return dict(agent_override.tools) + + +def _merge_tool_params( + tool_record: Dict[str, Any], + override_params: Optional[Dict[str, Any]], + extra_params: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Merge request overrides on top of tool instance defaults from DB. + + Args: + tool_record: Tool configuration from database + override_params: Request-scoped overrides from tool_params + extra_params: Additional internal params not in DB schema (e.g., document_paths) + + Returns: + Merged params dict with DB defaults, overrides, and extra params + """ + merged_params: Dict[str, Any] = {} + for param in tool_record.get("params", []): + merged_params[param["name"]] = param.get("default") + + if override_params: + merged_params.update(override_params) + + # Extra params (e.g., internal access control params) always take precedence + if extra_params: + merged_params.update(extra_params) + + return merged_params + + def _build_internal_s3_url(file: dict) -> str: """Build a valid S3 URL for internal tools from uploaded file metadata.""" if not isinstance(file, dict): @@ -310,7 +369,9 @@ async def create_agent_config( allow_memory_search: bool = True, version_no: int = 0, override_model_id: int | None = None, + tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None, ): + normalized_tool_params = _normalize_tool_params_request(tool_params) agent_info = search_agent_info_by_agent_id( agent_id=agent_id, tenant_id=tenant_id, version_no=version_no) @@ -331,13 +392,20 @@ async def create_agent_config( allow_memory_search=allow_memory_search, version_no=sub_agent_version_no, override_model_id=None, + tool_params=normalized_tool_params, ) managed_agents.append(sub_agent_config) # create external A2A agents (synchronous function, no await needed) external_a2a_agents = _get_external_a2a_agents(agent_id, tenant_id, version_no) - tool_list = await create_tool_config_list(agent_id, tenant_id, user_id, version_no=version_no) + tool_list = await create_tool_config_list( + agent_id, + tenant_id, + user_id, + version_no=version_no, + tool_params=normalized_tool_params, + ) # Build system prompt: prioritize segmented fields, fallback to original prompt field if not available duty_prompt = agent_info.get("duty_prompt", "") @@ -562,17 +630,43 @@ async def create_agent_config( return agent_config -async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int = 0): - # create tool +async def create_tool_config_list( + agent_id, + tenant_id, + user_id, + version_no: int = 0, + tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None, +): tool_config_list = [] langchain_tools = await discover_langchain_tools() + normalized_tool_params = _normalize_tool_params_request(tool_params) # now only admin can modify the agent, user_id is not used tools_list = search_tools_for_sub_agent(agent_id, tenant_id, version_no=version_no) + + # Look up agent name for use in error messages. + # Agent name is optional for tool_params matching (matching uses tool identifiers only), + # but we include it in error messages so callers can identify which agent/tool caused a failure. + agent_info = search_agent_info_by_agent_id(agent_id=agent_id, tenant_id=tenant_id, version_no=version_no) + agent_name = agent_info.get("name") if agent_info else None + agent_tool_overrides = _get_agent_tool_overrides(normalized_tool_params, agent_name) + + tool_keys_seen = set() for tool in tools_list: - param_dict = {} - for param in tool.get("params", []): - param_dict[param["name"]] = param.get("default") + tool_identifier = tool.get("name") or tool.get("class_name") + if tool_identifier in tool_keys_seen: + raise ValidationError( + f"Duplicate tool identifier '{tool_identifier}' found in agent '{agent_name or agent_id}'." + ) + tool_keys_seen.add(tool_identifier) + + override_params = None + if tool.get("name") in agent_tool_overrides: + override_params = agent_tool_overrides[tool.get("name")] + elif tool.get("class_name") in agent_tool_overrides: + override_params = agent_tool_overrides[tool.get("class_name")] + + param_dict = _merge_tool_params(tool, override_params) tool_config = ToolConfig( class_name=tool.get("class_name"), name=tool.get("name"), @@ -591,12 +685,21 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tool_config.metadata = langchain_tool break + # Extract document_paths for KnowledgeBaseSearchTool (internal access control, not in DB schema) + document_paths = None + if override_params and "document_paths" in override_params: + document_paths = override_params.get("document_paths") + # Also check using the tool name as key + if not document_paths: + kb_overrides = agent_tool_overrides.get("knowledge_base_search") + if kb_overrides and "document_paths" in kb_overrides: + document_paths = kb_overrides.get("document_paths") + # special logic for search tools that may use reranking models if tool_config.class_name == "KnowledgeBaseSearchTool": - rerank = param_dict.get("rerank", False) - rerank_model_name = param_dict.get("rerank_model_name", "") + rerank = tool_config.params.get("rerank", False) + rerank_model_name = tool_config.params.get("rerank_model_name", "") rerank_model = None - is_multimodal = bool(tool_config.params.pop("multimodal", False)) if rerank and rerank_model_name: rerank_model = get_rerank_model( tenant_id=tenant_id, model_name=rerank_model_name @@ -604,7 +707,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int # Build display_name to index_name mapping for LLM parameter conversion # Also build reverse mapping (index_name -> display_name) for knowledge_base_summary - index_names = param_dict.get("index_names", []) + index_names = tool_config.params.get("index_names", []) display_name_to_index_map = {} index_name_to_display_map = {} if index_names: @@ -620,12 +723,14 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int "rerank_model": rerank_model, "display_name_to_index_map": display_name_to_index_map, "index_name_to_display_map": index_name_to_display_map, + # Internal access control: restrict results to specific document paths (path_or_urls) + "document_paths": document_paths, } - # Must have embedding model for knowledge base search if not index_names: raise ValidationError( - "Embedding model is required for knowledge_base_search but index_names is empty") + f"[{agent_name or agent_id}] knowledge_base_search tool requires index_names, " + f"but it is not configured in the agent and not provided via tool_params.") embedding_model, _, _ = get_embedding_model_by_index_name(tenant_id, index_names[0]) if not embedding_model: @@ -634,8 +739,8 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int f"Please configure an embedding model for this knowledge base.") tool_config.metadata["embedding_model"] = embedding_model elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: - rerank = param_dict.get("rerank", False) - rerank_model_name = param_dict.get("rerank_model_name", "") + rerank = tool_config.params.get("rerank", False) + rerank_model_name = tool_config.params.get("rerank_model_name", "") rerank_model = None if rerank and rerank_model_name: rerank_model = get_rerank_model( @@ -929,6 +1034,7 @@ async def create_agent_run_info( is_debug: bool = False, override_version_no: int | None = None, override_model_id: int | None = None, + tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None, ): # Determine which version_no to use based on is_debug flag # If is_debug=false, use the current published version (current_version_no) @@ -961,7 +1067,7 @@ async def create_agent_run_info( if override_model_id is not None: create_config_kwargs["override_model_id"] = override_model_id - agent_config = await create_agent_config(**create_config_kwargs) + agent_config = await create_agent_config(**create_config_kwargs, tool_params=tool_params) remote_mcp_list = await get_remote_mcp_server_list(tenant_id=tenant_id, is_need_auth=True) default_mcp_url = urljoin(LOCAL_MCP_SERVER, "sse") diff --git a/backend/apps/northbound_app.py b/backend/apps/northbound_app.py index e6aff8e06..d8c652f5f 100644 --- a/backend/apps/northbound_app.py +++ b/backend/apps/northbound_app.py @@ -1,14 +1,16 @@ import logging from http import HTTPStatus from typing import Optional, Dict, Any -from urllib.parse import urlparse +from urllib.parse import urlparse, unquote +import re import uuid import httpx -from fastapi import APIRouter, Body, Header, Request, HTTPException, Query +from fastapi import APIRouter, Body, File, Header, HTTPException, Query, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse -from consts.exceptions import LimitExceededError, UnauthorizedError +from consts.exceptions import LimitExceededError, UnauthorizedError, ConversationNotFoundError +from consts.model import ToolParamsRequest from services.northbound_service import ( NorthboundContext, get_conversation_history, @@ -17,16 +19,35 @@ stop_chat, get_agent_info_list, update_conversation_title, + upload_files_for_northbound, ) from utils.auth_utils import validate_bearer_token, get_user_and_tenant_by_access_key +from .file_management_app import build_content_disposition_header + router = APIRouter(prefix="/nb/v1", tags=["northbound"]) __all__ = ["router", "_get_northbound_context"] +def _resolve_proxy_download_filename(presigned_url: str, content_disposition: str) -> str: + """Resolve a stable download filename for the northbound file proxy.""" + if content_disposition: + filename_star_match = re.search(r"filename\*=UTF-8''([^;]+)", content_disposition) + if filename_star_match: + return unquote(filename_star_match.group(1)) or "download" + + filename_match = re.search(r'filename="?([^";]+)"?', content_disposition) + if filename_match: + return filename_match.group(1) or "download" + + path = unquote(urlparse(presigned_url).path) + filename = path.split("/")[-1].strip() + return filename or "download" + + async def _get_northbound_context(request: Request) -> NorthboundContext: """ Build northbound context from request. @@ -109,13 +130,118 @@ async def health_check(): return {"status": "healthy", "service": "northbound-api"} -@router.post("/chat/run") +@router.post( + "/chat/attachments/upload", + summary="Upload chat attachments for northbound runs", + description=( + "Upload one or more files for later use in `/nb/v1/chat/run`. " + "Successful uploads return reusable `s3_url` references." + ), +) +async def upload_chat_attachments( + request: Request, + files: list[UploadFile] = File( + ..., + description="List of files to upload", + examples=["report.pdf", "diagram.png"], + ), +): + try: + ctx: NorthboundContext = await _get_northbound_context(request) + return JSONResponse( + status_code=HTTPStatus.OK, + content=await upload_files_for_northbound(ctx=ctx, files=files), + ) + except LimitExceededError as e: + logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, + detail="Too Many Requests: rate limit exceeded") + except ValueError as e: + logging.error(f"Invalid northbound upload request: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except PermissionError as e: + logging.error(f"Permission denied while uploading northbound files: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=str(e)) + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Failed to upload northbound files: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Internal Server Error") + + +@router.post( + "/chat/run", + summary="Start a northbound chat run with optional attachments", + description=( + "Run a northbound chat request. Upload attachments first through " + "`/nb/v1/chat/attachments/upload`, then pass the returned `s3_url` values " + "through the `attachments` field." + ), +) async def run_chat( request: Request, - conversation_id: Optional[int] = Body(None, embed=True), - agent_name: str = Body(..., embed=True), - query: str = Body(..., embed=True), - meta_data: Optional[Dict[str, Any]] = Body(None, embed=True), + conversation_id: Optional[int] = Body( + None, + embed=True, + description="Existing conversation ID. Omit to create a new conversation.", + examples=[123], + ), + agent_name: str = Body( + ..., + embed=True, + description="Target agent name.", + examples=["general-assistant"], + ), + query: str = Body( + ..., + embed=True, + description="User input to send to the agent.", + examples=["Summarize the uploaded report and list the key risks."], + ), + attachments: Optional[list[str]] = Body( + None, + embed=True, + description="S3 URLs returned by the attachment upload API.", + examples=[["s3://nexent/attachments/user123/20260609_report.pdf"]], + ), + meta_data: Optional[Dict[str, Any]] = Body( + None, + embed=True, + description="Optional metadata passed through for audit and usage logging.", + examples=[{"source": "crm", "ticket_id": "INC-1001"}], + ), + tool_params: Optional[ToolParamsRequest] = Body( + None, + embed=True, + description="Optional request-scoped overrides for tool initialization parameters. " + "Overrides DB-persisted params (ag_tool_instance_t.params) on a per-run basis. " + "Conflict resolution: request value wins over DB value. " + "Structure: agents -> {agent_name} -> tools -> {tool_name} -> {param_name: param_value}. " + "tool_name matching: first by tool.name, then by tool.class_name. " + "Unknown param names cause a ValidationError (400). " + "Metadata-derived fields (e.g., vdb_core, embedding_model) are recalculated " + "from merged params for tools like KnowledgeBaseSearchTool, DifySearchTool, DataMateSearchTool.", + examples=[{ + "agents": { + "common_sense_qa_assistant": { + "tools": { + "analyze_text_file": { + "chunk_size": 4000, + "summary_only": True, + "prompt": "Please provide a concise summary of this document focusing on key facts." + }, + "knowledge_base_search": { + "top_k": 10, + "rerank": True, + "rerank_model_name": "gte-rerank-v2", + "index_names": ["nexent-docs", "faq-index"] + } + } + } + } + }], + ), idempotency_key: Optional[str] = Header(None, alias="Idempotency-Key"), ): try: @@ -125,13 +251,21 @@ async def run_chat( conversation_id=conversation_id, agent_name=agent_name, query=query, + attachments=attachments, meta_data=meta_data, + tool_params=tool_params, idempotency_key=idempotency_key, ) except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") + except ValueError as e: + logging.error(f"Invalid northbound chat request: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except PermissionError as e: + logging.error(f"Permission denied while running northbound chat: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=str(e)) except HTTPException as e: raise e except Exception as e: @@ -254,6 +388,9 @@ async def update_convs_title( logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") + except ConversationNotFoundError as e: + logging.error(f"Conversation not found while updating title: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) except HTTPException as e: raise e except Exception as e: @@ -312,12 +449,12 @@ async def fetch_file_from_presigned_url( content_type = response.headers.get("Content-Type", "application/octet-stream") content_disposition = response.headers.get("Content-Disposition", "") + download_filename = _resolve_proxy_download_filename(presigned_url, content_disposition) headers = { "Content-Type": content_type, + "Content-Disposition": build_content_disposition_header(download_filename), } - if content_disposition: - headers["Content-Disposition"] = content_disposition return StreamingResponse( content=response.aiter_bytes(), diff --git a/backend/consts/model.py b/backend/consts/model.py index e45f49344..d72538ecb 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -230,6 +230,24 @@ class HistoryItem(BaseModel): minio_files: Optional[List[Dict[str, Any]]] = None +class AgentToolParamsRequest(BaseModel): + """Request-scoped tool parameter overrides for a single agent.""" + + tools: Dict[str, Dict[str, Any]] = Field( + default_factory=dict, + description="Mapping from tool identifier to request-scoped override params", + ) + + +class ToolParamsRequest(BaseModel): + """Request-scoped tool parameter overrides for main and managed agents.""" + + agents: Dict[str, AgentToolParamsRequest] = Field( + default_factory=dict, + description="Mapping from agent identifier to tool parameter overrides", + ) + + class AgentRequest(BaseModel): query: str conversation_id: Optional[int] = None @@ -240,6 +258,7 @@ class AgentRequest(BaseModel): model_id: Optional[int] = None version_no: Optional[int] = None is_debug: Optional[bool] = False + tool_params: Optional[ToolParamsRequest] = None class MessageUnit(BaseModel): diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 8e2147a41..04fd409fa 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -2046,6 +2046,7 @@ async def prepare_agent_run( is_debug=agent_request.is_debug, override_version_no=agent_request.version_no, override_model_id=agent_request.model_id, + tool_params=agent_request.tool_params, ) # Mount conversation-level reusable ContextManager if enabled diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index 81c2bfa98..34db53525 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -8,6 +8,7 @@ from consts.const import LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE, DEFAULT_EN_TITLE, DEFAULT_ZH_TITLE from consts.model import AgentRequest, ConversationResponse, MessageRequest, MessageUnit +from consts.exceptions import ConversationNotFoundError from database.conversation_db import ( create_conversation, create_conversation_message, @@ -298,7 +299,9 @@ def update_conversation_title(conversation_id: int, title: str, user_id: str = N """ success = rename_conversation(conversation_id, title, user_id) if not success: - raise Exception(f"Conversation {conversation_id} does not exist or has been deleted") + raise ConversationNotFoundError( + f"Conversation {conversation_id} does not exist or has been deleted" + ) return success diff --git a/backend/services/northbound_service.py b/backend/services/northbound_service.py index a6eaed77d..80cd505c0 100644 --- a/backend/services/northbound_service.py +++ b/backend/services/northbound_service.py @@ -3,29 +3,37 @@ import logging import time from dataclasses import dataclass -from typing import Any, Dict, Optional +from os.path import basename +from typing import Any, Dict, List, Optional +from fastapi import HTTPException, UploadFile from fastapi.responses import StreamingResponse + +from consts.const import ASSET_OWNER_TENANT_ID from consts.exceptions import ( LimitExceededError, UnauthorizedError, + ConversationNotFoundError, ) -from consts.model import AgentRequest +from consts.model import AgentRequest, ToolParamsRequest from database.conversation_db import get_conversation_messages from database.token_db import log_token_usage, get_latest_usage_metadata from services.agent_service import ( run_agent_stream, stop_agent_tasks, - list_all_agent_info_impl, get_agent_id_by_name ) +from services.agent_version_service import list_published_agents_impl from services.conversation_management_service import ( save_conversation_user, get_conversation_list_service, create_new_conversation, update_conversation_title as update_conversation_title_service, ) +from services.file_management_service import upload_to_minio, resolve_minio_upload_folder, validate_urls_access +from database.attachment_db import build_s3_url, get_file_url +from nexent.multi_modal.utils import parse_s3_url logger = logging.getLogger("northbound_service") @@ -39,6 +47,102 @@ class NorthboundContext: token_id: int = 0 +def _build_northbound_file_descriptor(upload_result: Dict[str, Any]) -> Dict[str, Any]: + """Normalize upload metadata for northbound API consumers.""" + object_name = str(upload_result.get("object_name") or "").strip() + file_name = str(upload_result.get("file_name") or basename(object_name) or "") + descriptor = { + "name": file_name, + "object_name": object_name, + "s3_url": build_s3_url(object_name), + } + presigned_url = upload_result.get("presigned_url") + if presigned_url: + descriptor["presigned_url"] = presigned_url + return descriptor + + +async def upload_files_for_northbound( + ctx: NorthboundContext, + files: List[UploadFile], + folder: str = "attachments", +) -> Dict[str, Any]: + """Upload files for northbound callers and return reusable storage references.""" + if not files: + raise ValueError("No files in the request") + + actual_folder = resolve_minio_upload_folder(folder, ctx.user_id, ctx.tenant_id) + results = await upload_to_minio(files=files, folder=actual_folder) + normalized_files = [ + _build_northbound_file_descriptor(result) + for result in results + if result.get("success") and result.get("object_name") + ] + + if not normalized_files: + raise ValueError("No valid files uploaded") + + success_count = sum(1 for result in results if result.get("success", False)) + failed_count = sum(1 for result in results if not result.get("success", False)) + + return { + "message": f"Processed {len(results)} files", + "requestId": ctx.request_id, + "summary": { + "total": len(results), + "uploaded": success_count, + "failed": failed_count, + }, + "files": normalized_files, + } + + +def _normalize_northbound_attachments( + attachments: Optional[List[str]], + user_id: str, + tenant_id: str, +) -> Optional[List[Dict[str, Any]]]: + """Convert northbound S3 attachment references into internal minio_files objects.""" + if attachments is None: + return None + if not isinstance(attachments, list): + raise ValueError("attachments must be an array of S3 URLs") + + normalized_files: List[Dict[str, Any]] = [] + for attachment in attachments: + if not isinstance(attachment, str) or not attachment.strip(): + raise ValueError("attachments must contain non-empty S3 URLs") + + attachment_url = attachment.strip() + if not attachment_url.startswith("s3://"): + raise ValueError(f"Invalid S3 URL format: {attachment_url}") + + try: + _, object_name = parse_s3_url(attachment_url) + except ValueError as exc: + raise ValueError(f"Invalid S3 URL format: {attachment_url}") from exc + + try: + validate_urls_access([attachment_url], user_id, tenant_id) + presigned_result = get_file_url(object_name=object_name, expires=86400) + except PermissionError as exc: + detail = str(exc) + if "Invalid S3 URL format" in detail: + raise ValueError(detail) from exc + raise PermissionError(detail) from exc + + normalized_file = { + "name": basename(object_name.rstrip("/")), + "object_name": object_name, + "url": attachment_url, + } + if presigned_result.get("success") and presigned_result.get("url"): + normalized_file["presigned_url"] = presigned_result["url"] + normalized_files.append(normalized_file) + + return normalized_files + + # ----------------------------- # In-memory idempotency and rate limit placeholders # ----------------------------- @@ -111,6 +215,12 @@ def _build_idempotency_key(*parts: Any) -> str: return ":".join(processed) +def _build_title_update_idempotency_key(tenant_id: str, conversation_id: int, title: str) -> str: + """Build an ASCII-safe idempotency key for title updates.""" + title_hash = hashlib.sha256(title.encode("utf-8")).hexdigest() + return _build_idempotency_key(tenant_id, str(conversation_id), title_hash) + + # ----------------------------- # Agent resolver # ----------------------------- @@ -126,7 +236,9 @@ async def start_streaming_chat( conversation_id: Optional[int], agent_name: str, query: str, + attachments: Optional[List[str]] = None, meta_data: Optional[Dict[str, Any]] = None, + tool_params: Optional[ToolParamsRequest] = None, idempotency_key: Optional[str] = None ) -> StreamingResponse: try: @@ -145,6 +257,11 @@ async def start_streaming_chat( # Get history according to internal_conversation_id history_resp = await get_conversation_history_internal(ctx, internal_conversation_id) agent_id = await get_agent_id_by_name(agent_name=agent_name, tenant_id=ctx.tenant_id) + normalized_attachments = _normalize_northbound_attachments( + attachments=attachments, + user_id=ctx.user_id, + tenant_id=ctx.tenant_id, + ) # Idempotency: only prevent concurrent duplicate starts composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, str(conversation_id), agent_id, query) await idempotency_start(composed_key) @@ -153,8 +270,9 @@ async def start_streaming_chat( agent_id=agent_id, query=query, history=(history_resp.get("data", {})).get("history", []), - minio_files=None, + minio_files=normalized_attachments, is_debug=False, + tool_params=tool_params, ) # Synchronously persist the user message before starting the stream to avoid race conditions @@ -284,7 +402,18 @@ async def get_conversation_history(ctx: NorthboundContext, conversation_id: int) async def get_agent_info_list(ctx: NorthboundContext) -> Dict[str, Any]: try: - agent_info_list = await list_all_agent_info_impl(tenant_id=ctx.tenant_id, user_id=ctx.user_id) + agent_info_list = await list_published_agents_impl( + tenant_id=ctx.tenant_id, + user_id=ctx.user_id, + ) + # Match the same scope as /agent/published_list: non-asset-owner tenants + # also get the asset owner's published agents merged in. + if ctx.tenant_id != ASSET_OWNER_TENANT_ID: + asset_agent_list = await list_published_agents_impl( + tenant_id=ASSET_OWNER_TENANT_ID, + user_id=ctx.user_id, + ) + agent_info_list.extend(asset_agent_list) # Remove internal information that partner don't need for agent_info in agent_info_list: agent_info.pop("agent_id", None) @@ -298,7 +427,11 @@ async def update_conversation_title(ctx: NorthboundContext, conversation_id: int composed_key: Optional[str] = None try: # Idempotency: avoid concurrent duplicate title update for same conversation - composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, str(conversation_id), title) + composed_key = idempotency_key or _build_title_update_idempotency_key( + ctx.tenant_id, + conversation_id, + title, + ) await idempotency_start(composed_key) update_conversation_title_service(conversation_id, title, ctx.user_id) @@ -324,6 +457,8 @@ async def update_conversation_title(ctx: NorthboundContext, conversation_id: int } except LimitExceededError as _: raise LimitExceededError("Duplicate request is still running, please wait.") + except ConversationNotFoundError: + raise except Exception as e: raise Exception(f"Failed to update conversation title for conversation_id {conversation_id}: {str(e)}") finally: diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 08f4896ab..3cbf5edc5 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -782,6 +782,8 @@ def _validate_local_tool( 'embedding_model': embedding_model, 'rerank_model': rerank_model, 'display_name_to_index_map': display_name_to_index_map, + # Internal access control: restrict results to specific document paths (path_or_urls) + 'document_paths': instantiation_params.get('document_paths'), } tool_instance = tool_class(**params) elif tool_name in ["dify_search", "datamate_search"]: diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index 3405be833..9149ed05d 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -48,6 +48,10 @@ class KnowledgeBaseSearchTool(Tool): }, } + # Internal parameter: restricts search results to specified document paths only. + # Not exposed to LLM, only settable via tool_params from /chat/run. + _internal_document_paths: Optional[List[str]] = None + init_param_descriptions = { "top_k": { "description": "Maximum number of search results", @@ -96,6 +100,10 @@ def __init__( display_name_to_index_map: dict = Field( description="Mapping from display_name (knowledge_name) to index_name", default_factory=dict, exclude=True), + # Internal parameter: not exposed to LLM, only settable via tool_params from /chat/run. + document_paths: Optional[List[str]] = Field( + description="Internal: restrict results to documents with these path_or_urls", default=None, exclude=True + ), ): """Initialize the KBSearchTool. @@ -121,11 +129,23 @@ def __init__( self.rerank_model = rerank_model self.data_process_service = os.getenv("DATA_PROCESS_SERVICE") self.display_name_to_index_map = display_name_to_index_map + self._internal_document_paths = document_paths self.record_ops = 1 self.running_prompt_zh = "知识库检索中..." self.running_prompt_en = "Searching the knowledge base..." + def set_document_paths(self, document_paths: Optional[List[str]]) -> None: + """Set the internal document_paths filter for access control. + + This method is intended for internal use only, called via tool_params + from the /chat/run endpoint. It is NOT exposed to the LLM. + + Args: + document_paths: List of allowed document path_or_urls. If None, no filtering is applied. + """ + self._internal_document_paths = document_paths + def _convert_to_index_names(self, names: List[str]) -> List[str]: """Convert display names (knowledge_name) to index names if necessary. @@ -155,6 +175,36 @@ def _convert_to_index_names(self, names: List[str]) -> List[str]: converted_names.append(name) return converted_names + def _filter_by_document_paths(self, results: List[dict]) -> List[dict]: + """Filter search results by allowed document paths for access control. + + If _internal_document_paths is set, only results whose path_or_url is in the + allowed list are returned. Results with no path_or_url field are discarded + when the filter is active. + + Args: + results: List of search result dicts from VDB search + + Returns: + Filtered list containing only results with allowed document paths + """ + allowed_paths = self._internal_document_paths + if not allowed_paths: + return results + + filtered = [ + result for result in results + if result.get("path_or_url") in allowed_paths + ] + + if filtered: + logger.info( + "Document paths filter applied: %d/%d results match allowed paths", + len(filtered), + len(results), + ) + return filtered + def forward(self, query: str, index_names: Optional[List[str]] = None) -> str: # Parse index_names from string (always required) search_index_names = index_names if index_names is not None else self.index_names @@ -203,6 +253,9 @@ def forward(self, query: str, index_names: Optional[List[str]] = None) -> str: ) kb_search_results = kb_search_data["results"] + # Apply document_paths access control: filter out results not in allowed list + kb_search_results = self._filter_by_document_paths(kb_search_results) + if not kb_search_results: raise Exception("No results found! Try a less restrictive/shorter query.") diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index f650de5d7..69dab99bc 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -47,6 +47,21 @@ class ToolExecutionException(Exception): consts_model_module = types.ModuleType("consts.model") consts_model_module.HistoryItem = HistoryItem + + +class MockAgentToolParamsRequest(BaseModel): + """Mock for AgentToolParamsRequest.""" + tools: Dict[str, Dict[str, Any]] = {} + + +class MockToolParamsRequest(BaseModel): + """Mock for ToolParamsRequest.""" + agents: Dict[str, MockAgentToolParamsRequest] = {} + + +consts_model_module.HistoryItem = HistoryItem +consts_model_module.AgentToolParamsRequest = MockAgentToolParamsRequest +consts_model_module.ToolParamsRequest = MockToolParamsRequest sys.modules["consts.model"] = consts_model_module # Mock consts.exceptions module with ValidationError @@ -63,7 +78,7 @@ class ToolExecutionException(Exception): setattr(consts_module, "model", consts_model_module) setattr(consts_module, "exceptions", consts_exceptions_module) -# Also add model to consts module attributes +# Also add model to consts module attributes (with AgentToolParamsRequest and ToolParamsRequest) consts_module = sys.modules.get("consts") if consts_module: setattr(consts_module, "model", consts_model_module) @@ -293,6 +308,9 @@ def __init__(self, **kwargs): _build_internal_s3_url, _format_minio_files_for_content, _convert_history_with_minio_files, + _normalize_tool_params_request, + _get_agent_tool_overrides, + _merge_tool_params, ) # Import HistoryItem for testing (from mocked consts.model) @@ -301,6 +319,9 @@ def __init__(self, **kwargs): # Import ValidationError for testing (from mocked consts.exceptions) ValidationError = sys.modules["consts.exceptions"].ValidationError +# Import ToolParamsRequest for testing +ToolParamsRequest = sys.modules["consts.model"].ToolParamsRequest + # Import constants for testing from consts.const import MODEL_CONFIG_MAPPING @@ -736,6 +757,11 @@ async def test_create_tool_config_list_knowledge_base_multimodal(self): """Ensure multimodal param is forwarded to embedding model selection.""" mock_tool_instance = MagicMock() mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + mock_tool_instance.params = { + "index_names": ["idx1", "idx2"], + "multimodal": True, + "rerank": False, + } with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ @@ -744,7 +770,7 @@ async def test_create_tool_config_list_knowledge_base_multimodal(self): patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map, \ patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config: - + mock_tool_config.return_value = mock_tool_instance mock_search_tools.return_value = [ @@ -755,7 +781,7 @@ async def test_create_tool_config_list_knowledge_base_multimodal(self): "inputs": "string", "output_type": "string", "params": [ - {"name": "index_names", "default": ["idx1", "idx2"]}, # 添加这个 + {"name": "index_names", "default": ["idx1", "idx2"]}, {"name": "multimodal", "default": True}, {"name": "rerank", "default": False}, ], @@ -773,9 +799,6 @@ async def test_create_tool_config_list_knowledge_base_multimodal(self): assert len(result) == 1 # Verify get_embedding_model_by_index_name was called with tenant_id and first index_name mock_embedding_by_index.assert_called_once_with("tenant_1", "idx1") - - # Verify that multimodal parameter was removed from params (popped) - assert "multimodal" not in result[0].params @pytest.mark.asyncio async def test_create_tool_config_list_with_analyze_image_tool(self): @@ -897,11 +920,16 @@ async def test_create_tool_config_list_with_analyze_text_file_tool(self): @pytest.mark.asyncio async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self): """ - Test that KnowledgeBaseSearchTool metadata contains vdb_core, embedding_model, + Test that KnowledgeBaseSearchTool metadata contains vdb_core, embedding_model, rerank_model, display_name_to_index_map, and index_name_to_display_map. """ mock_tool_instance = MagicMock() mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + mock_tool_instance.params = { + "index_names": ["idx_a"], + "rerank": True, + "rerank_model_name": "gte-rerank-v2", + } with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ @@ -944,7 +972,7 @@ async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self): # Verify correct functions were called with correct parameters mock_get_vector_db_core.assert_called_once() - # 修改:验证调用时使用 tenant_id 和 index_name + # Verify that call uses tenant_id and first index_name mock_embedding.assert_called_once_with("tenant_1", "idx_a") mock_rerank.assert_called_once_with(tenant_id="tenant_1", model_name="gte-rerank-v2") mock_get_knowledge_map.assert_called_once_with(["idx_a"]) @@ -1230,52 +1258,155 @@ async def test_create_tool_config_list_multiple_tools_same_type(self): assert mock_tool_2.metadata["display_name_to_index_map"] == {} @pytest.mark.asyncio - async def test_create_tool_config_list_with_dify_tool(self): - """Test that DifySearchTool gets correct metadata including rerank model.""" - mock_tool_instance = MagicMock() - mock_tool_instance.class_name = "DifySearchTool" + async def test_create_tool_config_list_applies_request_overrides_for_multiple_tools(self): + """Request tool_params should override DB params for multiple tools in one agent.""" + kb_tool = MagicMock() + kb_tool.class_name = "KnowledgeBaseSearchTool" + kb_tool.params = { + "index_names": ["idx_override"], + "rerank": True, + "rerank_model_name": "gte-rerank-v2", + "top_k": 10, + } + analyze_tool = MagicMock() + analyze_tool.class_name = "AnalyzeTextFileTool" + analyze_tool.params = { + "prompt": "override prompt", + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ - patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank: + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model_by_index_name') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names', return_value={"idx_override": "Override KB"}), \ + patch('backend.agents.create_agent_info.get_llm_model', return_value='llm-model'): + mock_tool_config.side_effect = [kb_tool, analyze_tool] + mock_get_vector_db_core.return_value = 'vdb-core' + mock_embedding.return_value = ('embedding-model', 1, {'status': 'ok'}) + mock_rerank.return_value = 'rerank-model' + mock_search_tools.return_value = [ + { + 'class_name': 'KnowledgeBaseSearchTool', + 'name': 'knowledge_base_search', + 'description': 'kb', + 'inputs': '{}', + 'output_type': 'string', + 'params': [ + {'name': 'index_names', 'default': ['idx_default']}, + {'name': 'rerank', 'default': False}, + {'name': 'rerank_model_name', 'default': ''}, + {'name': 'top_k', 'default': 5}, + ], + 'source': 'local', + 'usage': None, + }, + { + 'class_name': 'AnalyzeTextFileTool', + 'name': 'analyze_text_file', + 'description': 'text', + 'inputs': '{}', + 'output_type': 'string', + 'params': [ + {'name': 'prompt', 'default': 'default prompt'}, + ], + 'source': 'local', + 'usage': None, + }, + ] + + result = await create_tool_config_list( + 'agent_1', + 'tenant_1', + 'user_1', + tool_params={ + 'agents': { + 'test_agent': { + 'tools': { + 'knowledge_base_search': { + 'top_k': 10, + 'rerank': True, + 'rerank_model_name': 'gte-rerank-v2', + 'index_names': ['idx_override'], + }, + 'analyze_text_file': { + 'prompt': 'override prompt', + }, + } + } + } + }, + ) + assert len(result) == 2 + assert kb_tool.params['top_k'] == 10 + assert kb_tool.params['rerank'] is True + assert kb_tool.params['rerank_model_name'] == 'gte-rerank-v2' + assert kb_tool.params['index_names'] == ['idx_override'] + assert analyze_tool.params['prompt'] == 'override prompt' + mock_rerank.assert_called_once_with(tenant_id='tenant_1', model_name='gte-rerank-v2') + mock_embedding.assert_called_once_with('tenant_1', 'idx_override') + + @pytest.mark.asyncio + async def test_create_tool_config_list_with_tool_params(self): + """Test create_tool_config_list with valid tool_params.""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "AnalyzeTextFileTool" + mock_tool_instance.params = { + "prompt": "override prompt", + } + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_llm_model', return_value='llm-model'): mock_tool_config.return_value = mock_tool_instance - mock_rerank.return_value = "mock_rerank_model" mock_search_tools.return_value = [ { - "class_name": "DifySearchTool", - "name": "dify_search", - "description": "Dify knowledge search", - "inputs": "string", - "output_type": "string", - "params": [ - {"name": "rerank", "default": True}, - {"name": "rerank_model_name", "default": "gte-rerank-v2"}, + 'class_name': 'AnalyzeTextFileTool', + 'name': 'analyze_text_file', + 'description': 'text', + 'inputs': '{}', + 'output_type': 'string', + 'params': [ + {'name': 'prompt', 'default': 'default prompt'}, ], - "source": "local", - "usage": None + 'source': 'local', + 'usage': None, } ] - from backend.agents.create_agent_info import create_tool_config_list - result = await create_tool_config_list("agent_1", "tenant_1", "user_1") - - # Verify rerank model was fetched - mock_rerank.assert_called_once_with( - tenant_id="tenant_1", model_name="gte-rerank-v2" + result = await create_tool_config_list( + 'agent_1', + 'tenant_1', + 'user_1', + tool_params={ + 'agents': { + 'test_agent': { + 'tools': { + 'analyze_text_file': { + 'prompt': 'override prompt', + } + } + } + } + }, ) - # Verify metadata assert len(result) == 1 assert result[0] is mock_tool_instance @pytest.mark.asyncio - async def test_create_tool_config_list_with_dify_tool_no_rerank(self): - """Test that DifySearchTool without rerank gets None metadata.""" + async def test_create_tool_config_list_with_dify_tool(self): + """Test that DifySearchTool gets correct metadata including rerank model.""" mock_tool_instance = MagicMock() mock_tool_instance.class_name = "DifySearchTool" + mock_tool_instance.params = { + "rerank": True, + "rerank_model_name": "gte-rerank-v2", + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ @@ -1283,6 +1414,7 @@ async def test_create_tool_config_list_with_dify_tool_no_rerank(self): patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank: mock_tool_config.return_value = mock_tool_instance + mock_rerank.return_value = "mock_rerank_model" mock_search_tools.return_value = [ { @@ -1292,29 +1424,34 @@ async def test_create_tool_config_list_with_dify_tool_no_rerank(self): "inputs": "string", "output_type": "string", "params": [ - {"name": "rerank", "default": False}, - {"name": "rerank_model_name", "default": ""}, + {"name": "rerank", "default": True}, + {"name": "rerank_model_name", "default": "gte-rerank-v2"}, ], "source": "local", "usage": None } ] - from backend.agents.create_agent_info import create_tool_config_list result = await create_tool_config_list("agent_1", "tenant_1", "user_1") - # Verify rerank model was NOT fetched - mock_rerank.assert_not_called() + # Verify rerank model was fetched + mock_rerank.assert_called_once_with( + tenant_id="tenant_1", model_name="gte-rerank-v2" + ) # Verify metadata assert len(result) == 1 assert result[0] is mock_tool_instance @pytest.mark.asyncio - async def test_create_tool_config_list_with_datamate_tool(self): - """Test that DataMateSearchTool gets correct metadata including rerank model.""" + async def test_create_tool_config_list_with_dify_tool_no_rerank(self): + """Test that DifySearchTool without rerank gets None metadata.""" mock_tool_instance = MagicMock() - mock_tool_instance.class_name = "DataMateSearchTool" + mock_tool_instance.class_name = "DifySearchTool" + mock_tool_instance.params = { + "rerank": False, + "rerank_model_name": "", + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ @@ -1322,31 +1459,27 @@ async def test_create_tool_config_list_with_datamate_tool(self): patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank: mock_tool_config.return_value = mock_tool_instance - mock_rerank.return_value = "mock_datamate_rerank_model" mock_search_tools.return_value = [ { - "class_name": "DataMateSearchTool", - "name": "datamate_search", - "description": "DataMate knowledge search", + "class_name": "DifySearchTool", + "name": "dify_search", + "description": "Dify knowledge search", "inputs": "string", "output_type": "string", "params": [ - {"name": "rerank", "default": True}, - {"name": "rerank_model_name", "default": "jina-rerank-v2"}, + {"name": "rerank", "default": False}, + {"name": "rerank_model_name", "default": ""}, ], "source": "local", "usage": None } ] - from backend.agents.create_agent_info import create_tool_config_list result = await create_tool_config_list("agent_1", "tenant_1", "user_1") - # Verify rerank model was fetched - mock_rerank.assert_called_once_with( - tenant_id="tenant_1", model_name="jina-rerank-v2" - ) + # Verify rerank model was NOT fetched + mock_rerank.assert_not_called() # Verify metadata assert len(result) == 1 @@ -1357,6 +1490,10 @@ async def test_create_tool_config_list_with_datamate_tool_no_rerank(self): """Test that DataMateSearchTool without rerank gets None metadata.""" mock_tool_instance = MagicMock() mock_tool_instance.class_name = "DataMateSearchTool" + mock_tool_instance.params = { + "rerank": False, + "rerank_model_name": "", + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ @@ -1381,13 +1518,12 @@ async def test_create_tool_config_list_with_datamate_tool_no_rerank(self): } ] - from backend.agents.create_agent_info import create_tool_config_list result = await create_tool_config_list("agent_1", "tenant_1", "user_1") # Verify rerank model was NOT fetched mock_rerank.assert_not_called() - # Verify metadata + # Verify result assert len(result) == 1 assert result[0] is mock_tool_instance @@ -1610,6 +1746,103 @@ async def test_create_agent_config_with_sub_agents(self): context_components=ANY ) + @pytest.mark.asyncio + async def test_create_agent_config_passes_sub_agent_tool_overrides(self): + """Managed sub-agents should preserve request-scoped tool overrides.""" + with patch('backend.agents.create_agent_info.search_agent_info_by_agent_id') as mock_search_agent, \ + patch('backend.agents.create_agent_info.query_sub_agents_id_list') as mock_query_sub, \ + patch('backend.agents.create_agent_info.query_current_version_no', return_value=2), \ + patch('backend.agents.create_agent_info.create_tool_config_list') as mock_create_tools, \ + patch('backend.agents.create_agent_info.get_agent_prompt_template') as mock_get_template, \ + patch('backend.agents.create_agent_info.tenant_config_manager') as mock_tenant_config, \ + patch('backend.agents.create_agent_info.build_memory_context') as mock_build_memory, \ + patch('backend.agents.create_agent_info.AgentConfig') as mock_agent_config, \ + patch('backend.agents.create_agent_info.prepare_prompt_templates') as mock_prepare_templates, \ + patch('backend.agents.create_agent_info.get_model_by_model_id') as mock_get_model_by_id, \ + patch('backend.agents.create_agent_info._get_external_a2a_agents', return_value=[]), \ + patch('backend.agents.create_agent_info._get_skills_for_template', return_value=[]), \ + patch('backend.agents.create_agent_info._get_skill_script_tools', return_value=[]): + + def _search_agent(agent_id, tenant_id=None, version_no=0): + if agent_id == 'agent_1': + return { + 'name': 'main_agent', + 'description': 'main description', + 'duty_prompt': 'main duty', + 'constraint_prompt': 'main constraint', + 'few_shots_prompt': 'main few shots', + 'max_steps': 5, + 'model_id': 123, + 'provide_run_summary': True, + } + return { + 'name': 'sub_agent_name', + 'description': 'sub description', + 'duty_prompt': 'sub duty', + 'constraint_prompt': 'sub constraint', + 'few_shots_prompt': 'sub few shots', + 'max_steps': 3, + 'model_id': 123, + 'provide_run_summary': False, + } + + def _query_sub_agents(main_agent_id, tenant_id=None, version_no=0): + if main_agent_id == 'agent_1': + return ['sub_agent_1'] + return [] + + mock_search_agent.side_effect = _search_agent + mock_query_sub.side_effect = _query_sub_agents + mock_create_tools.return_value = [] + mock_get_template.return_value = { + 'system_prompt': '{{duty}} {{constraint}} {{few_shots}}' + } + mock_tenant_config.get_app_config.side_effect = [ + 'TestApp', 'Test Description', + 'TestApp', 'Test Description', + ] + mock_build_memory.return_value = Mock( + user_config=Mock(memory_switch=False), + memory_config={}, + tenant_id='tenant_1', + user_id='user_1', + agent_id='agent_1' + ) + mock_prepare_templates.return_value = {'system_prompt': 'populated_system_prompt'} + mock_get_model_by_id.return_value = {'display_name': 'test_model'} + mock_agent_config.side_effect = [Mock(name='sub_cfg'), Mock(name='main_cfg')] + + tool_params = { + 'agents': { + 'main_agent': { + 'tools': { + 'knowledge_base_search': {'top_k': 10}, + } + }, + 'sub_agent_name': { + 'tools': { + 'analyze_text_file': {'prompt': 'sub override'}, + } + }, + } + } + + await create_agent_config( + 'agent_1', + 'tenant_1', + 'user_1', + 'zh', + 'test query', + tool_params=tool_params, + ) + + assert mock_create_tools.call_count == 2 + sub_call = mock_create_tools.call_args_list[0] + main_call = mock_create_tools.call_args_list[1] + # Call args are positional: (agent_id, tenant_id, user_id, ...) + assert sub_call.args[0] == 'sub_agent_1' + assert main_call.args[0] == 'agent_1' + @pytest.mark.asyncio async def test_create_agent_config_with_memory(self): """Test case for creating agent configuration with memory""" @@ -2952,6 +3185,7 @@ async def test_create_agent_run_info_success(self): last_user_query="processed_query", allow_memory_search=True, version_no=1, + tool_params=None, ) mock_get_mcp.assert_called_once_with(tenant_id="tenant_1", is_need_auth=True) mock_filter.assert_called_once_with("agent_config", { @@ -3488,6 +3722,7 @@ async def test_create_agent_run_info_forwards_allow_memory_false(self): last_user_query="processed_query", allow_memory_search=False, version_no=1, + tool_params=None, ) @pytest.mark.asyncio @@ -3534,6 +3769,7 @@ async def test_create_agent_run_info_is_debug_true(self): last_user_query="processed_query", allow_memory_search=True, version_no=0, # Debug mode uses draft version 0 + tool_params=None, ) @pytest.mark.asyncio @@ -3586,6 +3822,7 @@ async def test_create_agent_run_info_no_published_version_fallback(self): last_user_query="processed_query", allow_memory_search=True, version_no=0, # Fallback to draft version 0 + tool_params=None, ) # Verify that get_remote_mcp_server_list was called with is_need_auth=True mock_get_mcp.assert_called_once_with(tenant_id="tenant_1", is_need_auth=True) @@ -4250,6 +4487,10 @@ async def test_knowledge_base_with_display_name_to_index_map(self): """Test that KnowledgeBaseSearchTool gets correct display_name_to_index_map from index_names""" mock_tool_instance = MagicMock() mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + mock_tool_instance.params = { + "index_names": ["idx1", "idx2"], + "rerank": False, + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ @@ -4511,11 +4752,16 @@ async def test_knowledge_base_empty_index_names_raises_validation_error(self): """Test that ValidationError is raised when index_names is empty for KnowledgeBaseSearchTool.""" mock_tool_instance = MagicMock() mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + mock_tool_instance.params = { + "index_names": [], + "rerank": False, + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model_by_index_name') as mock_get_emb, \ patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: @@ -4538,6 +4784,7 @@ async def test_knowledge_base_empty_index_names_raises_validation_error(self): } ] mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_get_emb.return_value = None # Will trigger ValidationError mock_rerank.return_value = None mock_get_knowledge_map.return_value = {} @@ -4546,13 +4793,17 @@ async def test_knowledge_base_empty_index_names_raises_validation_error(self): await create_tool_config_list("agent_1", "tenant_1", "user_1") # Verify error message - assert "Embedding model is required for knowledge_base_search but index_names is empty" in str(exc_info.value) + assert "index_names" in str(exc_info.value) and "not configured" in str(exc_info.value) @pytest.mark.asyncio async def test_knowledge_base_no_embedding_model_raises_validation_error(self): """Test that ValidationError is raised when get_embedding_model_by_index_name returns None.""" mock_tool_instance = MagicMock() mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + mock_tool_instance.params = { + "index_names": ["idx1"], + "rerank": False, + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ @@ -4590,8 +4841,7 @@ async def test_knowledge_base_no_embedding_model_raises_validation_error(self): with pytest.raises(ValidationError) as exc_info: await create_tool_config_list("agent_1", "tenant_1", "user_1") - # Verify error message contains index name and guidance - assert "No embedding model found for index 'idx1'" in str(exc_info.value) + # Verify error message contains guidance about configuring embedding model assert "Please configure an embedding model for this knowledge base" in str(exc_info.value) @pytest.mark.asyncio @@ -4599,6 +4849,11 @@ async def test_knowledge_base_with_valid_embedding_model(self): """Test that KnowledgeBaseSearchTool correctly sets embedding_model when get_embedding_model_by_index_name succeeds.""" mock_tool_instance = MagicMock() mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + mock_tool_instance.params = { + "index_names": ["idx1", "idx2"], + "rerank": True, + "rerank_model_name": "gte-rerank-v2", + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ @@ -4642,19 +4897,19 @@ async def test_knowledge_base_with_valid_embedding_model(self): # Verify the tool was created successfully assert len(result) == 1 - + # Verify get_embedding_model_by_index_name was called with correct parameters mock_get_emb_by_index.assert_called_once_with("tenant_1", "idx1") - + # Verify metadata contains the embedding_model assert result[0].metadata["embedding_model"] == mock_embedding_model - + # Verify metadata also contains other expected fields assert "vdb_core" in result[0].metadata assert "rerank_model" in result[0].metadata assert "display_name_to_index_map" in result[0].metadata assert "index_name_to_display_map" in result[0].metadata - + # Verify mappings are correct assert result[0].metadata["display_name_to_index_map"] == { "Knowledge Base 1": "idx1", @@ -4670,6 +4925,10 @@ async def test_knowledge_base_with_single_index_and_embedding_model(self): """Test KnowledgeBaseSearchTool with single index_name and valid embedding model.""" mock_tool_instance = MagicMock() mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + mock_tool_instance.params = { + "index_names": ["single_index"], + "rerank": False, + } with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ @@ -4710,13 +4969,13 @@ async def test_knowledge_base_with_single_index_and_embedding_model(self): # Verify the tool was created successfully assert len(result) == 1 - + # Verify get_embedding_model_by_index_name was called mock_get_emb_by_index.assert_called_once_with("tenant_1", "single_index") - + # Verify embedding_model is set correctly assert result[0].metadata["embedding_model"] == mock_embedding_model - + # Verify mappings for single index assert result[0].metadata["display_name_to_index_map"] == { "My Knowledge Base": "single_index" @@ -4759,12 +5018,12 @@ async def test_knowledge_base_embedding_model_error_metadata(self): mock_get_vector_db_core.return_value = "vdb_core" mock_rerank.return_value = None mock_get_knowledge_map.return_value = {"test_idx": "Test KB"} - + # Return valid embedding model with error metadata mock_embedding_model = MagicMock() mock_get_emb_by_index.return_value = ( - mock_embedding_model, - 789, + mock_embedding_model, + 789, {"status": "error", "message": "Some error but model exists"} ) @@ -5189,3 +5448,157 @@ def test_convert_history_with_minio_files_all_items_have_minio_files(self): if __name__ == "__main__": pytest.main([__file__]) + + +# ============================================================================ +# Additional tests for improved coverage +# ============================================================================ + + +class TestNormalizeToolParamsRequest: + """Tests for _normalize_tool_params_request function.""" + + def test_normalize_with_none(self): + """Test that None returns empty ToolParamsRequest.""" + result = _normalize_tool_params_request(None) + assert isinstance(result, ToolParamsRequest) + assert result.agents == {} + + def test_normalize_with_tool_params_request(self): + """Test that ToolParamsRequest is returned as-is.""" + req = ToolParamsRequest(agents={"agent1": MockAgentToolParamsRequest(tools={"tool1": {"param1": "value1"}})}) + result = _normalize_tool_params_request(req) + assert result is req + + def test_normalize_with_valid_dict(self): + """Test that valid dict is validated into ToolParamsRequest.""" + input_dict = {"agents": {"agent1": {"tools": {"tool1": {"param1": "value1"}}}}} + result = _normalize_tool_params_request(input_dict) + assert isinstance(result, ToolParamsRequest) + assert "agent1" in result.agents + + def test_normalize_with_invalid_type_raises_validation_error(self): + """Test that non-dict, non-ToolParamsRequest raises ValidationError.""" + with pytest.raises(ValidationError, match="tool_params must be an object"): + _normalize_tool_params_request("invalid_string") + + def test_normalize_with_invalid_dict_returns_empty(self): + """Test that invalid dict returns empty ToolParamsRequest (mock behavior).""" + # The mock ToolParamsRequest doesn't validate, so it just returns empty + result = _normalize_tool_params_request({"invalid_key": 123}) + assert isinstance(result, ToolParamsRequest) + + +class TestGetAgentToolOverrides: + """Tests for _get_agent_tool_overrides function.""" + + def test_get_overrides_with_none_tool_params(self): + """Test that None tool_params returns empty dict.""" + result = _get_agent_tool_overrides(None, "agent1") + assert result == {} + + def test_get_overrides_with_none_agent_name(self): + """Test that None agent_name returns empty dict.""" + tool_params = ToolParamsRequest(agents={"agent1": MockAgentToolParamsRequest(tools={"tool1": {"param1": "value1"}})}) + result = _get_agent_tool_overrides(tool_params, None) + assert result == {} + + def test_get_overrides_with_empty_agent_name(self): + """Test that empty agent_name returns empty dict.""" + tool_params = ToolParamsRequest(agents={"agent1": MockAgentToolParamsRequest(tools={"tool1": {"param1": "value1"}})}) + result = _get_agent_tool_overrides(tool_params, "") + assert result == {} + + def test_get_overrides_with_unknown_agent(self): + """Test that unknown agent returns empty dict.""" + tool_params = ToolParamsRequest(agents={"agent1": MockAgentToolParamsRequest(tools={"tool1": {"param1": "value1"}})}) + result = _get_agent_tool_overrides(tool_params, "unknown_agent") + assert result == {} + + def test_get_overrides_with_existing_agent(self): + """Test that existing agent returns its tool overrides.""" + tool_params = ToolParamsRequest(agents={"agent1": MockAgentToolParamsRequest(tools={"tool1": {"param1": "value1"}, "tool2": {"param2": "value2"}})}) + result = _get_agent_tool_overrides(tool_params, "agent1") + assert result == {"tool1": {"param1": "value1"}, "tool2": {"param2": "value2"}} + + +class TestBuildInternalS3Url: + """Tests for _build_internal_s3_url function.""" + + def test_build_with_non_dict(self): + """Test that non-dict input returns empty string.""" + assert _build_internal_s3_url("not a dict") == "" + assert _build_internal_s3_url(None) == "" + assert _build_internal_s3_url(123) == "" + + def test_build_with_empty_dict(self): + """Test that empty dict returns empty string.""" + assert _build_internal_s3_url({}) == "" + + def test_build_with_object_name(self): + """Test URL building with object_name.""" + result = _build_internal_s3_url({"object_name": "path/to/file.txt"}) + # Bucket name depends on test environment mock (MINIO_DEFAULT_BUCKET = "test-bucket") + assert result.startswith("s3://") + assert "path/to/file.txt" in result + + def test_build_with_object_name_leading_slash(self): + """Test URL building with leading slash in object_name.""" + result = _build_internal_s3_url({"object_name": "/path/to/file.txt"}) + # Bucket name depends on test environment mock + assert result.startswith("s3://") + assert "path/to/file.txt" in result + + def test_build_with_s3_url_input(self): + """Test that s3:// URL is returned as-is.""" + result = _build_internal_s3_url({"url": "s3://bucket/path/file.txt"}) + assert result == "s3://bucket/path/file.txt" + + def test_build_with_s3_single_slash(self): + """Test URL building with s3:/ prefix.""" + result = _build_internal_s3_url({"url": "s3:/bucket/file.txt"}) + assert result == "s3://bucket/file.txt" + + def test_build_with_blob_url(self): + """Test that blob: URL returns empty string.""" + assert _build_internal_s3_url({"url": "blob:http://example.com/file"}) == "" + + def test_build_with_s3_blob_url(self): + """Test that s3:/blob: URL returns empty string.""" + assert _build_internal_s3_url({"url": "s3:/blob:http://example.com/file"}) == "" + + def test_build_with_http_url(self): + """Test that non-s3 URL returns s3:/ prefixed version.""" + result = _build_internal_s3_url({"url": "https://example.com/file.txt"}) + assert result == "s3:/https://example.com/file.txt" + + +class TestMergeToolParams: + """Tests for _merge_tool_params function.""" + + def test_merge_with_override_params(self): + """Test that override params update merged params.""" + tool_record = {"params": [{"name": "param1", "default": "default1"}, {"name": "param2", "default": "default2"}]} + override_params = {"param1": "override1"} + result = _merge_tool_params(tool_record, override_params) + assert result == {"param1": "override1", "param2": "default2"} + + def test_merge_with_extra_params(self): + """Test that extra params take precedence.""" + tool_record = {"params": [{"name": "param1", "default": "default1"}]} + override_params = {"param1": "override1"} + extra_params = {"param1": "extra1", "internal_param": "secret"} + result = _merge_tool_params(tool_record, override_params, extra_params) + assert result == {"param1": "extra1", "internal_param": "secret"} + + def test_merge_with_no_params_in_tool_record(self): + """Test merge when tool_record has no params.""" + tool_record = {} + result = _merge_tool_params(tool_record, {"override": "value"}) + assert result == {"override": "value"} + + def test_merge_with_empty_override_params(self): + """Test merge with empty override params.""" + tool_record = {"params": [{"name": "param1", "default": "default1"}]} + result = _merge_tool_params(tool_record, {}) + assert result == {"param1": "default1"} diff --git a/test/backend/app/test_northbound_app.py b/test/backend/app/test_northbound_app.py index 2bfb25a76..827e04e4d 100644 --- a/test/backend/app/test_northbound_app.py +++ b/test/backend/app/test_northbound_app.py @@ -1,53 +1,22 @@ -import os +"""Unit tests for backend.apps.northbound_app module.""" import sys -from unittest.mock import MagicMock, AsyncMock +import os + +# The conftest.py sets up all mocks + +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from fastapi import FastAPI, HTTPException -from fastapi.responses import StreamingResponse +from fastapi import FastAPI from fastapi.testclient import TestClient -import types -import sys as _sys - -# Dynamically determine the backend path -current_dir = os.path.dirname(os.path.abspath(__file__)) -backend_dir = os.path.abspath(os.path.join(current_dir, "../../../backend")) -sys.path.append(backend_dir) - - -# Pre-mock heavy dependencies before importing router -sys.modules['consts'] = MagicMock() -sys.modules['consts.model'] = MagicMock() - -consts_exceptions_mod = types.ModuleType("consts.exceptions") - -class LimitExceededError(Exception): - pass -class UnauthorizedError(Exception): - pass -class SignatureValidationError(Exception): - pass - -consts_exceptions_mod.LimitExceededError = LimitExceededError -consts_exceptions_mod.UnauthorizedError = UnauthorizedError -consts_exceptions_mod.SignatureValidationError = SignatureValidationError - -# Ensure the parent 'consts' is a module -if 'consts' not in _sys.modules or not isinstance(_sys.modules['consts'], types.ModuleType): - consts_root = types.ModuleType("consts") - consts_root.__path__ = [] - _sys.modules['consts'] = consts_root -else: - consts_root = _sys.modules['consts'] - -consts_root.exceptions = consts_exceptions_mod -_sys.modules['consts.exceptions'] = consts_exceptions_mod -sys.modules['services'] = MagicMock() -sys.modules['services.northbound_service'] = MagicMock() -sys.modules['utils'] = MagicMock() -sys.modules['utils.auth_utils'] = MagicMock() - -# Import router after setting mocks +from io import BytesIO + +# Import from conftest (which sets up mocks automatically) from apps.northbound_app import router +from consts.exceptions import ( + LimitExceededError, + UnauthorizedError, + SignatureValidationError, +) app = FastAPI() @@ -56,6 +25,7 @@ class SignatureValidationError(Exception): def _build_headers(auth="Bearer test_jwt", request_id="req-123", aksk=True): + """Build request headers for testing.""" headers = { "Authorization": auth, "X-Request-Id": request_id, @@ -69,8 +39,12 @@ def _build_headers(auth="Bearer test_jwt", request_id="req-123", aksk=True): return headers -@pytest.mark.asyncio -async def test_health_check(): +# ============================================================================= +# Health Check Tests +# ============================================================================= + +def test_health_check(): + """Test health check endpoint returns healthy status.""" resp = client.get("/nb/v1/health") assert resp.status_code == 200 data = resp.json() @@ -78,544 +52,783 @@ async def test_health_check(): assert data["service"] == "northbound-api" -def test_run_chat_calls_service(monkeypatch): - # Mock Bearer token validation to return valid token - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - # Mock user/tenant lookup to return user and tenant - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - async def _gen(): - yield b"data: hello\n\n" - start_mock = AsyncMock(return_value=StreamingResponse(_gen(), media_type="text/event-stream")) - monkeypatch.setattr("apps.northbound_app.start_streaming_chat", start_mock) +# ============================================================================= +# Upload Chat Attachments Tests +# ============================================================================= - # Use integer conversation_id as the endpoint expects Optional[int] - payload = {"conversation_id": 1, "agent_name": "agent-a", "query": "hi"} - headers = {**_build_headers(), "Idempotency-Key": "idem-1"} - resp = client.post("/nb/v1/chat/run", json=payload, headers=headers) +def test_upload_chat_attachments_success(): + """Test successful chat attachment upload.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.upload_files_for_northbound', new_callable=AsyncMock) as mock_upload: - assert resp.status_code == 200 - assert "text/event-stream" in resp.headers["content-type"] - # Validate call into service - assert start_mock.await_count == 1 - args, kwargs = start_mock.call_args - assert kwargs["conversation_id"] == 1 - assert kwargs["agent_name"] == "agent-a" - assert kwargs["query"] == "hi" - assert kwargs["idempotency_key"] == "idem-1" - - -def test_stop_chat_calls_service(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - stop_mock = AsyncMock(return_value={"message": "success"}) - monkeypatch.setattr("apps.northbound_app.stop_chat", stop_mock) - - # Use integer conversation_id in URL path - resp = client.get("/nb/v1/chat/stop/123", headers=_build_headers()) - assert resp.status_code == 200 - assert stop_mock.await_count == 1 + mock_ctx.return_value = MagicMock() + mock_upload.return_value = { + "message": "Processed 1 files", + "requestId": "req-123", + "results": [{"filename": "test.pdf", "status": "success"}], + } + # Create a fake file upload + file_content = b"test file content" + files = {"files": ("test.pdf", BytesIO(file_content), "application/pdf")} -def test_get_history_calls_service(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - hist_mock = AsyncMock(return_value={"message": "success"}) - monkeypatch.setattr("apps.northbound_app.get_conversation_history", hist_mock) + resp = client.post( + "/nb/v1/chat/attachments/upload", + files=files, + headers=_build_headers(), + ) - # Use integer conversation_id in URL path - resp = client.get("/nb/v1/conversations/123", headers=_build_headers()) - assert resp.status_code == 200 - assert hist_mock.await_count == 1 + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "Processed 1 files" -def test_list_agents_calls_service(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - agents_mock = AsyncMock(return_value={"message": "success", "data": []}) - monkeypatch.setattr("apps.northbound_app.get_agent_info_list", agents_mock) +def test_upload_chat_attachments_limit_exceeded(): + """Test upload returns 429 when limit exceeded.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.upload_files_for_northbound', new_callable=AsyncMock) as mock_upload: - resp = client.get("/nb/v1/agents", headers=_build_headers()) - assert resp.status_code == 200 - assert agents_mock.await_count == 1 + mock_ctx.return_value = MagicMock() + mock_upload.side_effect = LimitExceededError("Upload limit exceeded") + file_content = b"test file content" + files = {"files": ("test.pdf", BytesIO(file_content), "application/pdf")} -def test_list_conversations_calls_service(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - list_mock = AsyncMock(return_value={"message": "success", "data": []}) - monkeypatch.setattr("apps.northbound_app.list_conversations", list_mock) + resp = client.post( + "/nb/v1/chat/attachments/upload", + files=files, + headers=_build_headers(), + ) - resp = client.get("/nb/v1/conversations", headers=_build_headers()) - assert resp.status_code == 200 - assert list_mock.await_count == 1 - - -def test_update_title_sets_headers(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - # Ensure NorthboundContext yields plain string fields (avoid MagicMock in headers) - class _NCtx: - def __init__(self, request_id: str, tenant_id: str, user_id: str, authorization: str, token_id: int = 0): - self.request_id = request_id - self.tenant_id = tenant_id - self.user_id = user_id - self.authorization = authorization - self.token_id = token_id - monkeypatch.setattr("apps.northbound_app.NorthboundContext", _NCtx) - update_mock = AsyncMock(return_value={"message": "success", "data": "nb-4", "idempotency_key": "ide-xyz"}) - monkeypatch.setattr("apps.northbound_app.update_conversation_title", update_mock) - - headers = {**_build_headers(request_id="req-999"), "Idempotency-Key": "ide-xyz"} - resp = client.put("/nb/v1/conversations/123/title", params={"title": "New Title"}, headers=headers) - assert resp.status_code == 200 - # Router wraps JSONResponse and should echo idempotency and request id - assert resp.headers.get("Idempotency-Key") == "ide-xyz" - assert resp.headers.get("X-Request-Id") == "req-999" - assert update_mock.await_count == 1 + assert resp.status_code == 429 -def _std_headers(auth="Bearer test_jwt"): - return { - **_build_headers(auth=auth), - "Idempotency-Key": "idem-xyz", - } +def test_upload_chat_attachments_internal_error(): + """Test upload returns 500 when internal error occurs.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.upload_files_for_northbound', new_callable=AsyncMock) as mock_upload: + mock_ctx.return_value = MagicMock() + mock_upload.side_effect = Exception("Unknown error") -@pytest.mark.parametrize("exc_cls, status", [ - (UnauthorizedError, 401), - (LimitExceededError, 429), - (SignatureValidationError, 401), -]) -def test_run_chat_auth_exceptions_are_mapped(monkeypatch, exc_cls, status): - # Force Bearer token validation to raise domain exceptions - def _raise(*_, **__): - raise exc_cls("boom") - - monkeypatch.setattr( - "apps.northbound_app.validate_bearer_token", _raise) - # Even if provided, auth should not be parsed because token validation fails first - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, "agent_name": "a", "query": "hi"}, - headers=_std_headers(), - ) - assert resp.status_code == status - - -def test_run_chat_missing_authorization_header_returns_401(monkeypatch): - # When no Authorization header, validate_bearer_token returns (False, None) - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (False, None)) - # No Authorization header - headers = {k: v for k, v in _std_headers().items() if k.lower() - != "authorization"} - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, "agent_name": "a", "query": "hi"}, - headers=headers, - ) - assert resp.status_code == 401 - assert "bearer token" in resp.json()["detail"].lower() + file_content = b"test file content" + files = {"files": ("test.pdf", BytesIO(file_content), "application/pdf")} + resp = client.post( + "/nb/v1/chat/attachments/upload", + files=files, + headers=_build_headers(), + ) -def test_run_chat_jwt_parse_exception_returns_401(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) + assert resp.status_code == 500 - def _raise_user_lookup(_access_key): - raise Exception("user lookup error") - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", _raise_user_lookup) - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, "agent_name": "a", "query": "hi"}, - headers=_std_headers(), - ) - # When user lookup fails due to an invalid API key, return 401 - assert resp.status_code == 401 - assert "invalid api key" in resp.json()["detail"].lower() +# ============================================================================= +# Run Chat Tests +# ============================================================================= +def test_run_chat_success(): + """Test successful chat run initiation.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.start_streaming_chat', new_callable=AsyncMock) as mock_run: -def test_run_chat_jwt_missing_user_id_returns_400(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr( - "apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": None, "tenant_id": "t1", "token_id": "t1" - }) + mock_ctx.return_value = MagicMock() + mock_run.return_value = { + "message": "Chat run initiated", + "request_id": "req-789", + "status": "initiated", + } - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, "agent_name": "a", "query": "hi"}, - headers=_std_headers(), - ) - assert resp.status_code == 400 - assert "user" in resp.json()["detail"].lower() + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Hello, agent", + }, + headers=_build_headers(), + ) + assert resp.status_code == 200 -def test_run_chat_jwt_missing_tenant_id_returns_400(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr( - "apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": None, "token_id": "t1" - }) - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, "agent_name": "a", "query": "hi"}, - headers=_std_headers(), - ) - assert resp.status_code == 400 - assert "tenant" in resp.json()["detail"].lower() +def test_run_chat_limit_exceeded(): + """Test run chat returns 429 when limit exceeded.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.start_streaming_chat', new_callable=AsyncMock) as mock_run: + + mock_ctx.return_value = MagicMock() + mock_run.side_effect = LimitExceededError("Rate limit exceeded") + + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Hello", + }, + headers=_build_headers(), + ) + assert resp.status_code == 429 -def test_run_chat_internal_error_when_parsing_context_returns_401(monkeypatch): - def _raise(*_, **__): - raise Exception("unexpected") - monkeypatch.setattr( - "apps.northbound_app.validate_bearer_token", _raise) - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, "agent_name": "a", "query": "hi"}, - headers=_std_headers(), - ) - # Any exception during validation returns 401 - assert resp.status_code == 401 - - -def test_run_chat_unexpected_service_error_maps_500(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - start_mock = AsyncMock(side_effect=Exception("boom")) - monkeypatch.setattr("apps.northbound_app.start_streaming_chat", start_mock) - - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, "agent_name": "a", "query": "hi"}, - headers=_std_headers(), - ) - assert resp.status_code == 500 - - -@pytest.mark.parametrize("path", [ - "/nb/v1/chat/stop/123", - "/nb/v1/conversations/123", - "/nb/v1/agents", - "/nb/v1/conversations", -]) -@pytest.mark.parametrize("exc_cls, status", [ - (UnauthorizedError, 401), - (LimitExceededError, 429), - (SignatureValidationError, 401), -]) -def test_other_endpoints_auth_exceptions_are_mapped(monkeypatch, path, exc_cls, status): - def _raise(*_, **__): - raise exc_cls("boom") - monkeypatch.setattr( - "apps.northbound_app.validate_bearer_token", _raise) - - resp = client.get(path, headers=_build_headers()) - assert resp.status_code == status - - -@pytest.mark.parametrize( - "path, target", - [ - ("/nb/v1/chat/stop/123", "apps.northbound_app.stop_chat"), - ("/nb/v1/conversations/123", "apps.northbound_app.get_conversation_history"), - ("/nb/v1/agents", "apps.northbound_app.get_agent_info_list"), - ("/nb/v1/conversations", "apps.northbound_app.list_conversations"), - ], -) -def test_other_endpoints_unexpected_service_error_maps_500(monkeypatch, path, target): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - monkeypatch.setattr(target, AsyncMock(side_effect=Exception("boom"))) - - resp = client.get(path, headers=_build_headers()) - assert resp.status_code == 500 - - -def test_update_title_unexpected_service_error_maps_500(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - monkeypatch.setattr("apps.northbound_app.update_conversation_title", AsyncMock( - side_effect=Exception("boom"))) - - resp = client.put( - "/nb/v1/conversations/123/title", - params={"title": "x"}, +def test_run_chat_unauthorized(): + """Test run chat returns 500 on unauthorized (broad exception handling).""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx: + mock_ctx.side_effect = UnauthorizedError("Invalid token") + + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Hello", + }, + headers=_build_headers(), + ) + + # The run_chat endpoint has broad exception handling, so unauthorized returns 500 + assert resp.status_code == 500 + + +# ============================================================================= +# Stop Chat Tests +# ============================================================================= + +def test_stop_chat_success(): + """Test successful chat stop.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.stop_chat', new_callable=AsyncMock) as mock_stop: + + mock_ctx.return_value = MagicMock() + mock_stop.return_value = True + + resp = client.get( + "/nb/v1/chat/stop/123", + headers=_build_headers(), + ) + + assert resp.status_code == 200 + + +# ============================================================================= +# Get Conversation Tests +# ============================================================================= + +def test_get_conversation_success(): + """Test successful retrieval of conversation.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.get_conversation_history', new_callable=AsyncMock) as mock_get: + + mock_ctx.return_value = MagicMock() + mock_get.return_value = { + "conversation_id": 123, + "history": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + + resp = client.get( + "/nb/v1/conversations/123", + headers=_build_headers(), + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["conversation_id"] == 123 + assert len(data["history"]) == 2 + + +# ============================================================================= +# List Agents Tests +# ============================================================================= + +def test_list_agents_success(): + """Test successful retrieval of agent list.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.get_agent_info_list', new_callable=AsyncMock) as mock_get: + + mock_ctx.return_value = MagicMock() + mock_get.return_value = { + "agents": [ + {"name": "agent1", "description": "First agent"}, + {"name": "agent2", "description": "Second agent"}, + ] + } + + resp = client.get( + "/nb/v1/agents", + headers=_build_headers(), + ) + + assert resp.status_code == 200 + data = resp.json() + assert len(data["agents"]) == 2 + + +# ============================================================================= +# List Conversations Tests +# ============================================================================= + +def test_list_conversations_success(): + """Test successful retrieval of conversation list.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.list_conversations', new_callable=AsyncMock) as mock_list: + + mock_ctx.return_value = MagicMock() + mock_list.return_value = { + "conversations": [ + {"id": 1, "title": "Conversation 1"}, + {"id": 2, "title": "Conversation 2"}, + ] + } + + resp = client.get( + "/nb/v1/conversations", + headers=_build_headers(), + ) + + assert resp.status_code == 200 + data = resp.json() + assert len(data["conversations"]) == 2 + + +# ============================================================================= +# Update Conversation Title Tests +# ============================================================================= + +def test_update_conversation_title_success(): + """Test successful update of conversation title.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.update_conversation_title', new_callable=AsyncMock) as mock_update: + + mock_ctx.return_value = MagicMock() + mock_ctx.return_value.request_id = "req-123" + mock_update.return_value = {"idempotency_key": "idem-key", "conversation_id": 123, "title": "New Title"} + + resp = client.put( + "/nb/v1/conversations/123/title?title=New%20Title", + headers=_build_headers(), + ) + + assert resp.status_code == 200 + + +# ============================================================================= +# File Fetch Tests +# ============================================================================= + +def test_file_fetch_missing_url(): + """Test file fetch returns 422 when URL is missing.""" + resp = client.get( + "/nb/v1/file/fetch", headers=_build_headers(), ) - assert resp.status_code == 500 - - -def test_run_chat_sets_headers_from_service_response(monkeypatch): - # Mock Bearer token and user lookup - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - - # Ensure NorthboundContext yields plain string fields (avoid MagicMock in headers) - class _NCtx: - def __init__(self, request_id: str, tenant_id: str, user_id: str, authorization: str, token_id: int = 0): - self.request_id = request_id - self.tenant_id = tenant_id - self.user_id = user_id - self.authorization = authorization - self.token_id = token_id - - monkeypatch.setattr("apps.northbound_app.NorthboundContext", _NCtx) - - async def _gen(): - yield b"data: ok\n\n" - - async def _start(ctx, conversation_id, agent_name, query, meta_data=None, idempotency_key=None): - resp = StreamingResponse(_gen(), media_type="text/event-stream") - # Service attaches headers in latest logic; emulate here - resp.headers["X-Request-Id"] = ctx.request_id - resp.headers["conversation_id"] = str(conversation_id) - return resp - - monkeypatch.setattr("apps.northbound_app.start_streaming_chat", _start) - - headers = {**_std_headers(), "X-Request-Id": "rid-123"} - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, - "agent_name": "agent-a", "query": "hello"}, - headers=headers, - ) - assert resp.status_code == 200 - assert resp.headers.get("X-Request-Id") == "rid-123" - assert resp.headers.get("conversation_id") == "1" + # Missing required parameter returns 422 + assert resp.status_code == 422 -def test_run_chat_service_error_maps_500(monkeypatch): - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) +# ============================================================================= +# Error Handling Tests +# ============================================================================= + +def test_invalid_request_body(): + """Test that invalid request body returns 422.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx: + mock_ctx.return_value = MagicMock() + + resp = client.post( + "/nb/v1/chat/run", + json={}, # Missing required fields + headers=_build_headers(), + ) + + # FastAPI returns 422 for validation errors + assert resp.status_code == 422 + + +def test_run_chat_with_conversation_id(): + """Test run chat with existing conversation ID.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.start_streaming_chat', new_callable=AsyncMock) as mock_run: + + mock_ctx.return_value = MagicMock() + mock_run.return_value = { + "message": "Chat run continued", + "request_id": "req-456", + "status": "continued", + } + + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Hello again", + "conversation_id": 123, + }, + headers=_build_headers(), + ) + + assert resp.status_code == 200 + + +def test_run_chat_with_attachments(): + """Test run chat with file attachments.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.start_streaming_chat', new_callable=AsyncMock) as mock_run: + + mock_ctx.return_value = MagicMock() + mock_run.return_value = { + "message": "Chat run with attachments", + "request_id": "req-789", + "status": "initiated", + } + + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Summarize the attached report", + "attachments": ["s3://nexent/attachments/file.pdf"], + }, + headers=_build_headers(), + ) + + assert resp.status_code == 200 + + +def test_run_chat_with_tool_params(): + """Test run chat with tool parameter overrides.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.start_streaming_chat', new_callable=AsyncMock) as mock_run: + + mock_ctx.return_value = MagicMock() + mock_run.return_value = { + "message": "Chat run with tool params", + "request_id": "req-101", + "status": "initiated", + } + + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Search the knowledge base", + "tool_params": { + "agents": { + "general-assistant": { + "tools": { + "knowledge_base_search": { + "top_k": 5, + } + } + } + } + }, + }, + headers=_build_headers(), + ) + + assert resp.status_code == 200 + + +def test_run_chat_permission_error(): + """Test run chat returns 403 when permission denied.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.start_streaming_chat', new_callable=AsyncMock) as mock_run: + + mock_ctx.return_value = MagicMock() + mock_run.side_effect = PermissionError("Access denied") + + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Hello", + }, + headers=_build_headers(), + ) + + assert resp.status_code == 403 + + +def test_run_chat_internal_error(): + """Test run chat returns 500 on internal error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.start_streaming_chat', new_callable=AsyncMock) as mock_run: + + mock_ctx.return_value = MagicMock() + mock_run.side_effect = Exception("Unexpected error") + + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Hello", + }, + headers=_build_headers(), + ) - async def _raise(*args, **kwargs): - raise Exception("Failed to persist user message: boom") + assert resp.status_code == 500 - monkeypatch.setattr("apps.northbound_app.start_streaming_chat", _raise) - resp = client.post( - "/nb/v1/chat/run", - json={"conversation_id": 1, - "agent_name": "agent-a", "query": "hello"}, - headers=_std_headers(), - ) +def test_run_chat_value_error(): + """Test run chat returns 400 on value error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.start_streaming_chat', new_callable=AsyncMock) as mock_run: + + mock_ctx.return_value = MagicMock() + mock_run.side_effect = ValueError("Invalid agent name") + + resp = client.post( + "/nb/v1/chat/run", + json={ + "agent_name": "general-assistant", + "query": "Hello", + }, + headers=_build_headers(), + ) - assert resp.status_code == 500 + assert resp.status_code == 400 -# --- Tests for /file/fetch endpoint --- +# ============================================================================= +# Stop Chat Error Tests +# ============================================================================= -def test_fetch_file_missing_presigned_url(): - """Missing presigned_url parameter returns 422 (FastAPI validation).""" - resp = client.get("/nb/v1/file/fetch") - assert resp.status_code == 422 +def test_stop_chat_limit_exceeded(): + """Test stop chat returns 429 when limit exceeded.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.stop_chat', new_callable=AsyncMock) as mock_stop: + mock_ctx.return_value = MagicMock() + mock_stop.side_effect = LimitExceededError("Rate limit exceeded") -def test_fetch_file_invalid_url_scheme(monkeypatch): - """URL scheme other than http/https returns 400.""" - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) + resp = client.get( + "/nb/v1/chat/stop/123", + headers=_build_headers(), + ) - resp = client.get( - "/nb/v1/file/fetch", - params={"presigned_url": "ftp://example.com/file"}, - headers=_build_headers(), - ) - assert resp.status_code == 400 - assert "Invalid URL scheme" in resp.json()["detail"] - - -def test_fetch_file_success(monkeypatch): - """Valid presigned_url: proxies file content as StreamingResponse.""" - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - - import httpx - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = { - "Content-Type": "application/pdf", - "Content-Disposition": 'attachment; filename="report.pdf"', - } - mock_response.aiter_bytes = MagicMock(return_value=iter([b"PDF content here"])) + assert resp.status_code == 429 - mock_client = MagicMock() - mock_client.get = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) +def test_stop_chat_internal_error(): + """Test stop chat returns 500 on internal error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.stop_chat', new_callable=AsyncMock) as mock_stop: - resp = client.get( - "/nb/v1/file/fetch", - params={"presigned_url": "http://minio:9000/bucket/file.pdf"}, - headers=_build_headers(), - ) + mock_ctx.return_value = MagicMock() + mock_stop.side_effect = Exception("Unexpected error") - assert resp.status_code == 200 - assert resp.headers["content-type"] == "application/pdf" - assert "report.pdf" in resp.headers["content-disposition"] + resp = client.get( + "/nb/v1/chat/stop/123", + headers=_build_headers(), + ) + assert resp.status_code == 500 -def test_fetch_file_non_200_returns_502(monkeypatch): - """MinIO returns non-200: maps to 502 Bad Gateway.""" - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) - mock_response = MagicMock() - mock_response.status_code = 403 - mock_response.headers = {} +# ============================================================================= +# Get Conversation Error Tests +# ============================================================================= - mock_client = MagicMock() - mock_client.get = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) +def test_get_conversation_limit_exceeded(): + """Test get conversation returns 429 when limit exceeded.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.get_conversation_history', new_callable=AsyncMock) as mock_get: - monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + mock_ctx.return_value = MagicMock() + mock_get.side_effect = LimitExceededError("Rate limit exceeded") - resp = client.get( - "/nb/v1/file/fetch", - params={"presigned_url": "http://minio:9000/bucket/file.pdf"}, - headers=_build_headers(), - ) + resp = client.get( + "/nb/v1/conversations/123", + headers=_build_headers(), + ) - assert resp.status_code == 502 - assert "Failed to fetch file from storage" in resp.json()["detail"] + assert resp.status_code == 429 -def test_fetch_file_timeout_returns_504(monkeypatch): - """httpx.TimeoutException: maps to 504 Gateway Timeout.""" - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) +def test_get_conversation_internal_error(): + """Test get conversation returns 500 on internal error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.get_conversation_history', new_callable=AsyncMock) as mock_get: - import httpx - mock_client = MagicMock() - mock_client.get = AsyncMock(side_effect=httpx.TimeoutException("Connection timed out")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) + mock_ctx.return_value = MagicMock() + mock_get.side_effect = Exception("Unexpected error") - monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + resp = client.get( + "/nb/v1/conversations/123", + headers=_build_headers(), + ) - resp = client.get( - "/nb/v1/file/fetch", - params={"presigned_url": "http://minio:9000/bucket/file.pdf"}, - headers=_build_headers(), - ) + assert resp.status_code == 500 - assert resp.status_code == 504 - assert "Timeout" in resp.json()["detail"] +# ============================================================================= +# List Agents Error Tests +# ============================================================================= -def test_fetch_file_request_error_returns_502(monkeypatch): - """httpx.RequestError: maps to 502 Bad Gateway.""" - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) +def test_list_agents_limit_exceeded(): + """Test list agents returns 429 when limit exceeded.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.get_agent_info_list', new_callable=AsyncMock) as mock_get: - import httpx - mock_client = MagicMock() - mock_client.get = AsyncMock(side_effect=httpx.RequestError("Connection refused", request=MagicMock())) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) + mock_ctx.return_value = MagicMock() + mock_get.side_effect = LimitExceededError("Rate limit exceeded") - monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + resp = client.get( + "/nb/v1/agents", + headers=_build_headers(), + ) - resp = client.get( - "/nb/v1/file/fetch", - params={"presigned_url": "http://minio:9000/bucket/file.pdf"}, - headers=_build_headers(), - ) + assert resp.status_code == 429 - assert resp.status_code == 502 - assert "Failed to fetch file from storage" in resp.json()["detail"] +def test_list_agents_internal_error(): + """Test list agents returns 500 on internal error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.get_agent_info_list', new_callable=AsyncMock) as mock_get: -def test_fetch_file_unexpected_error_returns_500(monkeypatch): - """Unexpected exception: maps to 500 Internal Server Error.""" - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", lambda auth: (True, {"token_id": "t1"})) - monkeypatch.setattr("apps.northbound_app.get_user_and_tenant_by_access_key", lambda access_key: { - "user_id": "u1", "tenant_id": "t1", "token_id": "t1" - }) + mock_ctx.return_value = MagicMock() + mock_get.side_effect = Exception("Unexpected error") - mock_client = MagicMock() - mock_client.get = AsyncMock(side_effect=RuntimeError("unexpected failure")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) + resp = client.get( + "/nb/v1/agents", + headers=_build_headers(), + ) - monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) + assert resp.status_code == 500 - resp = client.get( - "/nb/v1/file/fetch", - params={"presigned_url": "http://minio:9000/bucket/file.pdf"}, - headers=_build_headers(), + +# ============================================================================= +# List Conversations Error Tests +# ============================================================================= + +def test_list_conversations_limit_exceeded(): + """Test list conversations returns 429 when limit exceeded.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.list_conversations', new_callable=AsyncMock) as mock_list: + + mock_ctx.return_value = MagicMock() + mock_list.side_effect = LimitExceededError("Rate limit exceeded") + + resp = client.get( + "/nb/v1/conversations", + headers=_build_headers(), + ) + + assert resp.status_code == 429 + + +def test_list_conversations_internal_error(): + """Test list conversations returns 500 on internal error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.list_conversations', new_callable=AsyncMock) as mock_list: + + mock_ctx.return_value = MagicMock() + mock_list.side_effect = Exception("Unexpected error") + + resp = client.get( + "/nb/v1/conversations", + headers=_build_headers(), + ) + + assert resp.status_code == 500 + + +# ============================================================================= +# Update Conversation Title Error Tests +# ============================================================================= + +def test_update_conversation_title_limit_exceeded(): + """Test update conversation title returns 429 when limit exceeded.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.update_conversation_title', new_callable=AsyncMock) as mock_update: + + mock_ctx.return_value = MagicMock() + mock_ctx.return_value.request_id = "req-123" + mock_update.side_effect = LimitExceededError("Rate limit exceeded") + + resp = client.put( + "/nb/v1/conversations/123/title?title=New%20Title", + headers=_build_headers(), + ) + + assert resp.status_code == 429 + + +def test_update_conversation_title_not_found(): + """Test update conversation title returns 404 when conversation not found.""" + from consts.exceptions import ConversationNotFoundError + + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.update_conversation_title', new_callable=AsyncMock) as mock_update: + + mock_ctx.return_value = MagicMock() + mock_ctx.return_value.request_id = "req-123" + mock_update.side_effect = ConversationNotFoundError("Conversation not found") + + resp = client.put( + "/nb/v1/conversations/999/title?title=New%20Title", + headers=_build_headers(), + ) + + assert resp.status_code == 404 + + +def test_update_conversation_title_internal_error(): + """Test update conversation title returns 500 on internal error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.update_conversation_title', new_callable=AsyncMock) as mock_update: + + mock_ctx.return_value = MagicMock() + mock_ctx.return_value.request_id = "req-123" + mock_update.side_effect = Exception("Unexpected error") + + resp = client.put( + "/nb/v1/conversations/123/title?title=New%20Title", + headers=_build_headers(), + ) + + assert resp.status_code == 500 + + +def test_update_conversation_title_with_meta_data(): + """Test update conversation title with metadata.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.update_conversation_title', new_callable=AsyncMock) as mock_update: + + mock_ctx.return_value = MagicMock() + mock_ctx.return_value.request_id = "req-123" + mock_update.return_value = {"idempotency_key": "idem-key", "conversation_id": 123} + + resp = client.put( + "/nb/v1/conversations/123/title?title=New%20Title&meta_data=%7B%22source%22%3A%22test%22%7D", + headers=_build_headers(), + ) + + assert resp.status_code == 200 + + +def test_update_conversation_title_with_idempotency_key(): + """Test update conversation title with idempotency key.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.update_conversation_title', new_callable=AsyncMock) as mock_update: + + mock_ctx.return_value = MagicMock() + mock_ctx.return_value.request_id = "req-123" + mock_update.return_value = {"idempotency_key": "my-key", "conversation_id": 123} + + resp = client.put( + "/nb/v1/conversations/123/title?title=New%20Title", + headers={**_build_headers(), "Idempotency-Key": "my-key"}, + ) + + assert resp.status_code == 200 + + +# ============================================================================= +# Upload Attachments Error Tests +# ============================================================================= + +def test_upload_chat_attachments_value_error(): + """Test upload returns 400 on value error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.upload_files_for_northbound', new_callable=AsyncMock) as mock_upload: + + mock_ctx.return_value = MagicMock() + mock_upload.side_effect = ValueError("Invalid file") + + file_content = b"test file content" + files = {"files": ("test.pdf", BytesIO(file_content), "application/pdf")} + + resp = client.post( + "/nb/v1/chat/attachments/upload", + files=files, + headers=_build_headers(), + ) + + assert resp.status_code == 400 + + +def test_upload_chat_attachments_permission_error(): + """Test upload returns 403 on permission error.""" + with patch('apps.northbound_app._get_northbound_context', new_callable=AsyncMock) as mock_ctx, \ + patch('apps.northbound_app.upload_files_for_northbound', new_callable=AsyncMock) as mock_upload: + + mock_ctx.return_value = MagicMock() + mock_upload.side_effect = PermissionError("Access denied") + + file_content = b"test file content" + files = {"files": ("test.pdf", BytesIO(file_content), "application/pdf")} + + resp = client.post( + "/nb/v1/chat/attachments/upload", + files=files, + headers=_build_headers(), + ) + + assert resp.status_code == 403 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + +def test_resolve_proxy_download_filename_with_rfc598_filename(): + """Test filename resolution with RFC 598 filename.""" + from apps.northbound_app import _resolve_proxy_download_filename + + result = _resolve_proxy_download_filename( + "https://example.com/path/file.pdf", + 'filename="report.pdf"' ) + assert result == "report.pdf" - assert resp.status_code == 500 - assert "Internal server error" in resp.json()["detail"] +def test_resolve_proxy_download_filename_with_rfc598_star_filename(): + """Test filename resolution with RFC 598 star filename.""" + from apps.northbound_app import _resolve_proxy_download_filename -def test_fetch_file_no_auth_required(monkeypatch): - """Endpoint requires no authentication (NOTE: No authentication required).""" - auth_called = [] + result = _resolve_proxy_download_filename( + "https://example.com/path/file.pdf", + "filename*=UTF-8''report%20final.pdf" + ) + assert result == "report final.pdf" - def _track_auth(auth): - auth_called.append(auth) - return (True, {"token_id": "t1"}) - monkeypatch.setattr("apps.northbound_app.validate_bearer_token", _track_auth) +def test_resolve_proxy_download_filename_from_url(): + """Test filename resolution from URL when no content-disposition.""" + from apps.northbound_app import _resolve_proxy_download_filename - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"Content-Type": "text/plain"} - mock_response.aiter_bytes = MagicMock(return_value=iter([b"hello"])) + result = _resolve_proxy_download_filename( + "https://example.com/path/to/document.pdf", + "" + ) + assert result == "document.pdf" - mock_client = MagicMock() - mock_client.get = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - monkeypatch.setattr("httpx.AsyncClient", lambda **kwargs: mock_client) +def test_resolve_proxy_download_filename_no_filename_in_url(): + """Test filename resolution returns 'download' when no filename in URL.""" + from apps.northbound_app import _resolve_proxy_download_filename - # No headers at all - should still work because auth is not checked - resp = client.get( - "/nb/v1/file/fetch", - params={"presigned_url": "http://minio:9000/bucket/file.pdf"}, + result = _resolve_proxy_download_filename( + "https://example.com/path/", + "" ) + assert result == "download" - assert resp.status_code == 200 + +def test_resolve_proxy_download_filename_empty_content_disposition(): + """Test filename resolution with empty content-disposition.""" + from apps.northbound_app import _resolve_proxy_download_filename + + result = _resolve_proxy_download_filename( + "https://example.com/path/file.pdf", + None + ) + assert result == "file.pdf" diff --git a/test/backend/services/test_northbound_service.py b/test/backend/services/test_northbound_service.py index 0d658e198..e70c946ac 100644 --- a/test/backend/services/test_northbound_service.py +++ b/test/backend/services/test_northbound_service.py @@ -1,93 +1,145 @@ +""" +Tests for backend.services.northbound_service module. + +This module tests the northbound-facing service layer functions including: +- Streaming chat (start/stop) +- Conversation management (list, history, title update) +- Agent info listing +- Rate limiting and idempotency +""" import sys import os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../..")) +import types +from unittest.mock import MagicMock, AsyncMock, patch import pytest -from unittest.mock import MagicMock, AsyncMock, patch +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../..")) + +# ============================================================================= +# Mock all required modules BEFORE importing northbound_service +# ============================================================================= -# First mock the consts module to avoid ModuleNotFoundError -consts_mock = MagicMock() -consts_mock.const = MagicMock() -consts_mock.const.MINIO_ENDPOINT = "http://localhost:9000" -consts_mock.const.MINIO_ACCESS_KEY = "test_access_key" -consts_mock.const.MINIO_SECRET_KEY = "test_secret_key" -consts_mock.const.MINIO_REGION = "us-east-1" -consts_mock.const.MINIO_DEFAULT_BUCKET = "test-bucket" -consts_mock.const.POSTGRES_HOST = "localhost" -consts_mock.const.POSTGRES_USER = "test_user" -consts_mock.const.NEXENT_POSTGRES_PASSWORD = "test_password" -consts_mock.const.POSTGRES_DB = "test_db" -consts_mock.const.POSTGRES_PORT = 5432 -consts_mock.const.DEFAULT_TENANT_ID = "default_tenant" - -sys.modules['consts'] = consts_mock -sys.modules['consts.const'] = consts_mock.const - -# Mock exceptions module +# Mock consts.exceptions class LimitExceededError(Exception): pass class UnauthorizedError(Exception): pass -exceptions_mock = MagicMock() -exceptions_mock.LimitExceededError = LimitExceededError -exceptions_mock.UnauthorizedError = UnauthorizedError -sys.modules['consts.exceptions'] = exceptions_mock -sys.modules['backend.consts.exceptions'] = exceptions_mock - -# Mock database client -client_mock = MagicMock() -client_mock.MinioClient = MagicMock() -client_mock.get_db_session = MagicMock() -sys.modules['database.client'] = client_mock -sys.modules['backend.database.client'] = client_mock - -# Mock token_db module -token_db_mock = MagicMock() -token_db_mock.log_token_usage = MagicMock(return_value=1) -token_db_mock.get_latest_usage_metadata = MagicMock(return_value={"query": "test"}) -sys.modules['database.token_db'] = token_db_mock -sys.modules['backend.database.token_db'] = token_db_mock - -# Mock conversation_db module -conversation_db_mock = MagicMock() -conversation_db_mock.get_conversation_messages = MagicMock(return_value=[ +class ConversationNotFoundError(Exception): + pass + +consts_exceptions_mod = types.ModuleType("consts.exceptions") +consts_exceptions_mod.LimitExceededError = LimitExceededError +consts_exceptions_mod.UnauthorizedError = UnauthorizedError +consts_exceptions_mod.ConversationNotFoundError = ConversationNotFoundError +sys.modules["consts.exceptions"] = consts_exceptions_mod +sys.modules["backend.consts.exceptions"] = consts_exceptions_mod + +# Mock consts.const +consts_const_mod = types.ModuleType("consts.const") +consts_const_mod.ASSET_OWNER_TENANT_ID = "asset-owner-tenant" +sys.modules["consts.const"] = consts_const_mod + +# Mock consts package +consts_package = types.ModuleType("consts") +consts_package.exceptions = consts_exceptions_mod +consts_package.const = consts_const_mod +sys.modules["consts"] = consts_package + +# Mock database modules +db_client_mod = types.ModuleType("database.client") +db_client_mod.get_db_session = MagicMock() +db_client_mod.as_dict = MagicMock() +sys.modules["database.client"] = db_client_mod +sys.modules["backend.database.client"] = db_client_mod + +db_package = types.ModuleType("database") +db_package.client = db_client_mod +sys.modules["database"] = db_package + +# Mock token_db +token_db_mod = types.ModuleType("database.token_db") +token_db_mod.log_token_usage = MagicMock(return_value=1) +token_db_mod.get_latest_usage_metadata = MagicMock(return_value={"query": "test"}) +sys.modules["database.token_db"] = token_db_mod + +# Mock conversation_db +conversation_db_mod = types.ModuleType("database.conversation_db") +conversation_db_mod.get_conversation_messages = MagicMock(return_value=[ {"message_role": "user", "message_content": "Hello"} ]) -sys.modules['database.conversation_db'] = conversation_db_mock -sys.modules['backend.database.conversation_db'] = conversation_db_mock - -# Mock agent_service module -agent_service_mock = MagicMock() -agent_service_mock.run_agent_stream = AsyncMock() -agent_service_mock.stop_agent_tasks = MagicMock(return_value={"message": "stopped"}) -agent_service_mock.list_all_agent_info_impl = AsyncMock(return_value=[{"agent_id": 1, "name": "test_agent"}]) -agent_service_mock.get_agent_id_by_name = AsyncMock(return_value=1) -sys.modules['services.agent_service'] = agent_service_mock -sys.modules['backend.services.agent_service'] = agent_service_mock - -# Mock conversation_management_service module -conv_mgmt_mock = MagicMock() -conv_mgmt_mock.save_conversation_user = MagicMock() -conv_mgmt_mock.get_conversation_list_service = MagicMock(return_value=[ +sys.modules["database.conversation_db"] = conversation_db_mod + +# Mock attachment_db +attachment_db_mod = types.ModuleType("database.attachment_db") +attachment_db_mod.build_s3_url = MagicMock(return_value="s3://bucket/file") +attachment_db_mod.get_file_url = MagicMock(return_value={"success": True, "url": "https://proxy.example/file"}) +sys.modules["database.attachment_db"] = attachment_db_mod + +# Mock nexent.multi_modal.utils +nexent_utils_mod = types.ModuleType("nexent.multi_modal.utils") +nexent_utils_mod.parse_s3_url = MagicMock(return_value=("bucket", "path/file.txt")) +sys.modules["nexent"] = types.ModuleType("nexent") +sys.modules["nexent.multi_modal"] = types.ModuleType("nexent.multi_modal") +sys.modules["nexent.multi_modal.utils"] = nexent_utils_mod + +# Mock services modules +services_package = types.ModuleType("services") + +# Mock agent_service +agent_service_mod = types.ModuleType("services.agent_service") +agent_service_mod.run_agent_stream = AsyncMock() +agent_service_mod.stop_agent_tasks = MagicMock(return_value={"message": "stopped"}) +agent_service_mod.get_agent_id_by_name = AsyncMock(return_value=1) +sys.modules["services.agent_service"] = agent_service_mod + +# Mock conversation_management_service +conv_mgmt_mod = types.ModuleType("services.conversation_management_service") +conv_mgmt_mod.save_conversation_user = MagicMock() +conv_mgmt_mod.get_conversation_list_service = MagicMock(return_value=[ {"conversation_id": "1", "title": "Test"} ]) -conv_mgmt_mock.create_new_conversation = MagicMock(return_value={"conversation_id": 123}) -conv_mgmt_mock.update_conversation_title_service = MagicMock() -sys.modules['services.conversation_management_service'] = conv_mgmt_mock -sys.modules['backend.services.conversation_management_service'] = conv_mgmt_mock - -# Mock consts.model -consts_model_mock = MagicMock() -AgentRequest_mock = MagicMock() -consts_model_mock.AgentRequest = AgentRequest_mock -sys.modules['consts.model'] = consts_model_mock +conv_mgmt_mod.create_new_conversation = MagicMock(return_value={"conversation_id": 123}) +conv_mgmt_mod.update_conversation_title = MagicMock() +sys.modules["services.conversation_management_service"] = conv_mgmt_mod + +# Mock agent_version_service +agent_version_mod = types.ModuleType("services.agent_version_service") +agent_version_mod.list_published_agents_impl = AsyncMock(return_value=[ + {"agent_id": 1, "name": "test_agent", "description": "Test agent"} +]) +sys.modules["services.agent_version_service"] = agent_version_mod + +# Mock file_management_service +file_mgmt_mod = types.ModuleType("services.file_management_service") +file_mgmt_mod.upload_to_minio = AsyncMock(return_value=[]) +file_mgmt_mod.resolve_minio_upload_folder = MagicMock(return_value="attachments/user") +file_mgmt_mod.validate_urls_access = MagicMock() +sys.modules["services.file_management_service"] = file_mgmt_mod + +# Add to services package +services_package.agent_service = agent_service_mod +services_package.agent_version_service = agent_version_mod +services_package.conversation_management_service = conv_mgmt_mod +services_package.file_management_service = file_mgmt_mod +sys.modules["services"] = services_package + +# Mock consts.model - create stub classes +class AgentRequestStub: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + +class ToolParamsRequestStub: + pass -# Mock database.db_models -db_models_mock = MagicMock() -sys.modules['database.db_models'] = db_models_mock +consts_model_mod = types.ModuleType("consts.model") +consts_model_mod.AgentRequest = AgentRequestStub +consts_model_mod.ToolParamsRequest = ToolParamsRequestStub +sys.modules["consts.model"] = consts_model_mod # Now import the module under test from backend.services import northbound_service as ns @@ -107,13 +159,12 @@ def __init__(self, request_id="req-123", tenant_id="tenant-1", user_id="user-1", @pytest.fixture(autouse=True) def reset_test_isolation(): """Reset test isolation state before each test.""" - # Clear idempotency state ns._IDEMPOTENCY_RUNNING.clear() - # Reset mock call counts - token_db_mock.log_token_usage.reset_mock() + ns._RATE_STATE.clear() + token_db_mod.log_token_usage.reset_mock() yield - # Cleanup after test ns._IDEMPOTENCY_RUNNING.clear() + ns._RATE_STATE.clear() class TestNorthboundContext: @@ -149,23 +200,155 @@ def test_build_idempotency_key_normal(self): key = ns._build_idempotency_key("tenant1", "123", "agent1", "query") assert "tenant1" in key assert "123" in key + assert key.count(":") == 3 def test_build_idempotency_key_with_none(self): - """Test with None values.""" + """Test with None values are converted to empty string.""" key = ns._build_idempotency_key("tenant1", None, "query") assert "tenant1" in key - # None values are converted to empty string assert "None" not in key - # Should contain the empty string from None conversion - assert "tenant1::" in key or ":query" in key - def test_build_idempotency_key_long_string(self): + def test_build_idempotency_key_long_string_hashed(self): """Test with long string gets hashed.""" long_string = "a" * 100 key = ns._build_idempotency_key(long_string) - # Should be hashed (not the full string) assert len(key) < 100 + def test_build_idempotency_key_mixed_long_short(self): + """Test with mixed long and short values.""" + long_val = "x" * 100 + key = ns._build_idempotency_key("short", long_val, "another_short") + assert len(key) < 200 + + def test_build_idempotency_key_empty(self): + """Test with all empty values.""" + key = ns._build_idempotency_key() + assert key == "" + + def test_build_idempotency_key_single_value(self): + """Test with single value.""" + key = ns._build_idempotency_key("only") + assert key == "only" + + +class TestBuildTitleUpdateIdempotencyKey: + """Tests for _build_title_update_idempotency_key function.""" + + def test_title_update_key_format(self): + """Test that title is hashed in the key.""" + key = ns._build_title_update_idempotency_key("tenant1", 123, "My Title") + assert "tenant1" in key + assert "123" in key + # Title should be hashed (SHA256 hex = 64 chars) + parts = key.split(":") + assert len(parts) == 3 + assert len(parts[2]) == 64 # SHA256 hex digest + + def test_title_update_key_different_titles_different_keys(self): + """Test that different titles produce different keys.""" + key1 = ns._build_title_update_idempotency_key("tenant", 1, "Title A") + key2 = ns._build_title_update_idempotency_key("tenant", 1, "Title B") + assert key1 != key2 + + def test_title_update_key_same_inputs_same_key(self): + """Test that same inputs produce same key.""" + key1 = ns._build_title_update_idempotency_key("tenant", 1, "Same Title") + key2 = ns._build_title_update_idempotency_key("tenant", 1, "Same Title") + assert key1 == key2 + + +class TestIdempotencyStartEnd: + """Tests for idempotency_start and idempotency_end functions.""" + + @pytest.mark.asyncio + async def test_idempotency_start_new_key(self): + """Test starting idempotency with new key succeeds.""" + await ns.idempotency_start("new-key") + assert "new-key" in ns._IDEMPOTENCY_RUNNING + + @pytest.mark.asyncio + async def test_idempotency_start_duplicate_key_raises(self): + """Test that duplicate key raises LimitExceededError.""" + await ns.idempotency_start("duplicate-key") + with pytest.raises(LimitExceededError): + await ns.idempotency_start("duplicate-key") + + @pytest.mark.asyncio + async def test_idempotency_end_removes_key(self): + """Test that idempotency_end removes the key.""" + await ns.idempotency_start("end-key") + assert "end-key" in ns._IDEMPOTENCY_RUNNING + await ns.idempotency_end("end-key") + assert "end-key" not in ns._IDEMPOTENCY_RUNNING + + @pytest.mark.asyncio + async def test_idempotency_end_nonexistent_key(self): + """Test that ending nonexistent key does not raise.""" + await ns.idempotency_end("nonexistent-key") # Should not raise + + @pytest.mark.asyncio + async def test_idempotency_expired_key_can_be_reused(self, reset_test_isolation): + """Test that expired keys can be reused after TTL.""" + # Use a very short TTL + await ns.idempotency_start("expire-key", ttl_seconds=1) + assert "expire-key" in ns._IDEMPOTENCY_RUNNING + # Wait for expiration + import asyncio + await asyncio.sleep(1.1) + # Should be able to start again with same key + await ns.idempotency_start("expire-key", ttl_seconds=1) + + +class TestRateLimiting: + """Tests for rate limiting functionality.""" + + @pytest.mark.asyncio + async def test_rate_limit_first_request_allowed(self): + """Test first request under limit is allowed.""" + await ns.check_and_consume_rate_limit("tenant-rate") + assert ns._RATE_STATE["tenant-rate"].get(ns._minute_bucket(), 0) == 1 + + @pytest.mark.asyncio + async def test_rate_limit_multiple_requests(self): + """Test multiple requests increment counter.""" + for _ in range(5): + await ns.check_and_consume_rate_limit("tenant-multi") + assert ns._RATE_STATE["tenant-multi"].get(ns._minute_bucket(), 0) == 5 + + @pytest.mark.asyncio + async def test_rate_limit_exceeded_raises(self): + """Test that exceeding limit raises LimitExceededError.""" + # Fill up to limit + for _ in range(ns._RATE_LIMIT_PER_MINUTE): + await ns.check_and_consume_rate_limit("tenant-limit") + with pytest.raises(LimitExceededError): + await ns.check_and_consume_rate_limit("tenant-limit") + + @pytest.mark.asyncio + async def test_rate_limit_different_tenants(self): + """Test that different tenants have separate limits.""" + for _ in range(10): + await ns.check_and_consume_rate_limit("tenant-a") + for _ in range(5): + await ns.check_and_consume_rate_limit("tenant-b") + assert ns._RATE_STATE["tenant-a"].get(ns._minute_bucket(), 0) == 10 + assert ns._RATE_STATE["tenant-b"].get(ns._minute_bucket(), 0) == 5 + + @pytest.mark.asyncio + async def test_rate_limit_cleanup_old_buckets(self): + """Test that old minute buckets are cleaned up.""" + # First, add a request to create an old bucket + old_bucket = str(int(ns._now_seconds() // 60) - 1) + ns._RATE_STATE["tenant-cleanup"] = {old_bucket: 50} + + # Make a new request - should trigger cleanup of old bucket + await ns.check_and_consume_rate_limit("tenant-cleanup") + + # Old bucket should be cleaned up, new bucket should have 1 request + current_bucket = ns._minute_bucket() + assert old_bucket not in ns._RATE_STATE["tenant-cleanup"] + assert ns._RATE_STATE["tenant-cleanup"].get(current_bucket, 0) == 1 + @pytest.mark.asyncio class TestStartStreamingChat: @@ -173,30 +356,25 @@ class TestStartStreamingChat: async def test_start_streaming_chat_creates_conversation(self): """Test that new conversation is created when conversation_id is None.""" - ctx = MockNorthboundContext(token_id=1) + ctx = MockNorthboundContext(token_id=0) - # Mock response mock_response = MagicMock() mock_response.headers = {} - agent_service_mock.run_agent_stream.return_value = mock_response - - with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock): - with patch.object(ns, 'idempotency_start', new_callable=AsyncMock): - with patch.object(ns, 'get_conversation_history_internal', new_callable=AsyncMock) as mock_history: - mock_history.return_value = {"data": {"history": []}} - - try: - result = await ns.start_streaming_chat( - ctx=ctx, - conversation_id=None, - agent_name="test_agent", - query="test query" - ) - except Exception: - pass # May fail due to other mocks - - # Verify create_new_conversation was called - conv_mgmt_mock.create_new_conversation.assert_called() + agent_service_mod.run_agent_stream.return_value = mock_response + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_start', new_callable=AsyncMock), \ + patch.object(ns, 'get_conversation_history_internal', new_callable=AsyncMock) as mock_history: + mock_history.return_value = {"data": {"history": []}} + + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=None, + agent_name="test_agent", + query="test query" + ) + + conv_mgmt_mod.create_new_conversation.assert_called() async def test_start_streaming_chat_logs_token_usage(self): """Test that token usage is logged when token_id > 0.""" @@ -204,27 +382,113 @@ async def test_start_streaming_chat_logs_token_usage(self): mock_response = MagicMock() mock_response.headers = {} - agent_service_mock.run_agent_stream.return_value = mock_response + agent_service_mod.run_agent_stream.return_value = mock_response + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_start', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_end', new_callable=AsyncMock), \ + patch.object(ns, 'get_conversation_history_internal', new_callable=AsyncMock) as mock_history: + mock_history.return_value = {"data": {"history": []}} + + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=123, + agent_name="test_agent", + query="test query", + meta_data={"key": "value"} + ) + + token_db_mod.log_token_usage.assert_called() + + async def test_start_streaming_chat_rate_limit_exceeded(self): + """Test that rate limit exceeded is properly propagated.""" + ctx = MockNorthboundContext(token_id=0) + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock) as mock_limit: + mock_limit.side_effect = LimitExceededError("Rate exceeded") + with pytest.raises(LimitExceededError): + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=123, + agent_name="test_agent", + query="test query" + ) + + async def test_start_streaming_chat_uses_existing_conversation(self): + """Test that existing conversation_id is used without creating new one.""" + ctx = MockNorthboundContext(token_id=0) + conv_mgmt_mod.create_new_conversation.reset_mock() + + mock_response = MagicMock() + mock_response.headers = {} + agent_service_mod.run_agent_stream.return_value = mock_response + + async def mock_get_history(*args, **kwargs): + return {"data": {"history": []}} + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_start', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_end', new_callable=AsyncMock), \ + patch.object(ns, 'get_conversation_history_internal', side_effect=mock_get_history): + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=456, + agent_name="test_agent", + query="test query" + ) + + conv_mgmt_mod.create_new_conversation.assert_not_called() + + async def test_start_streaming_chat_no_token_id_no_logging(self): + """Test that token usage is not logged when token_id is 0.""" + ctx = MockNorthboundContext(token_id=0) + token_db_mod.log_token_usage.reset_mock() + + mock_response = MagicMock() + mock_response.headers = {} + agent_service_mod.run_agent_stream.return_value = mock_response + + async def mock_get_history(*args, **kwargs): + return {"data": {"history": []}} + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_start', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_end', new_callable=AsyncMock), \ + patch.object(ns, 'get_conversation_history_internal', side_effect=mock_get_history): + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=123, + agent_name="test_agent", + query="test query" + ) + + token_db_mod.log_token_usage.assert_not_called() + + async def test_start_streaming_chat_with_attachments(self): + """Test streaming chat with attachment normalization.""" + ctx = MockNorthboundContext(token_id=0) + attachments = ["s3://bucket/file.txt"] + + mock_response = MagicMock() + mock_response.headers = {} + agent_service_mod.run_agent_stream.return_value = mock_response - with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock): - with patch.object(ns, 'idempotency_start', new_callable=AsyncMock): - with patch.object(ns, 'idempotency_end', new_callable=AsyncMock): - with patch.object(ns, 'get_conversation_history_internal', new_callable=AsyncMock) as mock_history: - mock_history.return_value = {"data": {"history": []}} + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_start', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_end', new_callable=AsyncMock), \ + patch.object(ns, 'get_conversation_history_internal', new_callable=AsyncMock) as mock_history, \ + patch.object(ns, '_normalize_northbound_attachments', return_value=[{"name": "file.txt"}]) as mock_norm: + mock_history.return_value = {"data": {"history": []}} - try: - await ns.start_streaming_chat( - ctx=ctx, - conversation_id=123, - agent_name="test_agent", - query="test query", - meta_data={"key": "value"} - ) - except Exception: - pass + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=123, + agent_name="test_agent", + query="test query", + attachments=attachments + ) - # Verify log_token_usage was called - token_db_mock.log_token_usage.assert_called() + mock_norm.assert_called_once() @pytest.mark.asyncio @@ -234,7 +498,7 @@ class TestStopChat: async def test_stop_chat_success(self): """Test successful stop chat.""" ctx = MockNorthboundContext(token_id=1) - agent_service_mock.stop_agent_tasks.return_value = {"message": "stopped"} + agent_service_mod.stop_agent_tasks.return_value = {"message": "stopped"} result = await ns.stop_chat(ctx=ctx, conversation_id=123) @@ -242,12 +506,22 @@ async def test_stop_chat_success(self): assert result["data"] == 123 async def test_stop_chat_logs_token_usage(self): - """Test that token usage is logged.""" + """Test that token usage is logged when token_id > 0.""" ctx = MockNorthboundContext(token_id=1) + token_db_mod.log_token_usage.reset_mock() await ns.stop_chat(ctx=ctx, conversation_id=123, meta_data={"test": "data"}) - token_db_mock.log_token_usage.assert_called() + token_db_mod.log_token_usage.assert_called() + + async def test_stop_chat_no_token_id_no_logging(self): + """Test that token usage is not logged when token_id is 0.""" + ctx = MockNorthboundContext(token_id=0) + token_db_mod.log_token_usage.reset_mock() + + await ns.stop_chat(ctx=ctx, conversation_id=123) + + token_db_mod.log_token_usage.assert_not_called() @pytest.mark.asyncio @@ -256,7 +530,7 @@ class TestListConversations: async def test_list_conversations_success(self): """Test successful conversation listing.""" - ctx = MockNorthboundContext(token_id=0) # No token_id, no metadata lookup + ctx = MockNorthboundContext(token_id=0) result = await ns.list_conversations(ctx=ctx) @@ -266,12 +540,11 @@ async def test_list_conversations_success(self): async def test_list_conversations_with_metadata(self): """Test that metadata is added when token_id > 0.""" ctx = MockNorthboundContext(token_id=1) - token_db_mock.get_latest_usage_metadata.return_value = {"query": "test query"} + token_db_mod.get_latest_usage_metadata.return_value = {"query": "test query"} result = await ns.list_conversations(ctx=ctx) - # Should have called get_latest_usage_metadata - token_db_mock.get_latest_usage_metadata.assert_called() + token_db_mod.get_latest_usage_metadata.assert_called() @pytest.mark.asyncio @@ -281,7 +554,7 @@ class TestGetConversationHistory: async def test_get_conversation_history_success(self): """Test successful history retrieval.""" ctx = MockNorthboundContext(token_id=1) - conversation_db_mock.get_conversation_messages.return_value = [ + conversation_db_mod.get_conversation_messages.return_value = [ {"message_role": "user", "message_content": "Hello"}, {"message_role": "assistant", "message_content": "Hi there"} ] @@ -292,6 +565,19 @@ async def test_get_conversation_history_success(self): assert "data" in result assert "history" in result["data"] + async def test_get_conversation_history_fields_transformed(self): + """Test that message fields are properly transformed.""" + ctx = MockNorthboundContext(token_id=0) + conversation_db_mod.get_conversation_messages.return_value = [ + {"message_role": "user", "message_content": "Hello"} + ] + + result = await ns.get_conversation_history(ctx=ctx, conversation_id=123) + + history = result["data"]["history"] + assert history[0]["role"] == "user" + assert history[0]["content"] == "Hello" + @pytest.mark.asyncio class TestGetConversationHistoryInternal: @@ -300,7 +586,7 @@ class TestGetConversationHistoryInternal: async def test_get_conversation_history_internal_success(self): """Test internal history retrieval without logging.""" ctx = MockNorthboundContext(token_id=0) - conversation_db_mock.get_conversation_messages.return_value = [ + conversation_db_mod.get_conversation_messages.return_value = [ {"message_role": "user", "message_content": "Hello"} ] @@ -313,12 +599,12 @@ async def test_get_conversation_history_internal_success(self): async def test_get_conversation_history_internal_no_logging(self): """Test that internal function does not log token usage.""" ctx = MockNorthboundContext(token_id=1) - conversation_db_mock.get_conversation_messages.return_value = [] + conversation_db_mod.get_conversation_messages.return_value = [] + token_db_mod.log_token_usage.reset_mock() await ns.get_conversation_history_internal(ctx=ctx, conversation_id=123) - # Should NOT call log_token_usage - token_db_mock.log_token_usage.assert_not_called() + token_db_mod.log_token_usage.assert_not_called() @pytest.mark.asyncio @@ -326,9 +612,10 @@ class TestGetAgentInfoList: """Tests for get_agent_info_list function.""" async def test_get_agent_info_list_success(self): - """Test successful agent info list retrieval.""" - ctx = MockNorthboundContext(token_id=1) - agent_service_mock.list_all_agent_info_impl.return_value = [ + """Test successful agent info list retrieval for asset owner tenant.""" + # Use asset owner tenant to avoid merging asset owner agents + ctx = MockNorthboundContext(tenant_id="asset-owner-tenant", token_id=1) + agent_version_mod.list_published_agents_impl.return_value = [ {"agent_id": 1, "name": "test_agent", "description": "Test"} ] @@ -336,9 +623,21 @@ async def test_get_agent_info_list_success(self): assert result["message"] == "success" assert len(result["data"]) == 1 - # agent_id should be removed assert "agent_id" not in result["data"][0] + async def test_get_agent_info_list_includes_asset_owner_agents(self): + """Test that asset owner agents are included for non-asset-owner tenants.""" + ctx = MockNorthboundContext(tenant_id="other-tenant", token_id=0) + agent_version_mod.list_published_agents_impl.side_effect = [ + [{"agent_id": 1, "name": "local_agent"}], + [{"agent_id": 2, "name": "asset_agent"}] + ] + + result = await ns.get_agent_info_list(ctx=ctx) + + assert len(result["data"]) == 2 + agent_version_mod.list_published_agents_impl.assert_called() + @pytest.mark.asyncio class TestUpdateConversationTitle: @@ -359,8 +658,9 @@ async def test_update_conversation_title_success(self): assert "idempotency_key" in result async def test_update_conversation_title_logs_token_usage(self): - """Test that token usage is logged.""" + """Test that token usage is logged when token_id > 0.""" ctx = MockNorthboundContext(token_id=1) + token_db_mod.log_token_usage.reset_mock() await ns.update_conversation_title( ctx=ctx, @@ -369,10 +669,10 @@ async def test_update_conversation_title_logs_token_usage(self): meta_data={"source": "api"} ) - token_db_mock.log_token_usage.assert_called() + token_db_mod.log_token_usage.assert_called() - async def test_update_conversation_title_idempotency_key(self): - """Test that idempotency key is properly built.""" + async def test_update_conversation_title_custom_idempotency_key(self): + """Test that custom idempotency key is used when provided.""" ctx = MockNorthboundContext(tenant_id="tenant-1", token_id=1) result = await ns.update_conversation_title( @@ -383,3 +683,361 @@ async def test_update_conversation_title_idempotency_key(self): ) assert result["idempotency_key"] == "custom-key" + + async def test_update_conversation_title_idempotency_prevents_duplicate(self): + """Test that duplicate requests within TTL are prevented.""" + ctx = MockNorthboundContext(tenant_id="tenant-1", token_id=0) + + # First call should succeed + await ns.update_conversation_title( + ctx=ctx, + conversation_id=123, + title="New Title" + ) + + # Second call with same params should raise LimitExceededError + with pytest.raises(LimitExceededError): + await ns.update_conversation_title( + ctx=ctx, + conversation_id=123, + title="New Title" + ) + + +class TestReleaseIdempotencyAfterDelay: + """Tests for _release_idempotency_after_delay function.""" + + @pytest.mark.asyncio + async def test_release_after_delay(self): + """Test that idempotency key is released after delay.""" + import asyncio + + await ns.idempotency_start("delayed-key") + assert "delayed-key" in ns._IDEMPOTENCY_RUNNING + + asyncio.create_task(ns._release_idempotency_after_delay("delayed-key", seconds=0.1)) + await asyncio.sleep(0.2) + + assert "delayed-key" not in ns._IDEMPOTENCY_RUNNING + + +class TestMinuteBucket: + """Tests for _minute_bucket helper function.""" + + def test_minute_bucket_returns_string(self): + """Test that minute bucket is a string.""" + bucket = ns._minute_bucket() + assert isinstance(bucket, str) + + def test_minute_bucket_consistent_for_same_time(self): + """Test that same time produces same bucket.""" + ts = 1234567890.0 + bucket1 = ns._minute_bucket(ts) + bucket2 = ns._minute_bucket(ts) + assert bucket1 == bucket2 + + def test_minute_bucket_different_for_different_minutes(self): + """Test that different minutes produce different buckets.""" + ts1 = 1000000.0 + ts2 = ts1 + 60 + bucket1 = ns._minute_bucket(ts1) + bucket2 = ns._minute_bucket(ts2) + assert bucket1 != bucket2 + + +class TestStartStreamingChatErrorHandling: + """Tests for error handling in start_streaming_chat function.""" + + async def test_start_streaming_chat_unauthorized_error(self): + """Test that UnauthorizedError is properly propagated.""" + ctx = MockNorthboundContext(token_id=0) + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock) as mock_limit: + mock_limit.side_effect = UnauthorizedError("Unauthorized") + with pytest.raises(UnauthorizedError): + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=123, + agent_name="test_agent", + query="test query" + ) + + async def test_start_streaming_chat_get_agent_id_error(self): + """Test that get_agent_id_by_name error is wrapped properly.""" + ctx = MockNorthboundContext(token_id=0) + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock), \ + patch.object(ns, 'get_conversation_history_internal', new_callable=AsyncMock) as mock_history, \ + patch.object(ns, 'get_agent_id_by_name', new_callable=AsyncMock) as mock_get_id: + mock_history.return_value = {"data": {"history": []}} + mock_get_id.side_effect = Exception("Agent not found") + + with pytest.raises(Exception) as exc_info: + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=123, + agent_name="nonexistent_agent", + query="test query" + ) + # The exception is wrapped in the outer try/except block + assert "Agent not found" in str(exc_info.value) + + async def test_start_streaming_chat_save_message_error(self): + """Test that save_conversation_user error is wrapped properly.""" + ctx = MockNorthboundContext(token_id=0) + + mock_response = MagicMock() + mock_response.headers = {} + agent_service_mod.run_agent_stream.return_value = mock_response + + async def mock_get_history(*args, **kwargs): + return {"data": {"history": []}} + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_start', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_end', new_callable=AsyncMock), \ + patch.object(ns, 'get_conversation_history_internal', side_effect=mock_get_history), \ + patch.object(ns, 'save_conversation_user', side_effect=Exception("DB error")): + with pytest.raises(Exception) as exc_info: + await ns.start_streaming_chat( + ctx=ctx, + conversation_id=123, + agent_name="test_agent", + query="test query" + ) + assert "Failed to persist user message" in str(exc_info.value) + + async def test_start_streaming_chat_token_logging_failure(self): + """Test that token logging failure is handled gracefully.""" + ctx = MockNorthboundContext(token_id=1) + + mock_response = MagicMock() + mock_response.headers = {} + agent_service_mod.run_agent_stream.return_value = mock_response + token_db_mod.log_token_usage.side_effect = Exception("Logging failed") + + async def mock_get_history(*args, **kwargs): + return {"data": {"history": []}} + + with patch.object(ns, 'check_and_consume_rate_limit', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_start', new_callable=AsyncMock), \ + patch.object(ns, 'idempotency_end', new_callable=AsyncMock), \ + patch.object(ns, 'get_conversation_history_internal', side_effect=mock_get_history): + # Should not raise even if token logging fails + result = await ns.start_streaming_chat( + ctx=ctx, + conversation_id=123, + agent_name="test_agent", + query="test query", + meta_data={"key": "value"} + ) + assert result is not None + + +class TestStopChatErrorHandling: + """Tests for error handling in stop_chat function.""" + + async def test_stop_chat_error(self): + """Test that errors in stop_chat are wrapped properly.""" + ctx = MockNorthboundContext(token_id=0) + agent_service_mod.stop_agent_tasks.side_effect = Exception("Stop failed") + + with pytest.raises(Exception) as exc_info: + await ns.stop_chat(ctx=ctx, conversation_id=123) + assert "Failed to stop chat" in str(exc_info.value) + + async def test_stop_chat_token_logging_failure(self): + """Test that token logging failure is handled gracefully.""" + ctx = MockNorthboundContext(token_id=1) + token_db_mod.log_token_usage.side_effect = Exception("Logging failed") + + with patch("backend.services.northbound_service.stop_agent_tasks", return_value={"message": "stopped"}): + # Should not raise even if token logging fails + result = await ns.stop_chat(ctx=ctx, conversation_id=123, meta_data={"key": "value"}) + assert result is not None + + +class TestListConversationsErrorHandling: + """Tests for error handling in list_conversations function.""" + + async def test_list_conversations_with_metadata_error(self): + """Test that metadata fetch error is handled gracefully.""" + ctx = MockNorthboundContext(token_id=1) + conv_mgmt_mod.get_conversation_list_service.return_value = [ + {"conversation_id": "1", "title": "Test"} + ] + token_db_mod.get_latest_usage_metadata.side_effect = Exception("DB error") + + # Should not raise even if metadata fetch fails + result = await ns.list_conversations(ctx=ctx) + assert result["message"] == "success" + + async def test_list_conversations_empty_meta_data_removed(self): + """Test that empty meta_data keys are removed from items.""" + ctx = MockNorthboundContext(token_id=1) + conv_mgmt_mod.get_conversation_list_service.return_value = [ + {"conversation_id": "1", "title": "Test", "meta_data": {}} + ] + + result = await ns.list_conversations(ctx=ctx) + assert "meta_data" not in result["data"][0] + + async def test_list_conversations_meta_data_with_no_usage_record(self): + """Test that meta_data is removed when get_latest_usage_metadata returns empty.""" + ctx = MockNorthboundContext(token_id=1) + conv_mgmt_mod.get_conversation_list_service.return_value = [ + {"conversation_id": "1", "title": "Test"} + ] + token_db_mod.get_latest_usage_metadata.return_value = None + + result = await ns.list_conversations(ctx=ctx) + assert "meta_data" not in result["data"][0] + + async def test_list_conversations_meta_data_set_when_present(self): + """Test that meta_data is set on item when get_latest_usage_metadata returns a non-empty value.""" + ctx = MockNorthboundContext(token_id=1) + conv_mgmt_mod.get_conversation_list_service.return_value = [ + {"conversation_id": "1", "title": "Test"} + ] + # Reset side_effect and set return_value + token_db_mod.get_latest_usage_metadata.side_effect = None + token_db_mod.get_latest_usage_metadata.return_value = {"query": "test query"} + + result = await ns.list_conversations(ctx=ctx) + assert "meta_data" in result["data"][0] + assert result["data"][0]["meta_data"]["query"] == "test query" + + async def test_list_conversations_meta_data_empty_dict_removed(self): + """Test that empty meta_data (empty dict) is removed from item.""" + ctx = MockNorthboundContext(token_id=1) + conv_mgmt_mod.get_conversation_list_service.return_value = [ + {"conversation_id": "1", "title": "Test"} + ] + # Reset side_effect and set return_value to empty dict (falsy) + token_db_mod.get_latest_usage_metadata.side_effect = None + token_db_mod.get_latest_usage_metadata.return_value = {} + + result = await ns.list_conversations(ctx=ctx) + # Empty dict is falsy, so meta_data should be popped + assert "meta_data" not in result["data"][0] + + +class TestGetConversationHistoryErrorHandling: + """Tests for error handling in get_conversation_history function.""" + + async def test_get_conversation_history_error(self): + """Test that errors in get_conversation_history are wrapped properly.""" + ctx = MockNorthboundContext(token_id=0) + # Mock get_conversation_messages to raise an error + conversation_db_mod.get_conversation_messages.side_effect = Exception("DB error") + + with pytest.raises(Exception) as exc_info: + await ns.get_conversation_history(ctx=ctx, conversation_id=123) + assert "Failed to get conversation history" in str(exc_info.value) + + +class TestGetAgentInfoListErrorHandling: + """Tests for get_agent_info_list function.""" + + @pytest.mark.asyncio + async def test_get_agent_info_by_name_success(self): + """Test successful agent ID retrieval.""" + agent_service_mod.get_agent_id_by_name.return_value = 42 + + result = await ns.get_agent_info_by_name("test_agent", "tenant-1") + assert result == 42 + + @pytest.mark.asyncio + async def test_get_agent_info_by_name_error(self): + """Test that errors are wrapped properly.""" + agent_service_mod.get_agent_id_by_name.side_effect = Exception("Agent not found") + + with pytest.raises(Exception) as exc_info: + await ns.get_agent_info_by_name("nonexistent", "tenant-1") + assert "Failed to get agent id" in str(exc_info.value) + assert "nonexistent" in str(exc_info.value) + assert "tenant-1" in str(exc_info.value) + + async def test_get_agent_info_list_error(self): + """Test that errors in get_agent_info_list are wrapped properly.""" + ctx = MockNorthboundContext(tenant_id="asset-owner-tenant", token_id=0) + agent_version_mod.list_published_agents_impl.side_effect = Exception("DB error") + + with pytest.raises(Exception) as exc_info: + await ns.get_agent_info_list(ctx=ctx) + assert "Failed to get agent info list" in str(exc_info.value) + + +class TestUpdateConversationTitleErrorHandling: + """Tests for error handling in update_conversation_title function.""" + + async def test_update_conversation_title_error(self): + """Test that errors in update_conversation_title are wrapped properly.""" + ctx = MockNorthboundContext(token_id=0) + conv_mgmt_mod.update_conversation_title.side_effect = Exception("DB error") + + with pytest.raises(Exception) as exc_info: + await ns.update_conversation_title( + ctx=ctx, + conversation_id=123, + title="New Title" + ) + assert "Failed to update conversation title" in str(exc_info.value) + + async def test_update_conversation_title_token_logging_failure(self): + """Test that token logging failure is handled gracefully.""" + ctx = MockNorthboundContext(token_id=1) + token_db_mod.log_token_usage.side_effect = Exception("Logging failed") + # Ensure update_conversation_title_service succeeds + conv_mgmt_mod.update_conversation_title.side_effect = None + conv_mgmt_mod.update_conversation_title.return_value = True + + # Should not raise even if token logging fails + result = await ns.update_conversation_title( + ctx=ctx, + conversation_id=123, + title="New Title", + meta_data={"key": "value"} + ) + assert result["message"] == "success" + + async def test_update_conversation_title_conversation_not_found(self): + """Test that ConversationNotFoundError is propagated without wrapping.""" + ctx = MockNorthboundContext(token_id=0) + conv_mgmt_mod.update_conversation_title.side_effect = ConversationNotFoundError("Not found") + + with pytest.raises(ConversationNotFoundError): + await ns.update_conversation_title( + ctx=ctx, + conversation_id=123, + title="New Title" + ) + + +class TestNormalizeAttachmentsErrorHandling: + """Tests for error handling in _normalize_northbound_attachments function.""" + + def test_normalize_attachments_parse_s3_url_error(self): + """Test that parse_s3_url ValueError is converted to ValueError.""" + with patch("backend.services.northbound_service.parse_s3_url", side_effect=ValueError("Parse error")): + with pytest.raises(ValueError) as exc_info: + ns._normalize_northbound_attachments( + ["s3://bucket/file.txt"], + "user123", + "tenant123" + ) + assert "Invalid S3 URL format" in str(exc_info.value) + + def test_normalize_attachments_permission_error_invalid_url(self): + """Test that PermissionError with invalid URL is converted to ValueError.""" + with patch("backend.services.northbound_service.parse_s3_url", return_value=("bucket", "path/file.txt")), \ + patch("backend.services.northbound_service.validate_urls_access", + side_effect=PermissionError("Invalid S3 URL format: bad")): + with pytest.raises(ValueError) as exc_info: + ns._normalize_northbound_attachments( + ["s3://bucket/path/file.txt"], + "user123", + "tenant123" + ) + assert "Invalid S3 URL format" in str(exc_info.value) + diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index 53d02206a..acb94f43f 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -118,9 +118,67 @@ class VectorDatabaseCore: smolagents_mod = types.ModuleType("smolagents") smolagents_tools_mod = types.ModuleType("smolagents.tools") + class Tool: + """Mock Tool class that properly handles Pydantic Field definitions.""" + def __init__(self, *args, **kwargs): - pass + from pydantic.fields import FieldInfo + + # Set all provided kwargs as instance attributes + for key, value in kwargs.items(): + setattr(self, key, value) + + # For any Pydantic Field attributes defined in class hierarchy that weren't provided, + # extract their default values + for cls in type(self).__mro__: + if cls is Tool: + continue + if hasattr(cls, '__annotations__'): + for name, hint in cls.__annotations__.items(): + # Skip if already set from kwargs + if name in self.__dict__: + continue + # Check if there's a class attribute that's a FieldInfo + if hasattr(cls, name): + value = getattr(cls, name) + # Unwrap FieldInfo to get the default + if isinstance(value, FieldInfo): + # Handle default_factory + if value.default_factory is not None: + value = value.default_factory() + else: + value = value.default + setattr(self, name, value) + + def __setattr__(self, name, value): + from pydantic.fields import FieldInfo + # Unwrap FieldInfo when it's set after __init__ completes (not from kwargs) + if isinstance(value, FieldInfo): + # Check if this is a class-level default by looking at the class + for cls in type(self).__mro__: + if cls is Tool: + continue + if hasattr(cls, name): + class_attr = getattr(cls, name) + if class_attr is value: + # This is a class-level FieldInfo default, unwrap it + if value.default_factory is not None: + value = value.default_factory() + else: + value = value.default + break + else: + # Not found in class hierarchy, unwrap it anyway + if value.default_factory is not None: + value = value.default_factory() + else: + value = value.default + self.__dict__[name] = value + + def __repr__(self): + return f"" + smolagents_tools_mod.Tool = Tool smolagents_mod.tools = smolagents_tools_mod @@ -497,15 +555,10 @@ def test_init_without_rerank_params(self, mock_observer): observer=mock_observer, ) - # smolagents Tool doesn't properly handle Field defaults, so we check FieldInfo.default - try: - from pydantic import FieldInfo - except ImportError: - from pydantic.fields import FieldInfo - assert isinstance(tool.rerank, FieldInfo) - assert tool.rerank.default is False - assert tool.rerank_model_name.default == "" - assert tool.rerank_model.default is None + # Mock Tool properly unwraps Field defaults, so we check the actual values + assert tool.rerank is False + assert tool.rerank_model_name == "" + assert tool.rerank_model is None def test_forward_with_rerank_enabled(self, mock_observer, mock_vdb_core, mock_embedding_model, mocker): """Test forward method when rerank is enabled and model is provided.""" @@ -1516,3 +1569,210 @@ def test_forward_with_fieldinfo_rerank_default_only(self, mock_observer, mock_vd call_kwargs = mock_vdb_core.hybrid_search.call_args[1] # top_k from default is 3, multiplied by RERANK_OVERSEARCH_MULTIPLIER assert call_kwargs["top_k"] == 3 * RERANK_OVERSEARCH_MULTIPLIER + + +class TestDocumentPathsAccessControl: + """Tests for document_paths access control functionality.""" + + def _create_mock_formatted_results_with_paths(self, paths: list) -> list: + """Create mock search results in FORMATTED format for _filter_by_document_paths tests. + + After search_hybrid processes VDB results, the path_or_url is at the top level. + """ + results = [] + for path in paths: + results.append({ + "path_or_url": path, + "title": f"Document {path}", + "content": f"Content for {path}", + "filename": f"{path}.txt", + "source_type": "file", + "create_time": "2024-01-01T12:00:00Z", + "score": 0.9, + "index": "test_index" + }) + return results + + def _create_mock_vdb_results_with_paths(self, paths: list) -> list: + """Create mock search results in VDB format for forward() tests. + + VDB returns results with a nested 'document' object. + """ + results = [] + for path in paths: + results.append({ + "document": { + "path_or_url": path, + "title": f"Document {path}", + "content": f"Content for {path}", + "filename": f"{path}.txt", + "source_type": "file", + "create_time": "2024-01-01T12:00:00Z", + }, + "score": 0.9, + "index": "test_index" + }) + return results + return results + + def test_filter_by_document_paths_allows_matching(self, mock_vdb_core, mock_embedding_model): + """Test that results with path_or_url in the allowed list are returned.""" + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + document_paths=["s3://bucket/doc1.txt", "s3://bucket/doc2.txt"], + ) + + results = self._create_mock_formatted_results_with_paths(["s3://bucket/doc1.txt", "s3://bucket/doc2.txt", "s3://bucket/doc3.txt"]) + filtered = tool._filter_by_document_paths(results) + + # Only doc1 and doc2 should be returned + assert len(filtered) == 2 + assert all(r.get("path_or_url") in ["s3://bucket/doc1.txt", "s3://bucket/doc2.txt"] for r in filtered) + + def test_filter_by_document_paths_rejects_non_matching(self, mock_vdb_core, mock_embedding_model): + """Test that results with path_or_url NOT in the allowed list are filtered out.""" + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + document_paths=["s3://bucket/doc1.txt"], + ) + + results = self._create_mock_formatted_results_with_paths(["s3://bucket/doc1.txt", "s3://bucket/doc2.txt", "s3://bucket/doc3.txt"]) + filtered = tool._filter_by_document_paths(results) + + # Only doc1 should be returned + assert len(filtered) == 1 + assert filtered[0].get("path_or_url") == "s3://bucket/doc1.txt" + + def test_filter_by_document_paths_empty_list_returns_all(self, mock_vdb_core, mock_embedding_model): + """Test that empty document_paths list returns all results.""" + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + document_paths=[], + ) + + results = self._create_mock_formatted_results_with_paths(["s3://bucket/doc1.txt", "s3://bucket/doc2.txt", "s3://bucket/doc3.txt"]) + filtered = tool._filter_by_document_paths(results) + + # All results should be returned + assert len(filtered) == 3 + + def test_filter_by_document_paths_none_returns_all(self, mock_vdb_core, mock_embedding_model): + """Test that None document_paths (no filter) returns all results.""" + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + document_paths=None, + ) + + results = self._create_mock_formatted_results_with_paths(["s3://bucket/doc1.txt", "s3://bucket/doc2.txt", "s3://bucket/doc3.txt"]) + filtered = tool._filter_by_document_paths(results) + + # All results should be returned + assert len(filtered) == 3 + + def test_filter_by_document_paths_results_missing_path(self, mock_vdb_core, mock_embedding_model): + """Test that results without path_or_url field are filtered out when filter is active.""" + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + document_paths=["s3://bucket/doc1.txt"], + ) + + results = self._create_mock_formatted_results_with_paths(["s3://bucket/doc1.txt"]) + # Add a result without path_or_url (flat format, no nested document) + results.append({ + "title": "No Path", + "content": "This document has no path_or_url", + "filename": "no_path.txt", + "source_type": "file", + "score": 0.8, + "index": "test_index" + }) + + filtered = tool._filter_by_document_paths(results) + + # Only doc1 should be returned + assert len(filtered) == 1 + assert filtered[0].get("path_or_url") == "s3://bucket/doc1.txt" + + def test_set_document_paths_method(self, mock_vdb_core, mock_embedding_model): + """Test the set_document_paths method updates the internal filter.""" + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + document_paths=None, + ) + + # Initially no filter + results = self._create_mock_formatted_results_with_paths(["s3://bucket/doc1.txt", "s3://bucket/doc2.txt"]) + assert len(tool._filter_by_document_paths(results)) == 2 + + # Set document_paths filter + tool.set_document_paths(["s3://bucket/doc1.txt"]) + filtered = tool._filter_by_document_paths(results) + + # Only doc1 should be returned + assert len(filtered) == 1 + assert filtered[0].get("path_or_url") == "s3://bucket/doc1.txt" + + def test_forward_with_document_paths_filter(self, mock_vdb_core, mock_embedding_model, mock_observer): + """Test that forward method applies document_paths filter to search results.""" + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + document_paths=["s3://bucket/doc1.txt"], + top_k=5, + ) + + # Mock VDB returns 3 results, but only 1 matches the filter + # VDB returns nested 'document' format + mock_results = self._create_mock_vdb_results_with_paths(["s3://bucket/doc1.txt", "s3://bucket/doc2.txt", "s3://bucket/doc3.txt"]) + mock_vdb_core.hybrid_search.return_value = mock_results + + result = tool.forward("test query") + search_results = json.loads(result) + + # Only doc1 should be in the result + assert len(search_results) == 1 + assert search_results[0].get("url") == "s3://bucket/doc1.txt" + + def test_forward_with_document_paths_filter_no_results_after_filter(self, mock_vdb_core, mock_embedding_model, mock_observer): + """Test that forward raises exception when all results are filtered out.""" + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + document_paths=["s3://bucket/nonexistent.txt"], + top_k=5, + ) + + # Mock VDB returns 3 results, none match the filter + mock_results = self._create_mock_vdb_results_with_paths(["s3://bucket/doc1.txt", "s3://bucket/doc2.txt", "s3://bucket/doc3.txt"]) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Should raise exception because after filtering, no results remain + with pytest.raises(Exception) as excinfo: + tool.forward("test query") + + assert "No results found" in str(excinfo.value) +