diff --git a/pyproject.toml b/pyproject.toml index fbbd8d82b7a0..933e51759718 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,13 @@ solr = ["sqlalchemy-solr >= 0.2.0"] elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"] exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"] excel = ["xlrd>=1.2.0, <1.3"] -fastmcp = ["fastmcp>=3.2.4,<4.0"] +fastmcp = [ + "fastmcp>=3.2.4,<4.0", + # tiktoken backs the response-size-guard token estimator. Without + # it, the middleware falls back to a coarser character-based + # heuristic that under-counts JSON-heavy MCP responses. + "tiktoken>=0.7.0,<1.0", +] firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"] firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"] gevent = ["gevent>=23.9.1"] diff --git a/requirements/base.txt b/requirements/base.txt index 3c0575c1f4a4..9c99184f9202 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -183,7 +183,9 @@ idna==3.10 # trio # url-normalize isodate==0.7.2 - # via apache-superset (pyproject.toml) + # via + # apache-superset (pyproject.toml) + # apache-superset-core itsdangerous==2.2.0 # via # flask @@ -296,6 +298,7 @@ pyarrow==20.0.0 # via # -r requirements/base.in # apache-superset (pyproject.toml) + # apache-superset-core pyasn1==0.6.3 # via # pyasn1-modules diff --git a/requirements/development.txt b/requirements/development.txt index 28219acfd9c8..b7664ddae6b4 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -442,6 +442,7 @@ isodate==0.7.2 # via # -c requirements/base-constraint.txt # apache-superset + # apache-superset-core isort==6.0.1 # via pylint itsdangerous==2.2.0 @@ -715,6 +716,7 @@ pyarrow==20.0.0 # via # -c requirements/base-constraint.txt # apache-superset + # apache-superset-core # db-dtypes # pandas-gbq pyasn1==0.6.3 @@ -866,6 +868,8 @@ referencing==0.36.2 # jsonschema # jsonschema-path # jsonschema-specifications +regex==2026.4.4 + # via tiktoken requests==2.33.0 # via # -c requirements/base-constraint.txt @@ -878,6 +882,7 @@ requests==2.33.0 # requests-cache # requests-oauthlib # shillelagh + # tiktoken # trino requests-cache==1.2.1 # via @@ -1003,6 +1008,8 @@ tabulate==0.9.0 # via # -c requirements/base-constraint.txt # apache-superset +tiktoken==0.12.0 + # via apache-superset tomli-w==1.2.0 # via apache-superset-extensions-cli tomlkit==0.13.3 diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index de592a3da024..64a5145d7fc5 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -41,6 +41,12 @@ DEFAULT_TOKEN_LIMIT, DEFAULT_WARN_THRESHOLD_PCT, ) +from superset.mcp_service.utils.token_utils import ( + estimate_response_tokens, + format_size_limit_error, + INFO_TOOLS, + truncate_oversized_response, +) from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -1104,11 +1110,6 @@ def _try_truncate_info_response( ``content[0].text`` as a JSON string. We parse that string, run the truncation phases on the resulting dict, then re-wrap the result. """ - from superset.mcp_service.utils.token_utils import ( - estimate_response_tokens, - truncate_oversized_response, - ) - # Unwrap ToolResult so truncation operates on the real payload extracted = self._extract_payload_from_tool_result(response) if extracted is not None: @@ -1191,12 +1192,6 @@ async def on_call_tool( # Execute the tool response = await call_next(context) - # Estimate response token count (guard against huge responses causing OOM) - from superset.mcp_service.utils.token_utils import ( - estimate_response_tokens, - format_size_limit_error, - ) - # When the response is a ToolResult, estimate tokens on the actual # payload inside content[0].text rather than on the ToolResult # wrapper (which would double-serialize the JSON string). @@ -1233,8 +1228,6 @@ async def on_call_tool( params = getattr(context.message, "params", {}) or {} # For info tools, try dynamic truncation before blocking - from superset.mcp_service.utils.token_utils import INFO_TOOLS - if tool_name in INFO_TOOLS: truncated = self._try_truncate_info_response( tool_name, response, estimated_tokens diff --git a/superset/mcp_service/utils/token_utils.py b/superset/mcp_service/utils/token_utils.py index 00e6664e729a..b14c8f5a5641 100644 --- a/superset/mcp_service/utils/token_utils.py +++ b/superset/mcp_service/utils/token_utils.py @@ -21,6 +21,26 @@ This module provides utilities to estimate token counts and generate smart suggestions when responses exceed configured limits. This prevents large responses from overwhelming LLM clients like Claude Desktop. + +Token counting strategy: + +1. ``tiktoken`` with the ``cl100k_base`` encoding when the package is + installed (it is shipped as part of the ``fastmcp`` extra). This is a + real BPE tokenizer trained on a similar vocabulary to Claude's; for + English and JSON-heavy MCP payloads it tracks Claude's tokenizer + within roughly ±10%, which is far more accurate than the legacy + character heuristic. +2. A character-based fallback (``CHARS_PER_TOKEN``) when tiktoken is not + importable. The fallback uses a slightly more conservative ratio than + before (3.0 chars/token instead of 3.5) so that JSON-heavy responses + are not under-counted, which previously let oversized payloads slip + past the response-size guard. + +The exact-Claude tokenizer is only available via Anthropic's network +``count_tokens`` API; calling it from a synchronous middleware on every +tool result is too slow and adds an external dependency on every +response. ``tiktoken`` is the closest approximation we can ship without +that risk. """ from __future__ import annotations @@ -36,18 +56,63 @@ # Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes) ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes] -# Approximate characters per token for estimation -# Claude tokenizer averages ~4 chars per token for English text -# JSON tends to be more verbose, so we use a slightly lower ratio -CHARS_PER_TOKEN = 3.5 +# Fallback character-to-token ratio used when tiktoken is unavailable. +# 3.0 is conservative for JSON content (the previous 3.5 under-counted +# JSON-heavy payloads relative to Claude's actual tokenizer, which let +# oversized responses slip past the response-size guard). +CHARS_PER_TOKEN = 3.0 + +# Encoding used when tiktoken is available. cl100k_base is OpenAI's +# tokenizer for GPT-3.5/4; it is BPE-based with a vocabulary similar to +# Claude's and tracks Claude's token counts within roughly ±10% for +# English and JSON-heavy MCP responses. +_TIKTOKEN_ENCODING_NAME = "cl100k_base" + + +def _load_tiktoken_encoding() -> Any: + """Return a tiktoken encoding instance, or None if tiktoken is unavailable. + + Imported lazily so the module can be used in environments without + tiktoken installed. The encoding is small (~1 MB) so we cache it on + first use. + """ + try: + import tiktoken + except ImportError: + logger.info( + "tiktoken not installed; falling back to char-based token " + "estimation (CHARS_PER_TOKEN=%s). Install the 'fastmcp' extra " + "for accurate counts.", + CHARS_PER_TOKEN, + ) + return None + + try: + return tiktoken.get_encoding(_TIKTOKEN_ENCODING_NAME) + except (KeyError, ValueError) as exc: + # tiktoken installed but the requested encoding is missing — this + # only happens on partial installs. Treat as no tokenizer rather + # than crashing on every tool call. + logger.warning( + "tiktoken encoding '%s' unavailable: %s; falling back to " + "char-based token estimation", + _TIKTOKEN_ENCODING_NAME, + exc, + ) + return None + + +# Cached encoding instance (None if tiktoken not importable). +_ENCODING = _load_tiktoken_encoding() def estimate_token_count(text: str | bytes) -> int: """ Estimate the token count for a given text. - Uses a character-based heuristic since we don't have direct access to - the actual tokenizer. This is conservative to avoid underestimating. + Uses tiktoken's ``cl100k_base`` encoding when available for + Claude-aligned accuracy (within ~10%), falling back to a + character-based heuristic otherwise. Args: text: The text to estimate tokens for (string or bytes) @@ -58,11 +123,19 @@ def estimate_token_count(text: str | bytes) -> int: if isinstance(text, bytes): text = text.decode("utf-8", errors="replace") - # Simple heuristic: ~3.5 characters per token for JSON/code - text_length = len(text) - if text_length == 0: + if not text: return 0 - return max(1, int(text_length / CHARS_PER_TOKEN)) + + if _ENCODING is not None: + try: + return len(_ENCODING.encode(text)) + except (ValueError, UnicodeError) as exc: + # Defensive: if tiktoken chokes on a specific input, fall + # back to the char heuristic for this call rather than + # raising — the response size guard must never fail-open. + logger.warning("tiktoken encode failed (%s); using fallback", exc) + + return max(1, int(len(text) / CHARS_PER_TOKEN)) def estimate_response_tokens(response: ToolResponse) -> int: diff --git a/tests/unit_tests/mcp_service/test_middleware.py b/tests/unit_tests/mcp_service/test_middleware.py index e95c0d872205..948ba2547cb2 100644 --- a/tests/unit_tests/mcp_service/test_middleware.py +++ b/tests/unit_tests/mcp_service/test_middleware.py @@ -146,7 +146,13 @@ async def test_skips_excluded_tools(self) -> None: @pytest.mark.asyncio async def test_logs_warning_at_threshold(self) -> None: - """Should log warning when approaching limit.""" + """Should log warning when approaching limit. + + Mocks the token estimator to return a specific value above the + warn threshold but below the hard limit, decoupling the test + from whichever tokenizer (tiktoken or char heuristic) happens + to be loaded. + """ middleware = ResponseSizeGuardMiddleware( token_limit=1000, warn_threshold_pct=80 ) @@ -155,18 +161,21 @@ async def test_logs_warning_at_threshold(self) -> None: context.message.name = "list_charts" context.message.params = {} - # Response at ~85% of limit (should trigger warning but not block) - response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token + response = {"data": "approaching the limit"} call_next = AsyncMock(return_value=response) with ( patch("superset.mcp_service.middleware.get_user_id", return_value=1), patch("superset.mcp_service.middleware.event_logger"), + patch( + "superset.mcp_service.middleware.estimate_response_tokens", + return_value=850, + ), patch("superset.mcp_service.middleware.logger") as mock_logger, ): result = await middleware.on_call_tool(context, call_next) - # Should return response (not blocked) + # Should return response (not blocked at 85% of limit) assert result == response # Should log warning mock_logger.warning.assert_called() diff --git a/tests/unit_tests/mcp_service/utils/test_token_utils.py b/tests/unit_tests/mcp_service/utils/test_token_utils.py index 9a49264bd935..4254bd9f539e 100644 --- a/tests/unit_tests/mcp_service/utils/test_token_utils.py +++ b/tests/unit_tests/mcp_service/utils/test_token_utils.py @@ -20,9 +20,11 @@ """ from typing import Any, List +from unittest.mock import patch from pydantic import BaseModel +from superset.mcp_service.utils import token_utils from superset.mcp_service.utils.token_utils import ( _replace_collections_with_summaries, _summarize_large_dicts, @@ -45,29 +47,65 @@ class TestEstimateTokenCount: """Test estimate_token_count function.""" def test_estimate_string(self) -> None: - """Should estimate tokens for a string.""" + """Should produce a positive non-zero estimate for a normal string. + + We don't assert on a specific number because the result depends on + which tokenizer is loaded (tiktoken when available, char heuristic + otherwise). + """ text = "Hello world" result = estimate_token_count(text) - expected = int(len(text) / CHARS_PER_TOKEN) - assert result == expected + assert result > 0 def test_estimate_bytes(self) -> None: - """Should estimate tokens for bytes.""" - text = b"Hello world" - result = estimate_token_count(text) - expected = int(len(text) / CHARS_PER_TOKEN) - assert result == expected + """Bytes input should be decoded and produce the same count as the + equivalent string.""" + text = "Hello world" + assert estimate_token_count(text.encode("utf-8")) == estimate_token_count(text) def test_empty_string(self) -> None: - """Should return 0 for empty string.""" + """Should return 0 for empty string and empty bytes.""" assert estimate_token_count("") == 0 + assert estimate_token_count(b"") == 0 def test_json_like_content(self) -> None: - """Should estimate tokens for JSON-like content.""" + """JSON content should produce a positive estimate.""" json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}' - result = estimate_token_count(json_str) - assert result > 0 - assert result == int(len(json_str) / CHARS_PER_TOKEN) + assert estimate_token_count(json_str) > 0 + + def test_long_text_roughly_scales_with_length(self) -> None: + """A doubled string should produce roughly double the token count + (within ±10%).""" + small = "the quick brown fox jumps over the lazy dog. " * 20 + large = small * 2 + small_n = estimate_token_count(small) + large_n = estimate_token_count(large) + # Within 10% of 2x — both tokenizers (tiktoken and the char + # fallback) preserve length monotonicity. + assert 1.8 * small_n <= large_n <= 2.2 * small_n + + def test_fallback_uses_chars_per_token_when_tiktoken_unavailable( + self, + ) -> None: + """When the tiktoken encoding is None (not installed), the + function falls back to len/CHARS_PER_TOKEN math.""" + text = "x" * 100 + with patch.object(token_utils, "_ENCODING", None): + result = estimate_token_count(text) + assert result == int(100 / CHARS_PER_TOKEN) + + def test_fallback_when_tiktoken_encode_raises(self) -> None: + """A misbehaving encoding should fall back to the char heuristic + rather than raise — the size guard must never fail-open.""" + + class BoomEncoding: + def encode(self, text: str) -> list[int]: + raise ValueError("simulated tiktoken failure") + + text = "abc" * 50 + with patch.object(token_utils, "_ENCODING", BoomEncoding()): + result = estimate_token_count(text) + assert result == int(len(text) / CHARS_PER_TOKEN) class TestEstimateResponseTokens: