Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions src/tests/test_artifact_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,80 @@ async def test_explicit_profile_maps_correctly(self, mock_get_api_key):

call_args = mock_client.post.call_args
assert call_args[1]["json"]["profile"] == "InheritanceOnly"
# No data_source supplied => omitted from the body.
assert "dataSource" not in call_args[1]["json"]

@pytest.mark.asyncio
@patch("tools.artifact_relationships.get_api_key_from_context")
async def test_forwards_data_source(self, mock_get_api_key):
mock_get_api_key.return_value = "test_key"

ctx = MagicMock(spec=Context)
ctx.debug = AsyncMock()
ctx.error = AsyncMock()

mock_response = MagicMock()
mock_response.json.return_value = {
"sourceIdentifier": "id",
"profile": "CallsOnly",
"found": True,
"relationships": [],
}
mock_response.raise_for_status = MagicMock()

mock_client = AsyncMock()
mock_client.post.return_value = mock_response

mock_context = MagicMock()
mock_context.client = mock_client
mock_context.base_url = "https://app.codealive.ai"
ctx.request_context.lifespan_context = mock_context

await get_artifact_relationships(
ctx=ctx,
identifier="id",
data_source="repo (main)",
)

assert mock_client.post.call_args[1]["json"]["dataSource"] == "repo (main)"

@pytest.mark.asyncio
@patch("tools.artifact_relationships.get_api_key_from_context")
async def test_ambiguous_409_surfaces_candidate_data_sources(self, mock_get_api_key):
import httpx

mock_get_api_key.return_value = "test_key"

ctx = MagicMock(spec=Context)
ctx.debug = AsyncMock()
ctx.error = AsyncMock()

mock_response = MagicMock()
mock_response.status_code = 409
mock_response.text = (
'{"detail": "Identifier matches 2 data sources: '
"Name='repo (main)' Id='ds-main', Name='repo (master)' Id='ds-master'\"}"
)
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Conflict", request=MagicMock(), response=mock_response
)

mock_client = AsyncMock()
mock_client.post.return_value = mock_response

mock_context = MagicMock()
mock_context.client = mock_client
mock_context.base_url = "https://app.codealive.ai"
ctx.request_context.lifespan_context = mock_context

with pytest.raises(ToolError) as exc:
await get_artifact_relationships(ctx=ctx, identifier="org/repo::path::Symbol")

message = str(exc.value)
assert "409" in message
# The candidate data sources from the backend 409 must be surfaced, plus the data_source retry hint.
assert "repo (main)" in message and "repo (master)" in message
assert "data_source" in message

@pytest.mark.asyncio
async def test_empty_identifier_raises_tool_error(self):
Expand Down Expand Up @@ -446,3 +520,21 @@ async def test_not_found_response_renders_correctly(self, mock_get_api_key):

assert data["found"] is False
assert "relationships" not in data

def test_not_found_hint_with_data_source_suggests_retry_or_omit(self):
payload = _build_relationships_dict(
{"sourceIdentifier": "org/repo::path::S", "profile": "CallsOnly", "found": False},
data_source="repo (main)",
)
hint = payload["hint"]
assert "repo (main)" in hint
assert "data_source" in hint
assert "omit" in hint.lower()

def test_not_found_hint_without_data_source_is_generic(self):
payload = _build_relationships_dict(
{"sourceIdentifier": "org/repo::path::S", "profile": "CallsOnly", "found": False},
)
hint = payload["hint"]
assert "data_source" not in hint
assert "fresh identifier" in hint
70 changes: 70 additions & 0 deletions src/tests/test_fetch_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,39 @@ def test_hint_absent_when_no_artifacts_have_content(self):
assert "<hint>" not in result


class TestBuildArtifactsXmlDataSourceMissHint:
"""When a data_source was supplied but nothing was found, hint to retry or drop it."""

def test_hint_when_data_source_scoped_returns_nothing(self):
data = {"artifacts": [
{"identifier": "repo::a.ts::F", "content": None, "contentByteSize": None},
]}
result = _build_artifacts_xml(data, data_source="repo (main)")
assert "<hint>" in result
assert "repo (main)" in result
# Guides toward the two recovery moves.
assert "data_source" in result
assert "omit" in result.lower()

def test_hint_when_empty_artifacts_and_data_source(self):
result = _build_artifacts_xml({"artifacts": []}, data_source="ds-main")
assert "ds-main" in result and "<hint>" in result

