Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 122 additions & 16 deletions backend/agents/create_agent_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import threading
import logging
from typing import List, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import urljoin

from jinja2 import Template, StrictUndefined
Expand Down Expand Up @@ -33,12 +33,71 @@
from utils.config_utils import tenant_config_manager, get_model_name_from_config
from utils.context_utils import build_context_components
from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET
from consts.model import AgentToolParamsRequest, ToolParamsRequest
from consts.exceptions import ValidationError

logger = logging.getLogger("create_agent_info")
logger.setLevel(logging.DEBUG)


def _normalize_tool_params_request(tool_params: Optional[ToolParamsRequest | Dict[str, Any]]) -> ToolParamsRequest:
"""Normalize request-scoped tool parameter overrides into a ToolParamsRequest."""
if tool_params is None:
return ToolParamsRequest()
if isinstance(tool_params, ToolParamsRequest):
return tool_params
if not isinstance(tool_params, dict):
raise ValidationError("tool_params must be an object.")
try:
return ToolParamsRequest.model_validate(tool_params)
except Exception as exc:
raise ValidationError(f"Invalid tool_params payload: {exc}") from exc


def _get_agent_tool_overrides(
tool_params: Optional[ToolParamsRequest],
agent_name: Optional[str],
) -> Dict[str, Dict[str, Any]]:
"""Resolve tool overrides for a specific agent by its name."""
if tool_params is None:
return {}
if not agent_name:
return {}
agent_override = tool_params.agents.get(agent_name)
if agent_override is None:
return {}
return dict(agent_override.tools)