def test_no_miss_hint_when_data_source_resolved_content(self):
data = {"artifacts": [
{"identifier": "repo::a.ts::F", "content": "code", "contentByteSize": 4},
]}
result = _build_artifacts_xml(data, data_source="repo (main)")
assert "omit data_source" not in result

def test_no_miss_hint_without_data_source(self):
data = {"artifacts": [
{"identifier": "repo::a.ts::F", "content": None, "contentByteSize": None},
]}
result = _build_artifacts_xml(data)
assert "<hint>" not in result


@pytest.mark.asyncio
@patch('tools.fetch_artifacts.get_api_key_from_context')
async def test_fetch_artifacts_returns_xml(mock_get_api_key):
Expand Down Expand Up @@ -476,6 +509,43 @@ async def test_fetch_artifacts_posts_correct_body(mock_get_api_key):
body = call_args.kwargs["json"]
assert body["identifiers"] == ["id1", "id2"]
assert "names" not in body
# No data_source supplied => the field is omitted (preserves the 409-on-ambiguity fallback).
assert "dataSource" not in body


@pytest.mark.asyncio
@patch('tools.fetch_artifacts.get_api_key_from_context')
async def test_fetch_artifacts_forwards_data_source(mock_get_api_key):
"""data_source (Name or Id) is forwarded as the DataSource body field when provided."""
mock_get_api_key.return_value = "test_key"

ctx = MagicMock(spec=Context)
ctx.info = AsyncMock()
ctx.warning = AsyncMock()
ctx.error = AsyncMock()

mock_response = MagicMock()
mock_response.json.return_value = {"artifacts": []}
mock_response.raise_for_status = MagicMock()

mock_client = AsyncMock()
mock_client.post.return_value = mock_response

mock_codealive_context = MagicMock()
mock_codealive_context.client = mock_client
mock_codealive_context.base_url = "https://app.codealive.ai"

ctx.request_context.lifespan_context = mock_codealive_context
ctx.request_context.headers = {"authorization": "Bearer test_key"}

await fetch_artifacts(
ctx=ctx,
identifiers=["id1"],
data_source="repo (main)",
)

body = mock_client.post.call_args.kwargs["json"]
assert body["dataSource"] == "repo (main)"


@pytest.mark.asyncio
Expand Down
24 changes: 24 additions & 0 deletions src/tests/test_response_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,34 @@ def test_data_preservation(self):
assert first["identifier"] == "CodeAlive-AI/codealive-mcp::src/tools/search.py::codebase_search"
assert first["contentByteSize"] == 8500
assert first["description"] == "Main search function"
# Data-source identity must be surfaced (not stripped) so the agent can feed it back
# as `data_source` to disambiguate a branch-blind identifier.
assert first["dataSource"] == {"id": "685b21230e3822f4efa9d073", "name": "codealive-mcp"}

assert second["path"] == "README.md"
assert second["kind"] == "Chunk"
assert second["description"] == "Search documentation section"
assert second["dataSource"] == {"id": "685b21230e3822f4efa9d073", "name": "codealive-mcp"}

def test_grep_transform_surfaces_data_source(self):
response = {
"results": [
{
"kind": "File",
"identifier": "owner/repo::src/auth.py",
"location": {"path": "src/auth.py"},
"matchCount": 1,
"matches": [
{"lineNumber": 3, "startColumn": 0, "endColumn": 4, "lineText": "auth"}
],
"dataSource": {"type": "repository", "id": "ds-main", "name": "repo (main)"},
}
]
}

result = transform_grep_response(response)

assert result["results"][0]["dataSource"] == {"id": "ds-main", "name": "repo (main)"}

def test_grep_transform_preserves_match_previews(self):
response = {
Expand Down
38 changes: 33 additions & 5 deletions src/tools/artifact_relationships.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Artifact relationships tool implementation."""

from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Optional
from urllib.parse import urljoin

import httpx
Expand Down Expand Up @@ -37,6 +37,7 @@ async def get_artifact_relationships(
identifier: str,
profile: Literal["callsOnly", "inheritanceOnly", "allRelevant", "referencesOnly"] = "callsOnly",
max_count_per_type: int = 50,
data_source: Optional[str] = None,
) -> Dict[str, Any]:
"""
Retrieve relationship groups for a single artifact by profile.
Expand Down Expand Up @@ -84,6 +85,11 @@ async def get_artifact_relationships(
- "allRelevant": calls + inheritance only; references are excluded
- "referencesOnly": where-used LSP references for non-call usage
max_count_per_type: Maximum related artifacts per relationship type (1–1000, default 50).
data_source: Optional data-source Name or Id used to disambiguate an identifier that
exists in more than one data source. Copy the `dataSource.name` or
`dataSource.id` from a search result. Omit it for normal lookups; if the
source identifier is ambiguous and you omit it, the backend returns a 409
listing the candidate data sources.

Returns:
A dict with grouped relationships:
Expand All @@ -103,6 +109,7 @@ async def get_artifact_relationships(
"identifier": identifier,
"profile": profile,
"max_count_per_type": max_count_per_type,
"data_source": data_source,
}

if not identifier:
Expand Down Expand Up @@ -143,6 +150,8 @@ async def get_artifact_relationships(
"profile": api_profile,
"maxCountPerType": max_count_per_type,
}
if data_source:
body["dataSource"] = data_source
Comment thread
sciapanCA marked this conversation as resolved.

await ctx.debug(f"Fetching {profile} relationships for artifact")

Expand All @@ -156,7 +165,7 @@ async def get_artifact_relationships(
log_api_response(response, request_id)
response.raise_for_status()

return _build_relationships_dict(response.json())
return _build_relationships_dict(response.json(), data_source=data_source)

except (httpx.HTTPStatusError, Exception) as e:
logger.bind(
Expand All @@ -173,15 +182,24 @@ async def get_artifact_relationships(
"(2) call semantic_search or grep_search again to get a fresh identifier — the index may have changed, "
"(3) check that the artifact is a function/class (relationships are not available for non-symbol artifacts)"
),
409: (
"(1) the identifier exists in more than one data source — see the candidate data sources in the Detail above; each one will resolve, "
"(2) retry get_artifact_relationships with data_source set to one candidate's Name or Id; if that data source isn't the one you want, retry with the next candidate, "
"(3) do NOT invent relation results — pick from the listed data sources"
),
},
)


def _build_relationships_dict(data: dict) -> Dict[str, Any]:
def _build_relationships_dict(data: dict, data_source: Optional[str] = None) -> Dict[str, Any]:
"""Build a dict representation of an artifact relationships response.

FastMCP serializes the dict via pydantic_core.to_json, which preserves UTF-8 —
don't reintroduce json.dumps here, it would re-escape non-ASCII identifiers.

``data_source`` is the selector the caller passed (if any); when the source is not
found it shapes the recovery hint so the agent can retry with another data source
or drop the selector.
"""
raw_source_id = data.get("sourceIdentifier") or ""
raw_profile = data.get("profile") or ""
Expand All @@ -208,9 +226,9 @@ def _build_relationships_dict(data: dict) -> Dict[str, Any]:
counts = _build_counts(data.get("availableRelationshipCounts"))
if counts is not None:
payload["availableRelationshipCounts"] = counts
payload["hint"] = _build_relationship_hint(found, mcp_profile, groups, counts)
payload["hint"] = _build_relationship_hint(found, mcp_profile, groups, counts, data_source)
else:
payload["hint"] = _build_relationship_hint(found, mcp_profile, [], None)
payload["hint"] = _build_relationship_hint(found, mcp_profile, [], None, data_source)

return payload

Expand Down Expand Up @@ -266,9 +284,19 @@ def _build_relationship_hint(
profile: str,
groups: List[Dict[str, Any]],
counts: Dict[str, int] | None,
data_source: Optional[str] = None,
) -> str:
"""Give model-facing next-step guidance for graph traversal results."""
if not found:
if data_source:
return (
f'No relationship data was found for this identifier in data source "{data_source}". '
"The identifier may belong to a different data source, or the data_source value may be "
"wrong. Try: re-run with data_source set to a different candidate (use the `dataSource` "
"name or id from your search results, or call get_data_sources), or omit data_source "
"entirely — if the identifier is ambiguous you then get a 409 listing the candidate data "
"sources. Otherwise re-run semantic_search or grep_search to get a fresh identifier."
)
return (
"No relationship data was found for this identifier. Verify that the identifier came from "
"a recent search/fetch result and points to a symbol-level artifact; otherwise re-run "
Expand Down
Loading
Loading