def _merge_tool_params(
tool_record: Dict[str, Any],
override_params: Optional[Dict[str, Any]],
extra_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Merge request overrides on top of tool instance defaults from DB.

Args:
tool_record: Tool configuration from database
override_params: Request-scoped overrides from tool_params
extra_params: Additional internal params not in DB schema (e.g., document_paths)

Returns:
Merged params dict with DB defaults, overrides, and extra params
"""
merged_params: Dict[str, Any] = {}
for param in tool_record.get("params", []):
merged_params[param["name"]] = param.get("default")

if override_params:
merged_params.update(override_params)

# Extra params (e.g., internal access control params) always take precedence
if extra_params:
merged_params.update(extra_params)

return merged_params


def _build_internal_s3_url(file: dict) -> str:
"""Build a valid S3 URL for internal tools from uploaded file metadata."""
if not isinstance(file, dict):
Expand Down Expand Up @@ -310,7 +369,9 @@
allow_memory_search: bool = True,
version_no: int = 0,
override_model_id: int | None = None,
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
):
normalized_tool_params = _normalize_tool_params_request(tool_params)
agent_info = search_agent_info_by_agent_id(
agent_id=agent_id, tenant_id=tenant_id, version_no=version_no)

Expand All @@ -331,13 +392,20 @@
allow_memory_search=allow_memory_search,
version_no=sub_agent_version_no,
override_model_id=None,
tool_params=normalized_tool_params,
)
managed_agents.append(sub_agent_config)

# create external A2A agents (synchronous function, no await needed)
external_a2a_agents = _get_external_a2a_agents(agent_id, tenant_id, version_no)

tool_list = await create_tool_config_list(agent_id, tenant_id, user_id, version_no=version_no)
tool_list = await create_tool_config_list(
agent_id,
tenant_id,
user_id,
version_no=version_no,
tool_params=normalized_tool_params,
)

# Build system prompt: prioritize segmented fields, fallback to original prompt field if not available
duty_prompt = agent_info.get("duty_prompt", "")
Expand Down Expand Up @@ -562,17 +630,43 @@
return agent_config


async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int = 0):
# create tool
async def create_tool_config_list(

Check failure on line 633 in backend/agents/create_agent_info.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 55 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=ModelEngine-Group_nexent&issues=AZ62g2Y_t7nt5E6mViji&open=AZ62g2Y_t7nt5E6mViji&pullRequest=3223
agent_id,
tenant_id,
user_id,
version_no: int = 0,
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
):
tool_config_list = []
langchain_tools = await discover_langchain_tools()
normalized_tool_params = _normalize_tool_params_request(tool_params)

# now only admin can modify the agent, user_id is not used
tools_list = search_tools_for_sub_agent(agent_id, tenant_id, version_no=version_no)

# Look up agent name for use in error messages.
# Agent name is optional for tool_params matching (matching uses tool identifiers only),
# but we include it in error messages so callers can identify which agent/tool caused a failure.
agent_info = search_agent_info_by_agent_id(agent_id=agent_id, tenant_id=tenant_id, version_no=version_no)
agent_name = agent_info.get("name") if agent_info else None
agent_tool_overrides = _get_agent_tool_overrides(normalized_tool_params, agent_name)

tool_keys_seen = set()
for tool in tools_list:
param_dict = {}
for param in tool.get("params", []):
param_dict[param["name"]] = param.get("default")
tool_identifier = tool.get("name") or tool.get("class_name")
if tool_identifier in tool_keys_seen:
raise ValidationError(
f"Duplicate tool identifier '{tool_identifier}' found in agent '{agent_name or agent_id}'."
)
tool_keys_seen.add(tool_identifier)

override_params = None
if tool.get("name") in agent_tool_overrides:
override_params = agent_tool_overrides[tool.get("name")]
elif tool.get("class_name") in agent_tool_overrides:
override_params = agent_tool_overrides[tool.get("class_name")]

param_dict = _merge_tool_params(tool, override_params)
tool_config = ToolConfig(
class_name=tool.get("class_name"),
name=tool.get("name"),
Expand All @@ -591,20 +685,29 @@
tool_config.metadata = langchain_tool
break

# Extract document_paths for KnowledgeBaseSearchTool (internal access control, not in DB schema)
document_paths = None
if override_params and "document_paths" in override_params:
document_paths = override_params.get("document_paths")
# Also check using the tool name as key
if not document_paths:
kb_overrides = agent_tool_overrides.get("knowledge_base_search")
if kb_overrides and "document_paths" in kb_overrides:
document_paths = kb_overrides.get("document_paths")

# special logic for search tools that may use reranking models
if tool_config.class_name == "KnowledgeBaseSearchTool":
rerank = param_dict.get("rerank", False)
rerank_model_name = param_dict.get("rerank_model_name", "")
rerank = tool_config.params.get("rerank", False)
rerank_model_name = tool_config.params.get("rerank_model_name", "")
rerank_model = None
is_multimodal = bool(tool_config.params.pop("multimodal", False))
if rerank and rerank_model_name:
rerank_model = get_rerank_model(
tenant_id=tenant_id, model_name=rerank_model_name
)

# Build display_name to index_name mapping for LLM parameter conversion
# Also build reverse mapping (index_name -> display_name) for knowledge_base_summary
index_names = param_dict.get("index_names", [])
index_names = tool_config.params.get("index_names", [])
display_name_to_index_map = {}
index_name_to_display_map = {}
if index_names:
Expand All @@ -620,12 +723,14 @@
"rerank_model": rerank_model,
"display_name_to_index_map": display_name_to_index_map,
"index_name_to_display_map": index_name_to_display_map,
# Internal access control: restrict results to specific document paths (path_or_urls)
"document_paths": document_paths,
}

# Must have embedding model for knowledge base search
if not index_names:
raise ValidationError(
"Embedding model is required for knowledge_base_search but index_names is empty")
f"[{agent_name or agent_id}] knowledge_base_search tool requires index_names, "
f"but it is not configured in the agent and not provided via tool_params.")

embedding_model, _, _ = get_embedding_model_by_index_name(tenant_id, index_names[0])
if not embedding_model:
Expand All @@ -634,8 +739,8 @@
f"Please configure an embedding model for this knowledge base.")
tool_config.metadata["embedding_model"] = embedding_model
elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]:
rerank = param_dict.get("rerank", False)
rerank_model_name = param_dict.get("rerank_model_name", "")
rerank = tool_config.params.get("rerank", False)
rerank_model_name = tool_config.params.get("rerank_model_name", "")
rerank_model = None
if rerank and rerank_model_name:
rerank_model = get_rerank_model(
Expand Down Expand Up @@ -929,6 +1034,7 @@
is_debug: bool = False,
override_version_no: int | None = None,
override_model_id: int | None = None,
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
):
# Determine which version_no to use based on is_debug flag
# If is_debug=false, use the current published version (current_version_no)
Expand Down Expand Up @@ -961,7 +1067,7 @@
if override_model_id is not None:
create_config_kwargs["override_model_id"] = override_model_id

agent_config = await create_agent_config(**create_config_kwargs)
agent_config = await create_agent_config(**create_config_kwargs, tool_params=tool_params)

remote_mcp_list = await get_remote_mcp_server_list(tenant_id=tenant_id, is_need_auth=True)
default_mcp_url = urljoin(LOCAL_MCP_SERVER, "sse")
Expand Down
Loading
Loading