From dd45b6b1e244c6d80f107452a6e0d0b0aecb01c3 Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 27 May 2026 14:34:46 +0530 Subject: [PATCH 01/21] scaffold oauthbearer-aws extra + subpackage placeholders --- pyproject.toml | 10 ++++--- requirements/requirements-oauthbearer-aws.txt | 1 + src/confluent_kafka/_util/kv_string_parser.py | 18 +++++++++++++ src/confluent_kafka/oauthbearer/__init__.py | 19 ++++++++++++++ .../oauthbearer/aws/__init__.py | 26 +++++++++++++++++++ .../oauthbearer/aws/_aws_iam_marker.py | 23 ++++++++++++++++ .../aws/_aws_jwt_subject_extractor.py | 20 ++++++++++++++ .../aws/_aws_oauthbearer_config.py | 20 ++++++++++++++ .../aws/_aws_sasl_extensions_parser.py | 20 ++++++++++++++ .../aws/_aws_sts_token_provider.py | 24 +++++++++++++++++ .../oauthbearer/aws/aws_autowire.py | 24 +++++++++++++++++ tests/oauthbearer/__init__.py | 0 tests/oauthbearer/aws/__init__.py | 0 13 files changed, 202 insertions(+), 3 deletions(-) create mode 100644 requirements/requirements-oauthbearer-aws.txt create mode 100644 src/confluent_kafka/_util/kv_string_parser.py create mode 100644 src/confluent_kafka/oauthbearer/__init__.py create mode 100644 src/confluent_kafka/oauthbearer/aws/__init__.py create mode 100644 src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py create mode 100644 src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py create mode 100644 src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py create mode 100644 src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py create mode 100644 src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py create mode 100644 src/confluent_kafka/oauthbearer/aws/aws_autowire.py create mode 100644 tests/oauthbearer/__init__.py create mode 100644 tests/oauthbearer/aws/__init__.py diff --git a/pyproject.toml b/pyproject.toml index f89aa33c7..da7de97fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ optional-dependencies.rules = { file = ["requirements/requirements-rules.txt", " optional-dependencies.avro = { file = ["requirements/requirements-avro.txt", "requirements/requirements-schemaregistry.txt"] } optional-dependencies.json = { file = ["requirements/requirements-json.txt", "requirements/requirements-schemaregistry.txt"] } optional-dependencies.protobuf = { file = ["requirements/requirements-protobuf.txt", "requirements/requirements-schemaregistry.txt"] } +optional-dependencies.oauthbearer-aws = { file = ["requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.dev = { file = [ "requirements/requirements-docs.txt", "requirements/requirements-examples.txt", @@ -114,7 +115,8 @@ optional-dependencies.dev = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.docs = { file = [ "requirements/requirements-docs.txt", "requirements/requirements-schemaregistry.txt", @@ -128,7 +130,8 @@ optional-dependencies.tests = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.examples = { file = ["requirements/requirements-examples.txt"] } optional-dependencies.soaktest = { file = ["requirements/requirements-soaktest.txt"] } optional-dependencies.all = { file = [ @@ -140,7 +143,8 @@ optional-dependencies.all = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/requirements/requirements-oauthbearer-aws.txt b/requirements/requirements-oauthbearer-aws.txt new file mode 100644 index 000000000..3c332eafc --- /dev/null +++ b/requirements/requirements-oauthbearer-aws.txt @@ -0,0 +1 @@ +boto3>=1.42.25 diff --git a/src/confluent_kafka/_util/kv_string_parser.py b/src/confluent_kafka/_util/kv_string_parser.py new file mode 100644 index 000000000..f8e15c64a --- /dev/null +++ b/src/confluent_kafka/_util/kv_string_parser.py @@ -0,0 +1,18 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic key=value string parser shared across internal features. + +Placeholder for Phase 1; implementation lands in Phase 2. +""" diff --git a/src/confluent_kafka/oauthbearer/__init__.py b/src/confluent_kafka/oauthbearer/__init__.py new file mode 100644 index 000000000..0f4eb3e04 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Namespace package for OAUTHBEARER provider integrations. + +Cloud-specific subpackages (e.g. :mod:`confluent_kafka.oauthbearer.aws`) live +under here and are gated by their own PEP 621 extras. +""" diff --git a/src/confluent_kafka/oauthbearer/aws/__init__.py b/src/confluent_kafka/oauthbearer/aws/__init__.py new file mode 100644 index 000000000..48ecf1f87 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AWS IAM OAUTHBEARER autowire subpackage. + +The only publicly importable name in this subpackage is +:func:`confluent_kafka.oauthbearer.aws.aws_autowire.create_handler`, loaded by +core's C extension when the user sets +``sasl.oauthbearer.metadata.authentication.type=aws_iam``. All other modules +are private (underscore-prefixed) and not re-exported here. + +Install with:: + + pip install 'confluent-kafka[oauthbearer-aws]' +""" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py new file mode 100644 index 000000000..b9721453e --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py @@ -0,0 +1,23 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Marker constants identifying the AWS IAM OAUTHBEARER autowire path. + +Placeholder for Phase 1; constants land in Phase 2. + +The C dispatcher in ``src/confluent_kafka/src/confluent_kafka.c`` keeps its +own literal copies of these values for compile-time use; a drift-guard test +in ``tests/oauthbearer/aws/test_aws_iam_marker.py`` asserts the two stay in +sync. +""" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py new file mode 100644 index 000000000..4a382e329 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py @@ -0,0 +1,20 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: extracts the ``sub`` claim from an unverified AWS-minted JWT. + +Placeholder for Phase 1; implementation lands in Phase 2. + +Mirrors .NET's ``AwsJwtSubjectExtractor.cs``. +""" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py new file mode 100644 index 000000000..e75eda523 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py @@ -0,0 +1,20 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: validated ``sasl.oauthbearer.config`` dataclass + parser. + +Placeholder for Phase 1; implementation lands in Phase 3. + +Mirrors .NET's ``AwsOAuthBearerConfig.cs`` (dataclass + ``Parse`` classmethod). +""" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py b/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py new file mode 100644 index 000000000..36813a5a8 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py @@ -0,0 +1,20 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: parser for the ``sasl.oauthbearer.extensions`` config property. + +Placeholder for Phase 1; implementation lands in Phase 2. + +Mirrors .NET's ``AwsSaslExtensionsParser.cs``. +""" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py new file mode 100644 index 000000000..d4ce2f58f --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py @@ -0,0 +1,24 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: AWS STS ``GetWebIdentityToken``-based OAUTHBEARER token provider. + +Placeholder for Phase 1; implementation lands in Phase 3. + +Mirrors .NET's ``AwsStsTokenProvider.cs``. The top-level ``import boto3`` that +will land in Phase 3 is the gate that makes this whole subpackage opt-in via +the ``oauthbearer-aws`` PEP 621 extra: opt-out users have no ``boto3`` +installed, so any attempt to import this module raises ``ModuleNotFoundError`` +which core's C dispatcher rewrites into a friendly install hint. +""" diff --git a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py new file mode 100644 index 000000000..8b7060aa7 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py @@ -0,0 +1,24 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Frozen public entry-point for AWS IAM OAUTHBEARER autowire. + +Placeholder for Phase 1; ``create_handler`` lands in Phase 4. + +Mirrors .NET's ``AwsAutoWire.cs``. The C dispatcher in +``src/confluent_kafka/src/confluent_kafka.c`` will reach this module via +``PyImport_ImportModule("confluent_kafka.oauthbearer.aws.aws_autowire")``. +The function signature of ``create_handler`` is part of the cross-module +ABI and bumping it is a major-version change on ``confluent-kafka``. +""" diff --git a/tests/oauthbearer/__init__.py b/tests/oauthbearer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/oauthbearer/aws/__init__.py b/tests/oauthbearer/aws/__init__.py new file mode 100644 index 000000000..e69de29bb From 8a9610ee5b211684951d2e0f58b5b812d60d33da Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 27 May 2026 15:27:55 +0530 Subject: [PATCH 02/21] Implement KvStringParser + marker constants + extensions parser + JWT subject extractor --- src/confluent_kafka/_util/kv_string_parser.py | 69 ++++- .../oauthbearer/aws/_aws_iam_marker.py | 23 +- .../aws/_aws_jwt_subject_extractor.py | 101 +++++++- .../aws/_aws_sasl_extensions_parser.py | 39 ++- tests/_util/__init__.py | 0 tests/_util/test_kv_string_parser.py | 134 ++++++++++ tests/oauthbearer/aws/test_aws_iam_marker.py | 42 +++ .../aws/test_aws_jwt_subject_extractor.py | 244 ++++++++++++++++++ .../aws/test_aws_sasl_extensions_parser.py | 90 +++++++ 9 files changed, 730 insertions(+), 12 deletions(-) create mode 100644 tests/_util/__init__.py create mode 100644 tests/_util/test_kv_string_parser.py create mode 100644 tests/oauthbearer/aws/test_aws_iam_marker.py create mode 100644 tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py create mode 100644 tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py diff --git a/src/confluent_kafka/_util/kv_string_parser.py b/src/confluent_kafka/_util/kv_string_parser.py index f8e15c64a..2c84ba86e 100644 --- a/src/confluent_kafka/_util/kv_string_parser.py +++ b/src/confluent_kafka/_util/kv_string_parser.py @@ -12,7 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Generic key=value string parser shared across internal features. +"""Shared utility for parsing ``key=value`` strings.""" -Placeholder for Phase 1; implementation lands in Phase 2. -""" +import re +from typing import Iterable, Iterator, Optional, Tuple + +__all__ = ["parse_kv"] + + +def parse_kv( + raw: str, + separators: Iterable[str], + context_label: Optional[str] = None, + trim_tokens: bool = True, +) -> Iterator[Tuple[str, str]]: + """Tokenize ``raw`` and yield each non-empty token as a ``(key, value)`` pair. + + Tokens are split on any character in ``separators`` (e.g. ``[',']`` for + comma-separated values or ``[' ', '\\t', '\\r', '\\n']`` for + whitespace-separated). Within each token the split is on the first ``=`` + only — values may legitimately contain ``=`` (e.g. URL query strings). + + Empty tokens (e.g. consecutive separators, or whitespace-only tokens when + ``trim_tokens`` is true) are skipped. Tokens with no ``=`` or with ``=`` + at position 0 (empty key) raise :class:`ValueError` with + ``context_label`` woven into the message when supplied. + + The default trimming behaviour mirrors librdkafka's + ``rd_string_split`` (``rdstring.c``). + + :param raw: Input string to tokenize. + :param separators: Iterable of single-character separators. Each element + may be a string of any length, but only its first character is used. + :param context_label: Optional label woven into error messages to + identify which config the malformed token came from. When ``None``, + error messages fall back to a generic ``"key=value"`` phrasing. + :param trim_tokens: When true (default), each token is stripped of + leading and trailing whitespace before being split on ``=``. Set to + false to preserve whitespace inside tokens. + :raises TypeError: ``raw`` or ``separators`` is ``None``. + :raises ValueError: A token is malformed (no ``=`` or empty key). + """ + if raw is None: + raise TypeError("raw must not be None") + if separators is None: + raise TypeError("separators must not be None") + + # Build a single-character split pattern from the supplied separators. + chars = "".join(str(s)[0] for s in separators if s) + if not chars: + # Degenerate "no separators" → treat raw as a single token. + raw_tokens = [raw] + else: + raw_tokens = re.split("[" + re.escape(chars) + "]", raw) + + for raw_token in raw_tokens: + token = raw_token.strip() if trim_tokens else raw_token + if len(token) == 0: + continue + + idx = token.find("=") + if idx <= 0: + what = f"'{context_label}'" if context_label else "key=value" + raise ValueError( + f"Malformed {what} entry '{token}' (expected key=value)." + ) + + yield token[:idx], token[idx + 1:] diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py index b9721453e..5bd95ac95 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py @@ -14,10 +14,23 @@ """Marker constants identifying the AWS IAM OAUTHBEARER autowire path. -Placeholder for Phase 1; constants land in Phase 2. - +Mirrors .NET's ``Confluent.Kafka.Internal.OAuthBearer.Aws.AwsIamMarker``. The C dispatcher in ``src/confluent_kafka/src/confluent_kafka.c`` keeps its -own literal copies of these values for compile-time use; a drift-guard test -in ``tests/oauthbearer/aws/test_aws_iam_marker.py`` asserts the two stay in -sync. +own literal copies of these values for compile-time use; the drift-guard +test in ``tests/oauthbearer/aws/test_aws_iam_marker.py`` asserts the C-side +literals and these Python constants stay in lock-step. + +These strings are part of the cross-language wire contract — bumping either +is a major version change on ``confluent-kafka``. """ + +__all__ = ["AWS_IAM_MARKER_KEY", "AWS_IAM_MARKER_VALUE"] + + +#: librdkafka config key that activates the AWS IAM autowire path when set +#: to :data:`AWS_IAM_MARKER_VALUE`. See the README for the full user contract. +AWS_IAM_MARKER_KEY: str = "sasl.oauthbearer.metadata.authentication.type" + +#: On-wire value of :data:`AWS_IAM_MARKER_KEY` that selects AWS IAM +#: authentication. +AWS_IAM_MARKER_VALUE: str = "aws_iam" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py index 4a382e329..7c268110b 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py @@ -14,7 +14,104 @@ """Internal: extracts the ``sub`` claim from an unverified AWS-minted JWT. -Placeholder for Phase 1; implementation lands in Phase 2. +Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.Internal.AwsJwtSubjectExtractor``. -Mirrors .NET's ``AwsJwtSubjectExtractor.cs``. +No signature verification — AWS STS already signed the JWT and the broker +performs the cryptographic validation. We only decode the unprotected +payload segment to read the ``sub`` claim (the role ARN), which becomes the +``principal`` field handed to ``rd_kafka_oauthbearer_set_token``. + +Strict base64 decoding (``validate=True``) is used so stray non-alphabet +characters raise instead of being silently dropped — ``urlsafe_b64decode`` +is lenient and would let malformed payloads slip past JSON parsing. """ + +import base64 +import binascii +import json + +__all__ = ["extract_sub"] + + +_MAX_TOKEN_LENGTH_CHARS: int = 8192 + + +def extract_sub(jwt: str) -> str: + """Return the ``sub`` claim from the JWT payload. + + :raises ValueError: ``jwt`` is null, empty, oversized, has the wrong + segment count, fails base64url decoding, isn't valid JSON, isn't a + JSON object, or has no ``sub`` string claim (or its value is empty). + """ + if jwt is None: + raise ValueError("JWT is null.") + if jwt == "": + raise ValueError("JWT is empty.") + if len(jwt) > _MAX_TOKEN_LENGTH_CHARS: + raise ValueError( + f"JWT length {len(jwt)} exceeds maximum allowed " + f"({_MAX_TOKEN_LENGTH_CHARS})." + ) + + parts = jwt.split(".") + if len(parts) != 3: + raise ValueError( + f"JWT must have exactly 3 '.'-separated segments; got {len(parts)}." + ) + + payload_bytes = _decode_base64url_segment(parts[1]) + try: + payload_string = payload_bytes.decode("utf-8") + except UnicodeDecodeError as exc: + raise ValueError( + f"JWT payload is not valid UTF-8: {exc}" + ) from exc + + try: + token = json.loads(payload_string) + except json.JSONDecodeError as exc: + raise ValueError( + f"JWT payload is not valid JSON: {exc}" + ) from exc + + if not isinstance(token, dict): + raise ValueError("JWT payload is not a JSON object.") + + if "sub" not in token: + raise ValueError("JWT payload is missing a 'sub' string claim.") + sub = token["sub"] + if not isinstance(sub, str): + raise ValueError("JWT payload is missing a 'sub' string claim.") + if sub == "": + raise ValueError("JWT 'sub' claim value is empty.") + return sub + + +def _decode_base64url_segment(segment: str) -> bytes: + """Base64url-decode a JWT segment to bytes. + + Restores '=' padding to the next 4-byte boundary, swaps '-' → '+' and + '_' → '/', then defers to :func:`base64.b64decode` with strict + ``validate=True``. + """ + if len(segment) == 0: + raise ValueError("JWT payload segment is empty.") + + s = segment.replace("-", "+").replace("_", "/") + remainder = len(s) % 4 + if remainder == 0: + pass + elif remainder == 2: + s += "==" + elif remainder == 3: + s += "=" + else: + # remainder == 1 → not a valid base64url length + raise ValueError("JWT payload segment has invalid base64url length.") + + try: + return base64.b64decode(s.encode("ascii"), validate=True) + except (binascii.Error, UnicodeEncodeError, ValueError) as exc: + raise ValueError( + f"JWT payload segment is not valid base64url: {exc}" + ) from exc diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py b/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py index 36813a5a8..fc8514f5a 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py @@ -14,7 +14,42 @@ """Internal: parser for the ``sasl.oauthbearer.extensions`` config property. -Placeholder for Phase 1; implementation lands in Phase 2. +Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.Internal.AwsSaslExtensionsParser``. -Mirrors .NET's ``AwsSaslExtensionsParser.cs``. +The ``sasl.oauthbearer.extensions`` config carries RFC 7628 §3.1 SASL +extensions as a comma-separated ``key=value`` list. Forwarded verbatim to +the broker alongside the JWT — not part of the JWT itself. """ + +from typing import Dict, Optional + +from confluent_kafka._util.kv_string_parser import parse_kv + +__all__ = ["CONFIG_KEY", "parse"] + + +#: Top-level librdkafka config key carrying the SASL extensions list. +CONFIG_KEY: str = "sasl.oauthbearer.extensions" + + +def parse(raw: Optional[str]) -> Optional[Dict[str, str]]: + """Parse the verbatim ``sasl.oauthbearer.extensions`` value into a dict. + + Grammar: comma-separated ``key=value`` tokens. Whitespace around each + token is trimmed (mirrors .NET / librdkafka). Empty tokens (e.g. + ``"a=1,,b=2,"``) are tolerated and skipped. + + Returns ``None`` for ``None`` / empty input so the autowire layer can + short-circuit without constructing an empty dict. + + :raises ValueError: A token is missing ``=`` or has an empty key. + """ + if raw is None or raw == "": + return None + + result: Dict[str, str] = {} + for key, value in parse_kv(raw, separators=[","], context_label=CONFIG_KEY): + # Last-wins on duplicate keys, mirroring librdkafka. + result[key] = value + + return result if result else None diff --git a/tests/_util/__init__.py b/tests/_util/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/_util/test_kv_string_parser.py b/tests/_util/test_kv_string_parser.py new file mode 100644 index 000000000..f11d99c84 --- /dev/null +++ b/tests/_util/test_kv_string_parser.py @@ -0,0 +1,134 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka._util.kv_string_parser. + +Mirrors .NET's KvStringParserTests. +""" + +import pytest + +from confluent_kafka._util.kv_string_parser import parse_kv + + +# ---- Null guards ---- + + +def test_null_raw_raises(): + with pytest.raises(TypeError): + list(parse_kv(None, [","])) + + +def test_null_separators_raises(): + with pytest.raises(TypeError): + list(parse_kv("a=1", None)) + + +# ---- Empty / whitespace-only input ---- + + +def test_empty_raw_yields_nothing(): + assert list(parse_kv("", [","])) == [] + + +def test_whitespace_only_raw_yields_nothing(): + # After split + trim, no real tokens remain. + assert list(parse_kv(" ", [","])) == [] + + +# ---- Single separator ---- + + +def test_single_separator_comma_basic_case(): + pairs = list(parse_kv("a=1,b=2", [","])) + assert pairs == [("a", "1"), ("b", "2")] + + +# ---- Multiple separators (whitespace-tolerant case) ---- + + +def test_multiple_separators_all_recognized(): + pairs = list(parse_kv("a=1\tb=2\nc=3 d=4", [" ", "\t", "\r", "\n"])) + assert [k for k, _ in pairs] == ["a", "b", "c", "d"] + + +# ---- Empty tokens skipped ---- + + +def test_consecutive_separators_skip_empty_tokens(): + pairs = list(parse_kv("a=1,,b=2,", [","])) + assert len(pairs) == 2 + assert pairs == [("a", "1"), ("b", "2")] + + +# ---- Trim toggle ---- + + +def test_trim_tokens_default_true_strips_leading_and_trailing(): + pairs = list(parse_kv(" a=1 , b=2 ", [","])) + assert pairs == [("a", "1"), ("b", "2")] + + +def test_trim_tokens_false_preserves_whitespace_inside_tokens(): + pairs = list(parse_kv(" a=1 ,b=2 ", [","], trim_tokens=False)) + assert pairs == [(" a", "1 "), ("b", "2 ")] + + +# ---- Malformed tokens ---- + + +def test_token_without_equals_raises(): + with pytest.raises(ValueError, match="Malformed.*bad"): + list(parse_kv("a=1,bad,c=3", [","])) + + +def test_token_starts_with_equals_raises(): + with pytest.raises(ValueError, match="Malformed"): + list(parse_kv("=value", [","])) + + +# ---- Value-content semantics ---- + + +def test_value_contains_equals_kept_verbatim(): + # Split is on FIRST '=' only — the rest stays in the value. + pairs = list(parse_kv("key=val=ue", [","])) + assert pairs == [("key", "val=ue")] + + +def test_empty_value_allowed(): + # RFC 7628 SASL extensions allow empty values; AWS allows empty tag values. + pairs = list(parse_kv("key=", [","])) + assert pairs == [("key", "")] + + +# ---- Duplicate keys preserved (caller's responsibility to dedupe) ---- + + +def test_duplicate_keys_both_yielded(): + pairs = list(parse_kv("a=1,a=2", [","])) + assert pairs == [("a", "1"), ("a", "2")] + + +# ---- context_label weaves into error messages ---- + + +def test_context_label_present_appears_in_error_message(): + with pytest.raises(ValueError, match="my.config.key"): + list(parse_kv("bad", [","], context_label="my.config.key")) + + +def test_context_label_none_uses_generic_error_phrase(): + with pytest.raises(ValueError, match="key=value"): + list(parse_kv("bad", [","])) diff --git a/tests/oauthbearer/aws/test_aws_iam_marker.py b/tests/oauthbearer/aws/test_aws_iam_marker.py new file mode 100644 index 000000000..bbf6fd705 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_iam_marker.py @@ -0,0 +1,42 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Drift-guard tests for the AWS IAM marker constants. + +Phase 2 covers the Python-side half of the drift guard: the literal values +must not move. The end-to-end half (asserting the C dispatcher's literal +strings match these) lands in Phase 5 once the C dispatcher exists. +""" + +from confluent_kafka.oauthbearer.aws._aws_iam_marker import ( + AWS_IAM_MARKER_KEY, + AWS_IAM_MARKER_VALUE, +) + + +def test_marker_key_is_locked_value(): + """The marker key is part of the cross-language wire contract. + + Bumping it would silently break .NET / Go / JS / Python parity. Any change + here MUST be coordinated as a major version bump across all four clients. + """ + assert AWS_IAM_MARKER_KEY == "sasl.oauthbearer.metadata.authentication.type" + + +def test_marker_value_is_locked_value(): + """The marker value is part of the cross-language wire contract. + + Same constraint as :func:`test_marker_key_is_locked_value`. + """ + assert AWS_IAM_MARKER_VALUE == "aws_iam" diff --git a/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py b/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py new file mode 100644 index 000000000..e6fad53eb --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py @@ -0,0 +1,244 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_jwt_subject_extractor. + +Mirrors .NET's AwsJwtSubjectExtractorTests. +""" + +import base64 + +import pytest + +from confluent_kafka.oauthbearer.aws._aws_jwt_subject_extractor import extract_sub + + +# ---- Test helpers ---- + + +def _base64url_encode(data: bytes) -> str: + """Mirror of .NET's AwsTestHelpers.Base64UrlEncode.""" + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _make_jwt(payload_json: str, alg: str = "ES384") -> str: + """Build a 3-segment JWT with the given payload JSON. + + Header and signature segments are placeholders — ``extract_sub`` only + decodes the payload (``parts[1]``). + """ + header = _base64url_encode(f'{{"alg":"{alg}","typ":"JWT"}}'.encode("utf-8")) + payload = _base64url_encode(payload_json.encode("utf-8")) + return f"{header}.{payload}.sig" + + +# ---- Real STS payloads (from .NET test suite — signature stripped, only payload matters) ---- + + +_EXPECTED_ROLE_ARN = "arn:aws:iam::708975691912:role/prashah-iam-sts-test-role" + +_REAL_ES384_JWT = ( + "eyJraWQiOiJFQzM4NF8wIiwidHlwIjoiSldUIiwiYWxnIjoiRVMzODQifQ" + ".eyJhdWQiOiJodHRwczovL2FwaS5leGFtcGxlLmNvbSIsInN1YiI6ImFybjphd3M6aWFtOjo3MDg5N" + "zU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtdGVzdC1yb2xlIiwiaHR0cHM6Ly9zdHMuYW1hem9u" + "YXdzLmNvbS8iOnsiZWMyX2luc3RhbmNlX3NvdXJjZV92cGMiOiJ2cGMtYWQ4NzMzYzQiLCJlYzJfcm9" + "sZV9kZWxpdmVyeSI6IjIuMCIsIm9yZ19pZCI6Im8tMHgzdDh1bW9seiIsImF3c19hY2NvdW50IjoiNz" + "A4OTc1NjkxOTEyIiwib3VfcGF0aCI6WyJvLTB4M3Q4dW1vbHovci16YzVqL291LXpjNWotNXJwd3pqb" + "HIvIl0sIm9yaWdpbmFsX3Nlc3Npb25fZXhwIjoiMjAyNi0wNS0xNFQyMzoxNjozMloiLCJzb3VyY2Vf" + "cmVnaW9uIjoiZXUtbm9ydGgtMSIsImVjMl9zb3VyY2VfaW5zdGFuY2VfYXJuIjoiYXJuOmF3czplYzI" + "6ZXUtbm9ydGgtMTo3MDg5NzU2OTE5MTI6aW5zdGFuY2UvaS0wOTc5MGY5OTY4YzExYTFjNyIsInByaW" + "5jaXBhbF9pZCI6ImFybjphd3M6aWFtOjo3MDg5NzU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtd" + "GVzdC1yb2xlIiwicHJpbmNpcGFsX3RhZ3MiOnsiZGl2dnlfb3duZXIiOiJwcmFzaGFoQGNvbmZsdWVu" + "dC5pbyIsImRpdnZ5X2xhc3RfbW9kaWZpZWRfYnkiOiJwcmFzaGFoQGNvbmZsdWVudC5pbyJ9LCJlYzJ" + "faW5zdGFuY2Vfc291cmNlX3ByaXZhdGVfaXB2NCI6IjE3Mi4zMS4zLjEwOSJ9LCJpc3MiOiJodHRwcz" + "ovL2ExZWJjNzA1LWNkNGQtNDJiNC05M2I1LTk2ZTkzYWNmYjQzMS50b2tlbnMuc3RzLmdsb2JhbC5hc" + "GkuYXdzIiwiZXhwIjoxNzc4Nzc4NzY2LCJpYXQiOjE3Nzg3Nzg0NjYsImp0aSI6IjFkM2QzZmMyLTBl" + "NzktNDI2OS05NjcxLTJmODQ4NDYxOWZiNyJ9" + ".sig" +) + +_REAL_RS256_JWT = ( + "eyJraWQiOiJSU0FfMCIsInR5cCI6IkpXVCIsImFsZyI6IlJTMjU2In0" + ".eyJhdWQiOiJodHRwczovL2FwaS5leGFtcGxlLmNvbSIsInN1YiI6ImFybjphd3M6aWFtOjo3MDg5N" + "zU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtdGVzdC1yb2xlIiwiaHR0cHM6Ly9zdHMuYW1hem9u" + "YXdzLmNvbS8iOnsiZWMyX2luc3RhbmNlX3NvdXJjZV92cGMiOiJ2cGMtYWQ4NzMzYzQiLCJlYzJfcm9" + "sZV9kZWxpdmVyeSI6IjIuMCIsIm9yZ19pZCI6Im8tMHgzdDh1bW9seiIsImF3c19hY2NvdW50IjoiNz" + "A4OTc1NjkxOTEyIiwib3VfcGF0aCI6WyJvLTB4M3Q4dW1vbHovci16YzVqL291LXpjNWotNXJwd3pqb" + "HIvIl0sIm9yaWdpbmFsX3Nlc3Npb25fZXhwIjoiMjAyNi0wNS0xNFQyMzoxNjozMloiLCJzb3VyY2Vf" + "cmVnaW9uIjoiZXUtbm9ydGgtMSIsImVjMl9zb3VyY2VfaW5zdGFuY2VfYXJuIjoiYXJuOmF3czplYzI" + "6ZXUtbm9ydGgtMTo3MDg5NzU2OTE5MTI6aW5zdGFuY2UvaS0wOTc5MGY5OTY4YzExYTFjNyIsInByaW" + "5jaXBhbF9pZCI6ImFybjphd3M6aWFtOjo3MDg5NzU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtd" + "GVzdC1yb2xlIiwicHJpbmNpcGFsX3RhZ3MiOnsiZGl2dnlfb3duZXIiOiJwcmFzaGFoQGNvbmZsdWVu" + "dC5pbyIsImRpdnZ5X2xhc3RfbW9kaWZpZWRfYnkiOiJwcmFzaGFoQGNvbmZsdWVudC5pbyJ9LCJlYzJ" + "faW5zdGFuY2Vfc291cmNlX3ByaXZhdGVfaXB2NCI6IjE3Mi4zMS4zLjEwOSJ9LCJpc3MiOiJodHRwcz" + "ovL2ExZWJjNzA1LWNkNGQtNDJiNC05M2I1LTk2ZTkzYWNmYjQzMS50b2tlbnMuc3RzLmdsb2JhbC5hc" + "GkuYXdzIiwiZXhwIjoxNzc4Nzc4NzY2LCJpYXQiOjE3Nzg3Nzg0NjYsImp0aSI6IjU5OTliZDc1LTBm" + "NTctNDMzMS1iOGExLWJkYzYwMjgxN2UyZiJ9" + ".sig" +) + + +# ---- Happy paths ---- + + +def test_extract_sub_role_arn_returned(): + jwt = _make_jwt('{"sub":"arn:aws:iam::123456789012:role/MyRole","iat":1}') + assert extract_sub(jwt) == "arn:aws:iam::123456789012:role/MyRole" + + +def test_extract_sub_assumed_role_arn_returned(): + jwt = _make_jwt('{"sub":"arn:aws:sts::123456789012:assumed-role/MyRole/session-name"}') + assert extract_sub(jwt) == "arn:aws:sts::123456789012:assumed-role/MyRole/session-name" + + +def test_extract_sub_other_claims_ignored(): + jwt = _make_jwt( + '{"iss":"https://x","sub":"arn:aws:iam::1:role/R",' + '"aud":"a","exp":1,"iat":0,"jti":"j"}' + ) + assert extract_sub(jwt) == "arn:aws:iam::1:role/R" + + +@pytest.mark.parametrize("jwt", [_REAL_ES384_JWT, _REAL_RS256_JWT]) +def test_extract_sub_real_sts_jwt_returns_expected_arn(jwt): + """Cross-language wire-shape parity with .NET / Go / JS / librdkafka.""" + assert extract_sub(jwt) == _EXPECTED_ROLE_ARN + + +# ---- Base64url padding branches (% 4 ∈ {0, 2, 3}) ---- + + +def test_extract_sub_unpadded_base64url_works(): + unpadded = _make_jwt('{"sub":"a"}') + assert extract_sub(unpadded) == "a" + + +def test_extract_sub_padded_base64url_also_works(): + """If user pre-pads with '=' (non-canonical but tolerated), still decodes.""" + header = _base64url_encode(b'{"alg":"none"}') + # Encode WITHOUT stripping padding. + payload = base64.b64encode(b'{"sub":"abc"}').decode("ascii").replace("+", "-").replace("/", "_") + jwt = f"{header}.{payload}." + # Trailing empty segment makes len(parts) != 3 if payload contains '=' — fix segment count. + # The .NET test uses a 3-segment shape with empty signature. Mirror that: + jwt_3seg = f"{header}.{payload}.sig" + assert extract_sub(jwt_3seg) == "abc" + + +def test_extract_sub_url_safe_chars_handled(): + bytes_ = b'{"sub":"x"}' + normal = base64.b64encode(bytes_).decode("ascii") + url_safe = normal.replace("+", "-").replace("/", "_").rstrip("=") + jwt = "aGVhZGVy." + url_safe + ".c2ln" + assert extract_sub(jwt) == "x" + + +@pytest.mark.parametrize("payload_json,expected_sub", [ + ('{"sub":"a"}', "a"), # → encoded length % 4 = 3 (one '=' padded) + ('{"sub":"ab"}', "ab"), # → encoded length % 4 = 0 (no padding) + ('{"sub":"abc"}', "abc"), # → encoded length % 4 = 2 (two '==' padded) +]) +def test_extract_sub_padding_branches_all_hit_decodes_correctly(payload_json, expected_sub): + """Explicit per-branch coverage of the % 4 padding switch.""" + assert extract_sub(_make_jwt(payload_json)) == expected_sub + + +# ---- Top-level shape errors ---- + + +def test_extract_sub_null_raises(): + with pytest.raises(ValueError, match="null"): + extract_sub(None) + + +def test_extract_sub_empty_raises(): + with pytest.raises(ValueError, match="empty"): + extract_sub("") + + +def test_extract_sub_one_segment_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("onlyonepart") + + +def test_extract_sub_two_segments_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("a.b") + + +def test_extract_sub_four_segments_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("a.b.c.d") + + +def test_extract_sub_oversized_input_raises(): + oversized = "a" * 8193 + with pytest.raises(ValueError, match="exceeds maximum"): + extract_sub(oversized) + + +def test_extract_sub_at_ceiling_reaches_parser(): + """8192-char input passes the size gate and fails downstream on segment count.""" + at_ceiling = "a" * 8192 + with pytest.raises(ValueError, match="3"): + extract_sub(at_ceiling) + + +def test_extract_sub_empty_payload_segment_raises(): + with pytest.raises(ValueError, match="empty"): + extract_sub("header..sig") + + +# ---- Payload-content errors ---- + + +def test_extract_sub_malformed_base64_in_payload_raises(): + with pytest.raises(ValueError, match="base64url"): + extract_sub("aGVhZGVy.not!base64.c2ln") + + +def test_extract_sub_malformed_json_in_payload_raises(): + bad_payload = _base64url_encode(b"not json") + with pytest.raises(ValueError, match="not valid JSON"): + extract_sub(f"aGVhZGVy.{bad_payload}.c2ln") + + +def test_extract_sub_payload_is_json_array_raises(): + array_payload = _base64url_encode(b'["not","an","object"]') + with pytest.raises(ValueError, match="not a JSON object"): + extract_sub(f"aGVhZGVy.{array_payload}.c2ln") + + +def test_extract_sub_missing_sub_claim_raises(): + jwt = _make_jwt('{"iss":"https://x","aud":"a"}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_number_raises(): + jwt = _make_jwt('{"sub":12345}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_null_raises(): + jwt = _make_jwt('{"sub":null}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_empty_string_raises(): + jwt = _make_jwt('{"sub":""}') + with pytest.raises(ValueError, match="empty"): + extract_sub(jwt) diff --git a/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py b/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py new file mode 100644 index 000000000..9b1b46d49 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py @@ -0,0 +1,90 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_sasl_extensions_parser. + +Mirrors .NET's AwsSaslExtensionsParserTests. +""" + +import pytest + +from confluent_kafka.oauthbearer.aws._aws_sasl_extensions_parser import parse + + +# ---- Null / empty input → None ---- + + +def test_null_raw_returns_none(): + assert parse(None) is None + + +def test_empty_string_returns_none(): + assert parse("") is None + + +# ---- Happy paths ---- + + +def test_single_entry_returns_one_item(): + result = parse("logicalCluster=lkc-abc") + assert result == {"logicalCluster": "lkc-abc"} + + +def test_multiple_entries_returns_all(): + result = parse("logicalCluster=lkc-abc,identityPoolId=pool-x") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_whitespace_around_commas_trimmed_and_parsed(): + # Each comma-delimited entry is trimmed before parsing. + result = parse(" logicalCluster=lkc-abc , identityPoolId=pool-x ") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_empty_entries_tolerated(): + result = parse("logicalCluster=lkc-abc,,identityPoolId=pool-x,") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_empty_value_accepted(): + # RFC 7628 SASL extensions allow empty values; mirror that. + result = parse("logicalCluster=") + assert result == {"logicalCluster": ""} + + +def test_duplicate_key_last_wins(): + result = parse("k=a,k=b") + assert result == {"k": "b"} + + +# ---- Malformed input → throws ---- + + +def test_missing_equals_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + parse("noEqualsHere") + + +def test_empty_key_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + parse("=value") From 82190640b3b69ea0d967ae94bfc4fc2636b2b7dd Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 27 May 2026 17:15:20 +0530 Subject: [PATCH 03/21] Implement AwsOAuthBearerConfig + AwsStsTokenProvider --- .../aws/_aws_oauthbearer_config.py | 291 ++++++++++++- .../aws/_aws_sts_token_provider.py | 167 ++++++- .../aws/test_aws_oauthbearer_config.py | 408 ++++++++++++++++++ .../aws/test_aws_sts_token_provider.py | 406 +++++++++++++++++ 4 files changed, 1264 insertions(+), 8 deletions(-) create mode 100644 tests/oauthbearer/aws/test_aws_oauthbearer_config.py create mode 100644 tests/oauthbearer/aws/test_aws_sts_token_provider.py diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py index e75eda523..79dbf50cd 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py @@ -14,7 +14,294 @@ """Internal: validated ``sasl.oauthbearer.config`` dataclass + parser. -Placeholder for Phase 1; implementation lands in Phase 3. +Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.Internal.AwsOAuthBearerConfig``. +Encapsulates the wire-grammar parser and the validated, immutable typed view +of the AWS path's ``sasl.oauthbearer.config`` value. -Mirrors .NET's ``AwsOAuthBearerConfig.cs`` (dataclass + ``Parse`` classmethod). +The full grammar (whitespace-separated ``key=value`` pairs, no quoting): + + region= (required) + audience= (required) + duration_seconds=<60..3600> (default: 300) + signing_algorithm=ES384|RS256 (default: ES384) + sts_endpoint= (optional, FIPS / VPC) + principal_name= (optional, override JWT 'sub') + aws_debug=none|console (default: none, Pythonic subset) + tag_= (zero or more JWT custom claims, max 50) + +SASL extensions arrive separately via :data:`sasl_extensions` (parsed from +the typed ``sasl.oauthbearer.extensions`` config property). They are NOT +accepted inside this string under any ``extension_*`` prefix — that key +shape is rejected as an unknown key. """ + +from dataclasses import dataclass +from typing import Dict, Optional + +from confluent_kafka._util.kv_string_parser import parse_kv + +__all__ = [ + "CONFIG_KEY", + "DEFAULT_SIGNING_ALGORITHM", + "ALLOWED_SIGNING_ALGORITHMS", + "MIN_DURATION_SECONDS", + "MAX_DURATION_SECONDS", + "DEFAULT_DURATION_SECONDS", + "TAG_KEY_PREFIX", + "MAX_TAGS", + "AWS_DEBUG_NONE", + "AWS_DEBUG_CONSOLE", + "ALLOWED_AWS_DEBUG_VALUES", + "NET_ONLY_AWS_DEBUG_VALUES", + "AwsOAuthBearerConfig", +] + + +#: Top-level librdkafka config key carrying the AWS-path wire-grammar string. +CONFIG_KEY: str = "sasl.oauthbearer.config" + +#: Default JWT signing algorithm. Matches .NET / cross-language default. +DEFAULT_SIGNING_ALGORITHM: str = "ES384" + +#: Signing algorithms accepted by AWS STS ``GetWebIdentityToken``. +ALLOWED_SIGNING_ALGORITHMS = ("ES384", "RS256") + +#: Minimum / default / maximum token lifetime AWS STS allows +#: (60s and 3600s are AWS-enforced bounds; 300s is the .NET-aligned default). +MIN_DURATION_SECONDS: int = 60 +MAX_DURATION_SECONDS: int = 3600 +DEFAULT_DURATION_SECONDS: int = 300 + +#: Wire-grammar prefix for STS ``Tags`` entries (e.g. ``tag_team=platform``). +TAG_KEY_PREFIX: str = "tag_" + +#: AWS-enforced upper bound on number of tags per ``GetWebIdentityToken`` call. +MAX_TAGS: int = 50 + +#: Sentinel string values for ``aws_debug``. +AWS_DEBUG_NONE: str = "none" +AWS_DEBUG_CONSOLE: str = "console" + +#: ``aws_debug`` values accepted by the Python client. +#: +#: Python takes the **Pythonic subset** decision (locked 2026-05-26): only +#: ``none`` (no-op) and ``console`` (routes botocore logs to stderr via +#: :func:`boto3.set_stream_logger`). The .NET-only sinks ``log4net`` and +#: ``systemdiagnostics`` are rejected with a "not supported on this platform" +#: error so users moving between language clients get a clear signal. +ALLOWED_AWS_DEBUG_VALUES = (AWS_DEBUG_NONE, AWS_DEBUG_CONSOLE) + +#: ``aws_debug`` values that exist in .NET but are not supported in Python. +NET_ONLY_AWS_DEBUG_VALUES = ("log4net", "systemdiagnostics") + + +# Recognised non-tag keys for the wire grammar. Anything else (other than +# ``tag_``) raises "Unknown key" during :meth:`AwsOAuthBearerConfig.parse`. +_RECOGNISED_KEYS = frozenset({ + "region", + "audience", + "duration_seconds", + "signing_algorithm", + "sts_endpoint", + "principal_name", + "aws_debug", +}) + +# Keys whose value must NOT be the empty string (``region= audience=...``). +# Empty values are tolerated for ``tag_`` (mirrors AWS allowing empty +# tag values) and for ``duration_seconds`` (which separately fails integer +# parsing with a clearer error). +_NON_EMPTY_KEYS = frozenset({ + "region", + "audience", + "signing_algorithm", + "sts_endpoint", + "principal_name", + "aws_debug", +}) + + +@dataclass(frozen=True) +class AwsOAuthBearerConfig: + """Validated, immutable view of the AWS path's ``sasl.oauthbearer.config``. + + Construct via :meth:`parse` rather than directly — the classmethod is the + only path that exercises the wire-grammar parser. The dataclass's + :meth:`__post_init__` validates final state regardless of construction + path, so direct construction with bad values still raises. + """ + + region: str + audience: str + signing_algorithm: str = DEFAULT_SIGNING_ALGORITHM + duration_seconds: int = DEFAULT_DURATION_SECONDS + sts_endpoint: Optional[str] = None + principal_name: Optional[str] = None + aws_debug: str = AWS_DEBUG_NONE + tags: Optional[Dict[str, str]] = None + sasl_extensions: Optional[Dict[str, str]] = None + + def __post_init__(self) -> None: + """Final-state validation. Raises :class:`ValueError` on bad input.""" + if not isinstance(self.region, str) or self.region == "": + raise ValueError( + f"{CONFIG_KEY} 'region' must not be empty." + ) + if not isinstance(self.audience, str) or self.audience == "": + raise ValueError( + f"{CONFIG_KEY} 'audience' must not be empty." + ) + if self.signing_algorithm not in ALLOWED_SIGNING_ALGORITHMS: + raise ValueError( + f"{CONFIG_KEY} 'signing_algorithm' must be 'ES384' or 'RS256'; " + f"got {self.signing_algorithm!r}." + ) + # bool is-a int in Python — reject explicitly so True/False doesn't + # slip through as 1/0. + if (not isinstance(self.duration_seconds, int) + or isinstance(self.duration_seconds, bool)): + raise ValueError( + f"{CONFIG_KEY} 'duration_seconds' must be an integer." + ) + if not (MIN_DURATION_SECONDS + <= self.duration_seconds + <= MAX_DURATION_SECONDS): + raise ValueError( + f"{CONFIG_KEY} 'duration_seconds' must be between " + f"{MIN_DURATION_SECONDS} and {MAX_DURATION_SECONDS} inclusive; " + f"got {self.duration_seconds}." + ) + if self.sts_endpoint is not None and self.sts_endpoint == "": + raise ValueError( + f"{CONFIG_KEY} 'sts_endpoint' must not be empty." + ) + if self.principal_name is not None and self.principal_name == "": + raise ValueError( + f"{CONFIG_KEY} 'principal_name' must not be empty." + ) + if self.aws_debug not in ALLOWED_AWS_DEBUG_VALUES: + if self.aws_debug in NET_ONLY_AWS_DEBUG_VALUES: + raise ValueError( + f"{CONFIG_KEY} 'aws_debug={self.aws_debug}' is not " + f"supported on this platform. The Python client supports " + f"{list(ALLOWED_AWS_DEBUG_VALUES)} only; " + f"'log4net' and 'systemdiagnostics' are .NET-only sinks." + ) + raise ValueError( + f"{CONFIG_KEY} 'aws_debug' must be one of: " + f"none, console. Got {self.aws_debug!r}." + ) + if self.tags is not None: + if not isinstance(self.tags, dict): + raise ValueError( + f"{CONFIG_KEY} 'tags' must be a dict." + ) + if len(self.tags) > MAX_TAGS: + raise ValueError( + f"{CONFIG_KEY} has {len(self.tags)} tags; AWS allows at " + f"most {MAX_TAGS}." + ) + if (self.sasl_extensions is not None + and not isinstance(self.sasl_extensions, dict)): + raise ValueError("sasl_extensions, if set, must be a dict.") + + @classmethod + def parse( + cls, + raw: str, + sasl_extensions: Optional[Dict[str, str]] = None, + ) -> "AwsOAuthBearerConfig": + """Parse the verbatim ``sasl.oauthbearer.config`` value. + + Whitespace-separated ``key=value`` tokens; the union of recognised + keys plus ``tag_`` entries. Anything else raises + :class:`ValueError`. Empty values for required-non-empty keys raise + the same. Duplicate keys → last-wins (mirrors .NET). + + :param raw: The verbatim ``sasl.oauthbearer.config`` string. + :param sasl_extensions: Pre-parsed dict from the sibling + ``sasl.oauthbearer.extensions`` property (see + :mod:`._aws_sasl_extensions_parser`). Stored on the config + unchanged. + :raises TypeError: ``raw`` is ``None``. + :raises ValueError: grammar, range, or enum violations. + """ + if raw is None: + raise TypeError("raw must not be None") + + region: Optional[str] = None + audience: Optional[str] = None + signing_algorithm: str = DEFAULT_SIGNING_ALGORITHM + duration_seconds: int = DEFAULT_DURATION_SECONDS + sts_endpoint: Optional[str] = None + principal_name: Optional[str] = None + aws_debug: str = AWS_DEBUG_NONE + tags: Optional[Dict[str, str]] = None + + for key, value in parse_kv( + raw, + separators=[" ", "\t", "\r", "\n"], + context_label=CONFIG_KEY, + ): + if key in _NON_EMPTY_KEYS and value == "": + raise ValueError( + f"{CONFIG_KEY} {key!r} must not be empty." + ) + + if key == "region": + region = value + elif key == "audience": + audience = value + elif key == "signing_algorithm": + signing_algorithm = value + elif key == "duration_seconds": + try: + duration_seconds = int(value) + except ValueError as exc: + raise ValueError( + f"{CONFIG_KEY} 'duration_seconds' must be an integer; " + f"got {value!r}." + ) from exc + elif key == "sts_endpoint": + sts_endpoint = value + elif key == "principal_name": + principal_name = value + elif key == "aws_debug": + # Normalize case (matches .NET's case-insensitive parsing) so + # downstream comparisons against ALLOWED_AWS_DEBUG_VALUES are + # straightforward. + aws_debug = value.lower() + elif key.startswith(TAG_KEY_PREFIX): + tag_name = key[len(TAG_KEY_PREFIX):] + if tag_name == "": + raise ValueError( + f"{CONFIG_KEY} tag key {key!r} has empty name." + ) + if tags is None: + tags = {} + tags[tag_name] = value # last-wins on duplicate tag names + else: + raise ValueError( + f"Unknown key {key!r} in {CONFIG_KEY}." + ) + + if region is None: + raise ValueError( + f"'region' is required in {CONFIG_KEY}." + ) + if audience is None: + raise ValueError( + f"'audience' is required in {CONFIG_KEY}." + ) + + return cls( + region=region, + audience=audience, + signing_algorithm=signing_algorithm, + duration_seconds=duration_seconds, + sts_endpoint=sts_endpoint, + principal_name=principal_name, + aws_debug=aws_debug, + tags=tags, + sasl_extensions=sasl_extensions, + ) diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py index d4ce2f58f..fac6715ac 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py @@ -14,11 +14,166 @@ """Internal: AWS STS ``GetWebIdentityToken``-based OAUTHBEARER token provider. -Placeholder for Phase 1; implementation lands in Phase 3. +Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.Internal.AwsStsTokenProvider``. -Mirrors .NET's ``AwsStsTokenProvider.cs``. The top-level ``import boto3`` that -will land in Phase 3 is the gate that makes this whole subpackage opt-in via -the ``oauthbearer-aws`` PEP 621 extra: opt-out users have no ``boto3`` -installed, so any attempt to import this module raises ``ModuleNotFoundError`` -which core's C dispatcher rewrites into a friendly install hint. +This is the file whose top-level ``import boto3`` is the actual gate that +makes the entire ``oauthbearer-aws`` extra opt-in: opt-out users have no +``boto3`` installed, so transitively importing this module from +:mod:`...aws_autowire` raises ``ModuleNotFoundError`` and the C dispatcher +rewrites it as a friendly install hint at client construction time. + +The class :class:`AwsStsTokenProvider` is constructed once per autowired +client by :func:`...aws_autowire.create_handler` and its bound ``token`` +method is installed as the ``oauth_cb`` Python callable. librdkafka invokes +this method from its background thread on every token refresh; the return +4-tuple matches the C extension's ``oauth_cb`` contract (see +``confluent_kafka.c`` around ``L2291``: ``PyArg_ParseTuple(result, "sd|sO!", +...)``). + +Credential resolution is **lazy**: ``__init__`` constructs the boto3 +``Session`` and STS client objects without making any HTTP calls. The first +``token()`` invocation triggers boto3's default credential chain (env → +shared config → IMDS → ECS → IRSA → SSO). """ + +import logging +from typing import Any, Dict, Optional, Tuple + +import boto3 + +from . import _aws_jwt_subject_extractor +from ._aws_oauthbearer_config import ( + AWS_DEBUG_CONSOLE, + AWS_DEBUG_NONE, + AwsOAuthBearerConfig, +) + +__all__ = ["AwsStsTokenProvider"] + + +# Logger name targeted by ``aws_debug=console``. Routes botocore's HTTP / +# credential-chain / signing diagnostic logs to stderr at DEBUG level. +_BOTOCORE_LOGGER_NAME = "botocore" + + +class AwsStsTokenProvider: + """Mints OAUTHBEARER tokens via AWS STS ``GetWebIdentityToken``. + + The instance method :meth:`token` is shaped to slot directly into the + ``oauth_cb`` C contract (4-tuple of ``(token, expiry_seconds, principal, + extensions_dict)``). + + Construction is lightweight and side-effect-light: the boto3 STS client + is built eagerly but no network call is made until :meth:`token` runs. + The ``aws_debug=console`` side-effect (process-wide + :func:`boto3.set_stream_logger` configuration) does fire at construction + when configured. + """ + + def __init__( + self, + config: AwsOAuthBearerConfig, + sts_client: Optional[Any] = None, + ) -> None: + """Construct a provider bound to ``config``. + + :param config: Validated :class:`AwsOAuthBearerConfig` instance. + :param sts_client: Test seam — when supplied, the provider uses this + client directly instead of constructing a real boto3 STS client. + Production callers pass ``None``. + :raises TypeError: ``config`` is ``None``. + """ + if config is None: + raise TypeError("config must not be None") + self._cfg = config + + self._apply_aws_debug(config.aws_debug) + + if sts_client is not None: + self._sts = sts_client + else: + # boto3.Session() with explicit region_name short-circuits the + # AWS_DEFAULT_REGION env lookup; client() still re-asserts the + # region for STS endpoint resolution. + session = boto3.Session(region_name=config.region) + client_kwargs: Dict[str, Any] = {"region_name": config.region} + if config.sts_endpoint: + client_kwargs["endpoint_url"] = config.sts_endpoint + self._sts = session.client("sts", **client_kwargs) + + @staticmethod + def _apply_aws_debug(aws_debug: str) -> None: + """Apply the ``aws_debug`` side-effect to botocore's logger. + + Process-wide effect, intentionally — mirrors .NET's + ``AWSConfigs.LoggingConfig.LogTo`` behaviour. When the user opts in + with ``aws_debug=console``, every boto3 client in the process gets + DEBUG-level stderr logs. ``aws_debug=none`` is a no-op so any + logging the user has configured elsewhere is preserved. + """ + if aws_debug == AWS_DEBUG_CONSOLE: + boto3.set_stream_logger(_BOTOCORE_LOGGER_NAME, logging.DEBUG) + # AWS_DEBUG_NONE → no-op. Other values are rejected by config validation. + + def token( + self, + oauthbearer_config: str = "", + ) -> Tuple[str, float, str, Dict[str, str]]: + """Mint a fresh JWT and return the ``oauth_cb`` 4-tuple. + + :param oauthbearer_config: The verbatim ``sasl.oauthbearer.config`` + string librdkafka passes back on every refresh. Accepted for + interface completeness but unused — the AWS path's fields are + sourced from the bound :class:`AwsOAuthBearerConfig` at + construction time, not re-parsed per refresh. + + :returns: 4-tuple ``(token, expiry_epoch_seconds, principal, + extensions)`` matching the C ``oauth_cb`` contract. + + :raises botocore.exceptions.ClientError: STS-side error + (``AccessDenied``, ``OutboundWebIdentityFederationDisabled``, + ...). The C ``oauth_cb`` wrapper converts raised exceptions + into ``rd_kafka_oauthbearer_set_token_failure``. + :raises ValueError: STS returned a malformed JWT or missing + ``Expiration``. + """ + request_kwargs: Dict[str, Any] = { + "Audience": [self._cfg.audience], + "SigningAlgorithm": self._cfg.signing_algorithm, + "DurationSeconds": self._cfg.duration_seconds, + } + if self._cfg.tags: + request_kwargs["Tags"] = [ + {"Key": k, "Value": v} for k, v in self._cfg.tags.items() + ] + + response = self._sts.get_web_identity_token(**request_kwargs) + + jwt = response.get("WebIdentityToken") + if not isinstance(jwt, str) or not jwt: + raise ValueError( + "STS response missing WebIdentityToken; cannot mint OAUTHBEARER token." + ) + + expiration = response.get("Expiration") + if expiration is None: + raise ValueError( + "STS response missing Expiration; cannot compute token lifetime." + ) + # boto3 normalises the timestamp to a tz-aware UTC datetime; + # .timestamp() returns epoch seconds as a float. + expiry_epoch_seconds = expiration.timestamp() + + principal = ( + self._cfg.principal_name + if self._cfg.principal_name is not None + else _aws_jwt_subject_extractor.extract_sub(jwt) + ) + + # Always return a dict for the extensions slot — the C oauth_cb + # wrapper's PyArg_ParseTuple uses "O!" with PyDict_Type for that slot, + # which would reject None. Empty dict is the Pythonic equivalent of + # .NET's null-Extensions case. + extensions = dict(self._cfg.sasl_extensions) if self._cfg.sasl_extensions else {} + + return jwt, expiry_epoch_seconds, principal, extensions diff --git a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py new file mode 100644 index 000000000..8876c848e --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py @@ -0,0 +1,408 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_oauthbearer_config. + +Mirrors .NET's AwsOAuthBearerConfigTests, modulo Python conventions and the +Pythonic aws_debug subset decision (none/console only). +""" + +import pytest + +from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import ( + AWS_DEBUG_CONSOLE, + AWS_DEBUG_NONE, + AwsOAuthBearerConfig, +) + + +# ---- Required fields ---- + + +def test_parse_minimal_required_populates_region_and_audience(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + + +def test_parse_missing_region_raises(): + with pytest.raises(ValueError, match="region.*required"): + AwsOAuthBearerConfig.parse("audience=https://a") + + +def test_parse_missing_audience_raises(): + with pytest.raises(ValueError, match="audience.*required"): + AwsOAuthBearerConfig.parse("region=us-east-1") + + +def test_parse_null_input_raises(): + with pytest.raises(TypeError): + AwsOAuthBearerConfig.parse(None) + + +def test_parse_empty_input_raises_for_missing_region(): + with pytest.raises(ValueError, match="region.*required"): + AwsOAuthBearerConfig.parse("") + + +def test_parse_empty_value_on_required_key_raises(): + with pytest.raises(ValueError, match="region.*must not be empty"): + AwsOAuthBearerConfig.parse("region= audience=https://a") + + +@pytest.mark.parametrize("key", ["signing_algorithm", "sts_endpoint", "principal_name"]) +def test_parse_empty_value_on_optional_key_raises(key): + with pytest.raises(ValueError, match=key): + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a {key}=") + + +# ---- Defaults applied during parse ---- + + +def test_parse_no_signing_algorithm_defaults_to_es384(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.signing_algorithm == "ES384" + + +def test_parse_no_duration_defaults_to_300_seconds(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.duration_seconds == 300 + + +def test_parse_optional_fields_default_to_none(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.sts_endpoint is None + assert cfg.principal_name is None + assert cfg.sasl_extensions is None + assert cfg.tags is None + + +# ---- duration_seconds ---- + + +@pytest.mark.parametrize("seconds", [60, 300, 3600]) +def test_parse_duration_in_range_accepted(seconds): + cfg = AwsOAuthBearerConfig.parse( + f"region=us-east-1 audience=https://a duration_seconds={seconds}" + ) + assert cfg.duration_seconds == seconds + + +@pytest.mark.parametrize("seconds", [0, 59, 3601, -10]) +def test_parse_duration_out_of_range_raises(seconds): + with pytest.raises(ValueError, match="duration_seconds.*must be between"): + AwsOAuthBearerConfig.parse( + f"region=us-east-1 audience=https://a duration_seconds={seconds}" + ) + + +def test_parse_duration_not_integer_raises(): + with pytest.raises(ValueError, match="duration_seconds.*must be an integer"): + AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a duration_seconds=abc" + ) + + +# ---- signing_algorithm ---- + + +@pytest.mark.parametrize("alg", ["ES384", "RS256"]) +def test_parse_allowed_signing_algorithm_accepted(alg): + cfg = AwsOAuthBearerConfig.parse( + f"region=us-east-1 audience=https://a signing_algorithm={alg}" + ) + assert cfg.signing_algorithm == alg + + +@pytest.mark.parametrize("alg", ["HS256", "es384", "RS512"]) +def test_parse_disallowed_signing_algorithm_raises(alg): + with pytest.raises(ValueError, match="signing_algorithm.*ES384.*RS256"): + AwsOAuthBearerConfig.parse( + f"region=us-east-1 audience=https://a signing_algorithm={alg}" + ) + + +# ---- aws_debug (Python subset: none/console only) ---- + + +def test_parse_no_aws_debug_defaults_to_none(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.aws_debug == AWS_DEBUG_NONE + + +@pytest.mark.parametrize("value,expected", [ + ("none", AWS_DEBUG_NONE), + ("console", AWS_DEBUG_CONSOLE), +]) +def test_parse_aws_debug_values_accepted(value, expected): + cfg = AwsOAuthBearerConfig.parse( + f"region=us-east-1 audience=https://a aws_debug={value}" + ) + assert cfg.aws_debug == expected + + +@pytest.mark.parametrize("value,expected", [ + ("Console", AWS_DEBUG_CONSOLE), + ("CONSOLE", AWS_DEBUG_CONSOLE), + ("NONE", AWS_DEBUG_NONE), +]) +def test_parse_aws_debug_case_insensitive_accepted(value, expected): + cfg = AwsOAuthBearerConfig.parse( + f"region=us-east-1 audience=https://a aws_debug={value}" + ) + assert cfg.aws_debug == expected + + +@pytest.mark.parametrize("value", ["log4net", "systemdiagnostics", "Log4Net", "SystemDiagnostics"]) +def test_parse_aws_debug_dotnet_only_values_rejected_with_clear_message(value): + """log4net / systemdiagnostics are .NET sinks; Python rejects with a + platform-clarity message so cross-language users see the constraint.""" + with pytest.raises(ValueError, match="not supported on this platform"): + AwsOAuthBearerConfig.parse( + f"region=us-east-1 audience=https://a aws_debug={value}" + ) + + +@pytest.mark.parametrize("value", ["verbose", "etw", "debug", "true", "foo"]) +def test_parse_aws_debug_invalid_value_raises(value): + with pytest.raises(ValueError, match="aws_debug.*none, console"): + AwsOAuthBearerConfig.parse( + f"region=us-east-1 audience=https://a aws_debug={value}" + ) + + +def test_parse_aws_debug_empty_value_raises(): + with pytest.raises(ValueError, match="aws_debug.*must not be empty"): + AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a aws_debug=" + ) + + +# ---- sts_endpoint, principal_name ---- + + +def test_parse_sts_endpoint_stored_verbatim(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a " + "sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" + ) + assert cfg.sts_endpoint == "https://sts-fips.us-east-1.amazonaws.com" + + +def test_parse_principal_name_stored_verbatim(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a principal_name=my-principal" + ) + assert cfg.principal_name == "my-principal" + + +# ---- sasl_extensions argument (typed property pass-through) ---- + + +def test_parse_extension_prefix_in_config_rejected_as_unknown_key(): + """sasl extensions are accepted via typed sasl.oauthbearer.extensions + property only; embedded extension_* keys are rejected.""" + with pytest.raises(ValueError, match="extension_logicalCluster"): + AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a extension_logicalCluster=lkc-abc" + ) + + +def test_parse_sasl_extensions_arg_stored_on_config(): + ext = {"logicalCluster": "lkc-abc", "identityPoolId": "pool-xyz"} + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a", ext + ) + assert cfg.sasl_extensions is ext + + +def test_parse_sasl_extensions_arg_null_keeps_sasl_extensions_null(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a", None + ) + assert cfg.sasl_extensions is None + + +# ---- tag_ ---- + + +def test_parse_single_tag_collected_into_tags(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a tag_team=platform" + ) + assert cfg.tags == {"team": "platform"} + + +def test_parse_multiple_tags_all_collected(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a " + "tag_team=platform " + "tag_environment=prod" + ) + assert cfg.tags == {"team": "platform", "environment": "prod"} + + +def test_parse_empty_tag_name_raises(): + with pytest.raises(ValueError, match="empty name"): + AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a tag_=value" + ) + + +def test_parse_empty_tag_value_accepted(): + # AWS allows tag values of 0 chars; mirror that. + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a tag_team=" + ) + assert cfg.tags == {"team": ""} + + +def test_parse_duplicate_tag_name_last_wins(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a tag_team=infra tag_team=platform" + ) + assert cfg.tags == {"team": "platform"} + + +def test_parse_exactly_max_tags_accepted(): + parts = ["region=us-east-1", "audience=https://a"] + parts.extend(f"tag_k{i}=v{i}" for i in range(50)) + cfg = AwsOAuthBearerConfig.parse(" ".join(parts)) + assert len(cfg.tags) == 50 + + +def test_parse_over_max_tags_raises(): + parts = ["region=us-east-1", "audience=https://a"] + parts.extend(f"tag_k{i}=v{i}" for i in range(51)) + with pytest.raises(ValueError, match="50"): + AwsOAuthBearerConfig.parse(" ".join(parts)) + + +# ---- Unknown keys ---- + + +def test_parse_unknown_key_raises(): + with pytest.raises(ValueError, match="Unknown key.*not_a_key"): + AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a not_a_key=foo" + ) + + +# ---- Whitespace / ordering ---- + + +def test_parse_tabs_and_multiple_spaces_tolerated(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1\taudience=https://a duration_seconds=600" + ) + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + assert cfg.duration_seconds == 600 + + +def test_parse_leading_and_trailing_whitespace_tolerated(): + cfg = AwsOAuthBearerConfig.parse(" region=us-east-1 audience=https://a ") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + + +def test_parse_order_invariant(): + cfg = AwsOAuthBearerConfig.parse( + "duration_seconds=600 audience=https://a region=us-east-1 signing_algorithm=RS256" + ) + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + assert cfg.duration_seconds == 600 + assert cfg.signing_algorithm == "RS256" + + +def test_parse_duplicate_key_last_wins(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a region=us-west-2" + ) + assert cfg.region == "us-west-2" + + +# ---- Malformed entries ---- + + +def test_parse_no_equals_raises(): + with pytest.raises(ValueError, match="Malformed"): + AwsOAuthBearerConfig.parse("region us-east-1 audience=https://a") + + +def test_parse_leading_equals_raises(): + with pytest.raises(ValueError, match="Malformed"): + AwsOAuthBearerConfig.parse("=value audience=https://a region=us-east-1") + + +# ---- Integration ---- + + +def test_parse_all_fields_together_all_populated_correctly(): + sasl_extensions = {"logicalCluster": "lkc-abc"} + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 " + "audience=https://confluent.cloud/oidc " + "duration_seconds=1800 " + "signing_algorithm=RS256 " + "sts_endpoint=https://sts.us-east-1.amazonaws.com " + "principal_name=test-principal " + "aws_debug=console " + "tag_team=platform", + sasl_extensions, + ) + + assert cfg.region == "us-east-1" + assert cfg.audience == "https://confluent.cloud/oidc" + assert cfg.duration_seconds == 1800 + assert cfg.signing_algorithm == "RS256" + assert cfg.sts_endpoint == "https://sts.us-east-1.amazonaws.com" + assert cfg.principal_name == "test-principal" + assert cfg.aws_debug == AWS_DEBUG_CONSOLE + assert cfg.sasl_extensions == {"logicalCluster": "lkc-abc"} + assert cfg.tags == {"team": "platform"} + + +# ---- Direct construction (bypassing parse) still validated ---- + + +def test_direct_construction_with_empty_region_raises(): + with pytest.raises(ValueError, match="region.*must not be empty"): + AwsOAuthBearerConfig(region="", audience="https://a") + + +def test_direct_construction_with_out_of_range_duration_raises(): + with pytest.raises(ValueError, match="duration_seconds.*must be between"): + AwsOAuthBearerConfig( + region="us-east-1", audience="https://a", duration_seconds=10 + ) + + +def test_direct_construction_with_bool_duration_rejected(): + """bool is-a int in Python; reject explicitly so True/False doesn't pass.""" + with pytest.raises(ValueError, match="duration_seconds.*must be an integer"): + AwsOAuthBearerConfig( + region="us-east-1", audience="https://a", duration_seconds=True + ) + + +# ---- Surface invariants ---- + + +def test_aws_oauth_bearer_config_not_exposed_via_subpackage_init(): + """Public surface stays minimal — AwsOAuthBearerConfig is private to the + autowire layer (only `create_handler` is publicly importable).""" + import confluent_kafka.oauthbearer.aws as aws_pkg + assert not hasattr(aws_pkg, "AwsOAuthBearerConfig") diff --git a/tests/oauthbearer/aws/test_aws_sts_token_provider.py b/tests/oauthbearer/aws/test_aws_sts_token_provider.py new file mode 100644 index 000000000..aef20aefa --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_sts_token_provider.py @@ -0,0 +1,406 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_sts_token_provider. + +Mirrors .NET's AwsStsTokenProviderTests, modulo: + +- Python uses a tuple return (not a typed record) so the "no extensions → + null" assertion becomes "no extensions → empty dict" (the C oauth_cb + contract requires a dict, not None, at that slot). +- Python returns ``expiry_epoch_seconds`` (float seconds) not + ``LifetimeMs`` (long milliseconds) — the C wrapper multiplies by 1000. +""" + +import base64 +import datetime +import logging +from typing import Any, Dict, Optional +from unittest.mock import patch + +import pytest + +# botocore is a transitive of boto3; available in opt-in venv. +pytest.importorskip("boto3") +pytest.importorskip("botocore") + +from botocore.exceptions import ClientError + +from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import ( + AwsOAuthBearerConfig, +) +from confluent_kafka.oauthbearer.aws._aws_sts_token_provider import ( + AwsStsTokenProvider, +) + + +# ---- Test helpers ---- + + +_ROLE_ARN = "arn:aws:iam::123:role/R" + + +def _base64url(data: bytes) -> str: + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _canned_jwt(sub: str = _ROLE_ARN) -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" + + +_CANNED_JWT = _canned_jwt() +_CANNED_EXPIRY = datetime.datetime( + 2099, 4, 21, 6, 6, 47, 641_000, tzinfo=datetime.timezone.utc, +) + + +class FakeStsClient: + """Test double mirroring .NET's FakeStsClient. + + Records the last request kwargs and returns a canned response (or raises + a canned exception). Avoids the ceremony of ``botocore.stub.Stubber`` for + the simple cases here. + """ + + def __init__(self, responder=None, raises: Optional[BaseException] = None) -> None: + self.last_request: Optional[Dict[str, Any]] = None + self._responder = responder + self._raises = raises + + def get_web_identity_token(self, **kwargs: Any) -> Dict[str, Any]: + self.last_request = kwargs + if self._raises is not None: + raise self._raises + if self._responder is not None: + return self._responder(kwargs) + return _ok_response() + + +def _ok_response( + jwt: str = _CANNED_JWT, + expiration: datetime.datetime = _CANNED_EXPIRY, +) -> Dict[str, Any]: + return {"WebIdentityToken": jwt, "Expiration": expiration} + + +# ---- Constructor: checks ---- + + +def test_ctor_null_config_raises(): + with pytest.raises(TypeError): + AwsStsTokenProvider(None) + + +def test_ctor_valid_parsed_config_succeeds(): + """No AWS / no HTTP call at construction time (lazy credential chain).""" + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + # Pass a fake client so we don't try to instantiate a real boto3 client + # against a possibly-network-isolated test env. + provider = AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + assert provider is not None + + +# ---- Constructor: aws_debug ---- + + +def test_ctor_no_aws_debug_does_not_mutate_botocore_logger(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + # Capture botocore logger's level before & after construction; without + # an explicit aws_debug=console, we must not touch boto3.set_stream_logger. + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_not_called() + + +def test_ctor_aws_debug_none_does_not_mutate_botocore_logger(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a aws_debug=none" + ) + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_not_called() + + +def test_ctor_aws_debug_console_routes_botocore_logger_to_stream(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a aws_debug=console" + ) + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_called_once_with("botocore", logging.DEBUG) + + +# ---- token(): request shape ---- + + +def test_token_audience_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://my.audience" + ) + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["Audience"] == ["https://my.audience"] + + +def test_token_signing_algorithm_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a signing_algorithm=RS256" + ) + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["SigningAlgorithm"] == "RS256" + + +def test_token_duration_seconds_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a duration_seconds=900" + ) + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["DurationSeconds"] == 900 + + +def test_token_default_duration_sends_300_seconds(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["DurationSeconds"] == 300 + + +def test_token_default_signing_algorithm_sends_es384(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["SigningAlgorithm"] == "ES384" + + +def test_token_tags_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a tag_team=platform tag_environment=prod" + ) + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + + tags = fake.last_request["Tags"] + assert len(tags) == 2 + assert {"Key": "team", "Value": "platform"} in tags + assert {"Key": "environment", "Value": "prod"} in tags + + +def test_token_no_tags_omits_tags_field_from_request(): + """When no tags configured, omit the Tags field rather than send empty list.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert "Tags" not in fake.last_request + + +# ---- token(): response mapping ---- + + +def test_token_returns_mapped_fields(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + jwt, expiry, principal, extensions = provider.token() + + assert jwt == _CANNED_JWT + assert principal == _ROLE_ARN + assert expiry == _CANNED_EXPIRY.timestamp() + assert extensions == {} + + +def test_token_principal_name_override_wins_over_jwt_sub(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a principal_name=explicit-principal" + ) + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, _, principal, _ = provider.token() + assert principal == "explicit-principal" + + +def test_token_sasl_extensions_passthrough(): + fake = FakeStsClient() + sasl_extensions = {"logicalCluster": "lkc-123", "identityPoolId": "pool-x"} + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a", sasl_extensions, + ) + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, _, _, extensions = provider.token() + assert extensions == {"logicalCluster": "lkc-123", "identityPoolId": "pool-x"} + + +def test_token_no_extensions_configured_returns_empty_dict(): + """Python deviation from .NET: empty dict, not None, because the C + oauth_cb contract uses PyArg_ParseTuple "O!" with PyDict_Type at that + slot (rejects None). Empty dict is the Pythonic equivalent of the + .NET null-extensions case.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, _, _, extensions = provider.token() + assert extensions == {} + assert isinstance(extensions, dict) + + +def test_token_expiry_is_epoch_seconds_float(): + """The C wrapper expects expiry as epoch seconds (float). It multiplies + by 1000 internally to get milliseconds for librdkafka.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, expiry, _, _ = provider.token() + assert isinstance(expiry, float) + assert expiry == _CANNED_EXPIRY.timestamp() + + +# ---- token(): error propagation ---- + + +def test_token_missing_expiration_raises(): + """STS response without Expiration → ValueError surfaced as + rd_kafka_oauthbearer_set_token_failure by the C wrapper.""" + fake = FakeStsClient( + responder=lambda kwargs: {"WebIdentityToken": _CANNED_JWT} + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError, match="Expiration"): + provider.token() + + +def test_token_missing_token_value_raises(): + fake = FakeStsClient( + responder=lambda kwargs: {"Expiration": _CANNED_EXPIRY} + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError, match="WebIdentityToken"): + provider.token() + + +def test_token_malformed_jwt_raises_value_error(): + fake = FakeStsClient( + responder=lambda kwargs: { + "WebIdentityToken": "not-a-jwt", + "Expiration": _CANNED_EXPIRY, + } + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError): + provider.token() + + +def test_token_sts_access_denied_propagates(): + fake = FakeStsClient( + raises=ClientError( + error_response={ + "Error": { + "Code": "AccessDenied", + "Message": "User is not authorized to perform: sts:GetWebIdentityToken", + } + }, + operation_name="GetWebIdentityToken", + ), + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ClientError) as exc_info: + provider.token() + assert exc_info.value.response["Error"]["Code"] == "AccessDenied" + + +def test_token_outbound_federation_disabled_propagates(): + fake = FakeStsClient( + raises=ClientError( + error_response={ + "Error": { + "Code": "OutboundWebIdentityFederationDisabledException", + "Message": "Outbound web identity federation is not enabled on this account.", + } + }, + operation_name="GetWebIdentityToken", + ), + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ClientError) as exc_info: + provider.token() + assert ( + exc_info.value.response["Error"]["Code"] + == "OutboundWebIdentityFederationDisabledException" + ) + + +# ---- Surface invariants ---- + + +def test_aws_sts_token_provider_not_exposed_via_subpackage_init(): + """Public surface stays minimal — AwsStsTokenProvider is private.""" + import confluent_kafka.oauthbearer.aws as aws_pkg + assert not hasattr(aws_pkg, "AwsStsTokenProvider") + + +# ---- Lazy credential resolution ---- + + +def test_ctor_does_not_invoke_sts(): + """Construction must not make any STS call (lazy credential chain). + Verify by checking the fake client received no requests after __init__ + but before any .token() invocation.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + AwsStsTokenProvider(cfg, sts_client=fake) + assert fake.last_request is None + + +# ---- sts_endpoint plumbed to boto3.client ---- + + +def test_ctor_sts_endpoint_plumbed_to_boto3_client(): + """When sts_endpoint is set on config, boto3.Session().client('sts', ...) + receives an endpoint_url kwarg.""" + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a " + "sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" + ) + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + AwsStsTokenProvider(cfg) + mock_session.client.assert_called_once_with( + "sts", + region_name="us-east-1", + endpoint_url="https://sts-fips.us-east-1.amazonaws.com", + ) + + +def test_ctor_no_sts_endpoint_omits_endpoint_url_kwarg(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + AwsStsTokenProvider(cfg) + mock_session.client.assert_called_once_with( + "sts", region_name="us-east-1", + ) From 8ab03094bdad63c3ccf9131e9ec286688d65d3e4 Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 27 May 2026 17:49:47 +0530 Subject: [PATCH 04/21] Implement aws autowire contract --- .../oauthbearer/aws/aws_autowire.py | 102 ++++++- tests/oauthbearer/aws/test_aws_autowire.py | 267 ++++++++++++++++++ tests/oauthbearer/aws/test_contract.py | 148 ++++++++++ 3 files changed, 510 insertions(+), 7 deletions(-) create mode 100644 tests/oauthbearer/aws/test_aws_autowire.py create mode 100644 tests/oauthbearer/aws/test_contract.py diff --git a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py index 8b7060aa7..0ee5ad3fe 100644 --- a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py +++ b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py @@ -12,13 +12,101 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Frozen public entry-point for AWS IAM OAUTHBEARER autowire. +"""Public entry-point for AWS IAM OAUTHBEARER autowire. -Placeholder for Phase 1; ``create_handler`` lands in Phase 4. +Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.AwsAutoWire``. This is the +**only publicly importable name** in the optional subpackage. The C +dispatcher in ``src/confluent_kafka/src/confluent_kafka.c`` (Phase 5) reaches +this module via:: -Mirrors .NET's ``AwsAutoWire.cs``. The C dispatcher in -``src/confluent_kafka/src/confluent_kafka.c`` will reach this module via -``PyImport_ImportModule("confluent_kafka.oauthbearer.aws.aws_autowire")``. -The function signature of ``create_handler`` is part of the cross-module -ABI and bumping it is a major-version change on ``confluent-kafka``. + PyImport_ImportModule("confluent_kafka.oauthbearer.aws.aws_autowire") + +and resolves :func:`create_handler` by name. The function signature is a +**frozen cross-module contract**: + +* arity: 2 positional parameters +* names: ``sasl_oauthbearer_config``, ``sasl_oauthbearer_extensions`` +* types: ``str``, ``Optional[str]`` +* return: :data:`OAuthBearerCallback` + +Bumping any of these is a breaking change requiring a major version +increment on the ``confluent-kafka`` distribution. The frozen contract is +test-guarded by ``tests/oauthbearer/aws/test_contract.py``. + +The marker key/value check is performed in core (the C dispatcher); +:func:`create_handler` is invoked only when the caller has already decided +to autowire the AWS path. The function therefore unconditionally attempts +to build a handler and raises on any input it cannot parse. """ + +from typing import Callable, Dict, Optional, Tuple + +from . import _aws_sasl_extensions_parser +from ._aws_iam_marker import AWS_IAM_MARKER_KEY, AWS_IAM_MARKER_VALUE +from ._aws_oauthbearer_config import CONFIG_KEY, AwsOAuthBearerConfig +from ._aws_sts_token_provider import AwsStsTokenProvider + +__all__ = ["create_handler", "OAuthBearerCallback"] + + +#: Type alias for the callable returned by :func:`create_handler`. +#: +#: Tuple shape matches the existing ``oauth_cb`` contract enforced in C at +#: ``confluent_kafka.c`` around L2291 via +#: ``PyArg_ParseTuple(result, "sd|sO!", ...)`` — single ``str`` argument +#: (the ``sasl.oauthbearer.config`` value librdkafka passes back on every +#: refresh), returning ``(token_str, expiry_epoch_seconds, principal_str, +#: extensions_dict)``. +OAuthBearerCallback = Callable[[str], Tuple[str, float, str, Dict[str, str]]] + + +def create_handler( + sasl_oauthbearer_config: str, + sasl_oauthbearer_extensions: Optional[str], +) -> OAuthBearerCallback: + """Build an OAUTHBEARER refresh callback from the two OAUTHBEARER config strings. + + Construction-time work: + + 1. Validates ``sasl_oauthbearer_config`` is a non-empty string. + 2. Parses ``sasl_oauthbearer_extensions`` (comma-separated ``key=value``) + via :mod:`._aws_sasl_extensions_parser` into an optional dict. + 3. Parses ``sasl_oauthbearer_config`` (whitespace-separated ``key=value``) + into a validated :class:`._aws_oauthbearer_config.AwsOAuthBearerConfig`. + 4. Constructs an :class:`._aws_sts_token_provider.AwsStsTokenProvider` + (no HTTP yet — credential resolution is lazy until the first + ``token()`` invocation). + 5. Returns the provider's bound :meth:`token` method as the callable. + + The returned callable is invoked by the C ``oauth_cb`` wrapper on every + OAUTHBEARER refresh; each call performs one STS ``GetWebIdentityToken`` + round-trip and returns a fresh 4-tuple suitable for + ``rd_kafka_oauthbearer_set_token``. + + :param sasl_oauthbearer_config: The verbatim ``sasl.oauthbearer.config`` + value (whitespace-separated ``key=value`` pairs). Must be non-empty. + :param sasl_oauthbearer_extensions: The verbatim + ``sasl.oauthbearer.extensions`` value (comma-separated ``key=value`` + pairs, RFC 7628 §3.1). May be ``None`` or empty when the user has + no extensions configured. + + :returns: A callable matching :data:`OAuthBearerCallback`. + + :raises ValueError: ``sasl_oauthbearer_config`` is ``None`` or empty; + the wire-grammar parse fails (unknown key, malformed token, missing + required field, range/enum violation, etc.). + :raises RuntimeError: AWS SDK reachability or initialisation failure + (e.g. unknown region, malformed ``sts_endpoint``). + """ + if not sasl_oauthbearer_config: + raise ValueError( + f"'{AWS_IAM_MARKER_KEY}={AWS_IAM_MARKER_VALUE}' is set but " + f"'{CONFIG_KEY}' is missing or empty. The AWS IAM autowire path " + f"requires region and audience to be supplied via " + f"{CONFIG_KEY} (e.g. \"region=us-east-1 audience=https://...\")." + ) + + sasl_extensions = _aws_sasl_extensions_parser.parse(sasl_oauthbearer_extensions) + config = AwsOAuthBearerConfig.parse(sasl_oauthbearer_config, sasl_extensions) + provider = AwsStsTokenProvider(config) + return provider.token diff --git a/tests/oauthbearer/aws/test_aws_autowire.py b/tests/oauthbearer/aws/test_aws_autowire.py new file mode 100644 index 000000000..7d84afaeb --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_autowire.py @@ -0,0 +1,267 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws.aws_autowire.create_handler. + +End-to-end orchestration tests. The frozen-signature contract guard lives +in test_contract.py. +""" + +import pytest + +pytest.importorskip("boto3") + +from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler + + +# ---- Input validation (defensive checks for direct callers) ---- + + +def test_create_handler_null_config_raises(): + with pytest.raises(ValueError, match="missing or empty"): + create_handler(None, None) + + +def test_create_handler_empty_config_raises(): + with pytest.raises(ValueError, match="missing or empty"): + create_handler("", None) + + +def test_create_handler_error_message_names_marker_and_config_key(): + with pytest.raises(ValueError) as exc_info: + create_handler(None, None) + # User should see both: the marker that's set AND the config key that's missing. + assert "sasl.oauthbearer.metadata.authentication.type" in str(exc_info.value) + assert "aws_iam" in str(exc_info.value) + assert "sasl.oauthbearer.config" in str(exc_info.value) + + +# ---- Parse delegation: errors from AwsOAuthBearerConfig.parse propagate ---- + + +def test_create_handler_missing_region_raises(): + with pytest.raises(ValueError, match="region.*required"): + create_handler("audience=https://a", None) + + +def test_create_handler_missing_audience_raises(): + with pytest.raises(ValueError, match="audience.*required"): + create_handler("region=us-east-1", None) + + +def test_create_handler_invalid_signing_algorithm_raises(): + with pytest.raises(ValueError, match="signing_algorithm"): + create_handler( + "region=us-east-1 audience=https://a signing_algorithm=HS256", + None, + ) + + +def test_create_handler_invalid_duration_raises(): + with pytest.raises(ValueError, match="duration_seconds"): + create_handler( + "region=us-east-1 audience=https://a duration_seconds=10", None, + ) + + +def test_create_handler_unknown_key_raises(): + with pytest.raises(ValueError, match="not_a_key"): + create_handler( + "region=us-east-1 audience=https://a not_a_key=foo", None, + ) + + +def test_create_handler_invalid_extensions_grammar_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + create_handler( + "region=us-east-1 audience=https://a", "noEqualsHere", + ) + + +def test_create_handler_aws_debug_dotnet_only_value_raises_with_platform_hint(): + """log4net / systemdiagnostics are .NET-only sinks; surface that clearly.""" + with pytest.raises(ValueError, match="not supported on this platform"): + create_handler( + "region=us-east-1 audience=https://a aws_debug=log4net", None, + ) + + +# ---- Success cases — handler returned, no throw, no STS call yet ---- + + +def test_create_handler_marker_only_minimum_config_returns_handler(): + handler = create_handler( + "region=us-east-1 audience=https://a", None, + ) + assert handler is not None + assert callable(handler) + + +def test_create_handler_all_optional_fields_returns_handler(): + handler = create_handler( + "region=us-east-1 audience=https://a " + "duration_seconds=900 signing_algorithm=RS256 " + "sts_endpoint=https://sts.us-east-1.amazonaws.com " + "principal_name=my-principal " + "aws_debug=none " + "tag_team=platform tag_environment=prod", + None, + ) + assert handler is not None + assert callable(handler) + + +def test_create_handler_tag_config_handler_ready(): + handler = create_handler( + "region=us-east-1 audience=https://a tag_team=platform tag_environment=prod", + None, + ) + assert callable(handler) + + +def test_create_handler_principal_name_override_handler_ready(): + handler = create_handler( + "region=us-east-1 audience=https://a principal_name=explicit-principal", + None, + ) + assert callable(handler) + + +# ---- Extensions argument handling ---- + + +def test_create_handler_null_extensions_treats_as_absent(): + handler = create_handler( + "region=us-east-1 audience=https://a", None, + ) + assert callable(handler) + + +def test_create_handler_empty_extensions_treats_as_absent(): + handler = create_handler( + "region=us-east-1 audience=https://a", "", + ) + assert callable(handler) + + +def test_create_handler_single_extension_handler_ready(): + handler = create_handler( + "region=us-east-1 audience=https://a", "logicalCluster=lkc-abc", + ) + assert callable(handler) + + +def test_create_handler_multiple_extensions_handler_ready(): + handler = create_handler( + "region=us-east-1 audience=https://a", + "logicalCluster=lkc-abc,identityPoolId=pool-x", + ) + assert callable(handler) + + +# ---- Returned callable invokes correctly with a real STS round-trip stubbed at the boto3 layer ---- +# We can't easily stub the boto3 client AwsStsTokenProvider creates internally without +# patching boto3.Session. End-to-end with-injected-client tests live in +# test_aws_sts_token_provider.py; here we cover the *closure* level (no STS call yet). + + +def test_create_handler_does_not_call_sts_at_construction(): + """create_handler must NOT make an STS call — credential resolution and + network I/O are deferred to the first invocation of the returned callable. + Verify by patching boto3.Session to detect any client.get_web_identity_token call. + """ + from unittest.mock import MagicMock, patch + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_session.client.return_value = mock_client + create_handler("region=us-east-1 audience=https://a", None) + # Session/client construction OK; get_web_identity_token must not have been called. + mock_client.get_web_identity_token.assert_not_called() + + +def test_create_handler_returned_callable_when_invoked_calls_sts(): + """When the returned callable is invoked, it triggers exactly one STS call.""" + from unittest.mock import MagicMock, patch + import datetime + + canned_response = { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime(2099, 4, 21, 6, 6, 47, tzinfo=datetime.timezone.utc), + } + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_client.get_web_identity_token.return_value = canned_response + mock_session.client.return_value = mock_client + + handler = create_handler("region=us-east-1 audience=https://a", None) + result = handler("ignored-passthrough-string") + + mock_client.get_web_identity_token.assert_called_once() + assert isinstance(result, tuple) + assert len(result) == 4 + token, expiry, principal, extensions = result + assert token == canned_response["WebIdentityToken"] + assert isinstance(expiry, float) + assert principal.startswith("arn:") + assert extensions == {} + + +def test_create_handler_returned_callable_round_trips_extensions(): + """Extensions configured via the typed property flow through to the + 4-tuple's extensions slot.""" + from unittest.mock import MagicMock, patch + import datetime + + canned_response = { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime(2099, 4, 21, 6, 6, 47, tzinfo=datetime.timezone.utc), + } + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_client.get_web_identity_token.return_value = canned_response + mock_session.client.return_value = mock_client + + handler = create_handler( + "region=us-east-1 audience=https://a", + "logicalCluster=lkc-abc,identityPoolId=pool-x", + ) + _, _, _, extensions = handler("") + + assert extensions == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +# ---- Test helpers ---- + + +def _base64url(data: bytes) -> str: + import base64 + return ( + base64.b64encode(data) + .decode("ascii") + .rstrip("=") + .replace("+", "-") + .replace("/", "_") + ) + + +def _canned_jwt(sub: str = "arn:aws:iam::123:role/R") -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" diff --git a/tests/oauthbearer/aws/test_contract.py b/tests/oauthbearer/aws/test_contract.py new file mode 100644 index 000000000..cfc9b0f2a --- /dev/null +++ b/tests/oauthbearer/aws/test_contract.py @@ -0,0 +1,148 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Frozen cross-module contract guard for the autowire entry-point. + +The C dispatcher (Phase 5) in ``src/confluent_kafka/src/confluent_kafka.c`` +resolves ``confluent_kafka.oauthbearer.aws.aws_autowire.create_handler`` via +``PyImport_ImportModule`` + ``PyObject_GetAttrString`` and calls it with two +string arguments. Any signature drift on the Python side breaks that +contract and requires a major version bump on the ``confluent-kafka`` +distribution. + +These tests guard against accidental drift: parameter names, type +annotations, return annotation, arity, and absence of defaults. +""" + +import inspect +from typing import Callable, Dict, Optional, Tuple + +import pytest + +pytest.importorskip("boto3") + +from confluent_kafka.oauthbearer.aws import aws_autowire +from confluent_kafka.oauthbearer.aws.aws_autowire import ( + OAuthBearerCallback, + create_handler, +) + + +# ---- Module surface ---- + + +def test_module_importable_at_canonical_path(): + """The C dispatcher does PyImport_ImportModule(...) with this exact path.""" + import importlib + mod = importlib.import_module("confluent_kafka.oauthbearer.aws.aws_autowire") + assert mod is aws_autowire + + +def test_create_handler_is_module_level_callable(): + assert callable(aws_autowire.create_handler) + assert aws_autowire.create_handler is create_handler + + +def test_oauthbearer_callback_type_alias_exported(): + assert hasattr(aws_autowire, "OAuthBearerCallback") + assert aws_autowire.OAuthBearerCallback is OAuthBearerCallback + + +def test_module_all_lists_public_surface(): + """__all__ should advertise only the two public names.""" + assert set(aws_autowire.__all__) == {"create_handler", "OAuthBearerCallback"} + + +# ---- create_handler frozen signature ---- + + +def test_create_handler_arity_is_two(): + sig = inspect.signature(create_handler) + assert len(sig.parameters) == 2 + + +def test_create_handler_parameter_names_are_frozen(): + sig = inspect.signature(create_handler) + names = list(sig.parameters.keys()) + assert names == ["sasl_oauthbearer_config", "sasl_oauthbearer_extensions"] + + +def test_create_handler_parameter_annotations_are_frozen(): + sig = inspect.signature(create_handler) + assert sig.parameters["sasl_oauthbearer_config"].annotation is str + assert ( + sig.parameters["sasl_oauthbearer_extensions"].annotation + == Optional[str] + ) + + +def test_create_handler_no_default_values(): + """Both parameters must be required positional — the C dispatcher always + passes two arguments (extensions may be the empty string or None, but + is always supplied).""" + sig = inspect.signature(create_handler) + assert sig.parameters["sasl_oauthbearer_config"].default is inspect.Parameter.empty + assert sig.parameters["sasl_oauthbearer_extensions"].default is inspect.Parameter.empty + + +def test_create_handler_return_annotation_is_oauthbearer_callback(): + sig = inspect.signature(create_handler) + assert sig.return_annotation is OAuthBearerCallback + + +def test_create_handler_parameters_accept_positional_arguments(): + """The C dispatcher calls via PyObject_CallFunction with positional args.""" + sig = inspect.signature(create_handler) + for name in ["sasl_oauthbearer_config", "sasl_oauthbearer_extensions"]: + kind = sig.parameters[name].kind + # POSITIONAL_OR_KEYWORD covers what PyObject_CallFunction provides. + assert kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + + +# ---- OAuthBearerCallback shape (the type returned by create_handler) ---- + + +def test_oauthbearer_callback_matches_c_oauth_cb_contract(): + """The C oauth_cb wrapper at confluent_kafka.c:2291 does: + + PyArg_ParseTuple(result, "sd|sO!", &token, &expiry, &principal, &PyDict_Type, &extensions) + + So the callable returned by create_handler must accept one string + argument (the sasl.oauthbearer.config pass-through) and return a tuple + of (str, float, str, Dict[str, str]). + """ + expected = Callable[[str], Tuple[str, float, str, Dict[str, str]]] + assert OAuthBearerCallback == expected + + +# ---- create_handler docstring sanity ---- + + +def test_create_handler_docstring_present(): + """The frozen cross-module contract is documented in the function's + docstring so it's visible to anyone reading the source. Block a future + contributor from accidentally stripping the docstring.""" + assert create_handler.__doc__ is not None + assert "frozen" in aws_autowire.__doc__.lower() + + +def test_create_handler_docstring_references_key_contract_terms(): + doc = create_handler.__doc__ or "" + # Names of the two arguments must appear so users grepping for them + # find the documentation. + assert "sasl.oauthbearer.config" in doc + assert "sasl.oauthbearer.extensions" in doc From 975d1a5b1089ccd4b127f20a59cbd241f4671ddc Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 27 May 2026 19:32:14 +0530 Subject: [PATCH 05/21] C-extension dispatcher resolve_aws_oauthbearer_marker --- src/confluent_kafka/src/confluent_kafka.c | 284 +++++++++++++ tests/oauthbearer/aws/test_aws_iam_marker.py | 32 ++ tests/oauthbearer/aws/test_dispatch.py | 422 +++++++++++++++++++ 3 files changed, 738 insertions(+) create mode 100644 tests/oauthbearer/aws/test_dispatch.py diff --git a/src/confluent_kafka/src/confluent_kafka.c b/src/confluent_kafka/src/confluent_kafka.c index 0a2c58495..68b18654d 100644 --- a/src/confluent_kafka/src/confluent_kafka.c +++ b/src/confluent_kafka/src/confluent_kafka.c @@ -2585,6 +2585,281 @@ static void common_conf_set_software(rd_kafka_conf_t *conf) { * Returns a conf object on success or NULL on failure in which case * an exception has been raised. */ +/** + * @brief Detect the aws_iam OAUTHBEARER autowire marker in the user's + * config dict and replace it with an oauth_cb callable sourced from + * the optional confluent_kafka.oauthbearer.aws subpackage. + * + * User contract (all three keys required when marker is set): + * sasl.oauthbearer.method = "oidc" + * sasl.oauthbearer.metadata.authentication.type = "aws_iam" + * sasl.oauthbearer.config = "region=... audience=..." + * + * Plus optional: + * sasl.oauthbearer.extensions = "key=val,key=val" + * + * IMPORTANT: These literals MUST match + * src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py. + * tests/oauthbearer/aws/test_aws_iam_marker.py guards against drift. + * + * Flow: + * 1. Skip if the marker key is absent or its value is not "aws_iam". + * (Other values like "azure_imds" flow through to librdkafka unchanged.) + * 2. If oauth_cb is already set, strip the marker and skip autowire — + * the explicit handler wins (precedence rule). + * 3. Require sasl.oauthbearer.method == "oidc". The AWS IAM path runs as + * a high-level-client refresh callback inside librdkafka's OIDC + * subsystem (parallel to azure_imds); without method=oidc the + * configuration is rejected by design. + * 4. Require non-empty sasl.oauthbearer.config (carries region, audience, + * duration_seconds, etc.). + * 5. Lazy-import confluent_kafka.oauthbearer.aws.aws_autowire. If the + * optional 'oauthbearer-aws' extra isn't installed, the import chain + * bottoms out in a missing boto3; rewrite that ModuleNotFoundError + * into a friendly install hint while preserving the original via + * __cause__ (Python exception-chaining). + * 6. Call aws_autowire.create_handler(config_str, extensions_str_or_None). + * The returned Python callable replaces oauth_cb in confdict; the + * existing config-iteration loop below picks it up and registers the + * librdkafka refresh callback. + * 7+. Strip the marker key + auto-set THREE sentinel OIDC fields that + * librdkafka's config-finalize demands when method=oidc. The full set: + * - sasl.oauthbearer.token.endpoint.url + * - sasl.oauthbearer.client.id + * - sasl.oauthbearer.client.secret + * + * Why all three are needed: + * librdkafka has TWO mandatory-field checks gated on the marker value: + * + * (a) rdkafka_conf.c:4130 demands token.endpoint.url unless + * metadata_authentication.type == AZURE_IMDS or AWS_IAM. With + * the marker stripped, this check fires unless we provide a URL. + * + * (b) rdkafka_conf.c:4155 — finalize_oauthbearer_oidc_grant_type + * ONLY runs when metadata_authentication.type == NONE. With the + * marker stripped, it runs and demands client.id + client.secret + * (for the default CLIENT_CREDENTIALS grant type). + * + * Why both old AND new librdkafkas tolerate this: + * - librdkafka < AWS-IAM-aware: doesn't know `aws_iam`, rejects the + * marker value at rd_kafka_conf_set() time. We must strip first. + * The 3 sentinels then satisfy its OIDC mandatory checks. + * - librdkafka >= AWS-IAM-aware: would honour the marker (bypass + * token.endpoint.url and skip the grant-type check), but we strip + * it for backward compatibility with older librdkafkas. The 3 + * sentinels keep this path working too. + * + * The sentinels are NEVER used at runtime. librdkafka's config + * finalize at rdkafka_conf.c:4166-4169 explicitly skips the built-in + * OIDC token fetcher when a refresh callback is registered: + * "Enable background thread for the builtin OIDC handler, + * unless a refresh callback has been set." + * Our autowire registers an oauth_cb. The sentinel URL uses RFC 2606 + * `.invalid` TLD so even an accidental fetch would fail at DNS. + * + * TODO: REMOVE STEPS 8+9 (strip + sentinels) once librdkafka + * v2.14.2-aws-iam.2-dev (or the official release that supersedes it) + * becomes the bundled MIN_VER floor. Leave the marker in place; the + * AWS-IAM-aware librdkafka knows `aws_iam`, bypasses the + * token.endpoint.url check, and skips the grant-type check entirely. + * All three sentinels become unnecessary at that point. The removal + * MUST be done in the same PR that bumps the librdkafka floor. + * Project memory: project_aws_iam_python_alignment.md decision #6. + * + * @returns 0 on success (no-op or autowire complete), -1 on error + * (PyErr_* is set; caller goto outer_err). + */ +static int resolve_aws_oauthbearer_marker(PyObject *confdict) { + static const char MARKER_KEY[] = + "sasl.oauthbearer.metadata.authentication.type"; + static const char MARKER_VALUE[] = "aws_iam"; + static const char METHOD_KEY[] = "sasl.oauthbearer.method"; + static const char METHOD_OIDC_VALUE[] = "oidc"; + static const char CONFIG_KEY[] = "sasl.oauthbearer.config"; + static const char EXTENSIONS_KEY[] = "sasl.oauthbearer.extensions"; + static const char OAUTH_CB_KEY[] = "oauth_cb"; + static const char AUTOWIRE_MODULE[] = + "confluent_kafka.oauthbearer.aws.aws_autowire"; + static const char CREATE_HANDLER[] = "create_handler"; + static const char METHOD_REQUIREMENT_ERR[] = + "'sasl.oauthbearer.metadata.authentication.type=aws_iam' requires " + "'sasl.oauthbearer.method=oidc'. The AWS IAM path runs as a " + "high-level-client refresh callback inside librdkafka's OIDC " + "subsystem (parallel to azure_imds); without method=oidc the " + "configuration is rejected by design."; + static const char CONFIG_REQUIREMENT_ERR[] = + "'sasl.oauthbearer.metadata.authentication.type=aws_iam' is set " + "but 'sasl.oauthbearer.config' is missing or empty. The AWS IAM " + "autowire path requires region and audience to be supplied via " + "sasl.oauthbearer.config " + "(e.g. \"region=us-east-1 audience=https://...\")."; + static const char FRIENDLY_IMPORT_ERR[] = + "Config 'sasl.oauthbearer.metadata.authentication.type=aws_iam' " + "requires the optional 'oauthbearer-aws' extra. Install with:\n" + " pip install 'confluent-kafka[oauthbearer-aws]'"; + + PyObject *marker; + PyObject *cb; + PyObject *method; + PyObject *cfg_str; + PyObject *ext_str; + PyObject *mod; + PyObject *func; + PyObject *callback; + const char *marker_c; + const char *method_c; + + /* 1. Marker present and equals "aws_iam"? */ + marker = PyDict_GetItemString(confdict, MARKER_KEY); + if (!marker || !PyUnicode_Check(marker)) { + return 0; + } + marker_c = PyUnicode_AsUTF8(marker); + if (!marker_c) { + /* Non-ASCII / malformed unicode — treat as not-our-value. */ + PyErr_Clear(); + return 0; + } + if (strcmp(marker_c, MARKER_VALUE) != 0) { + return 0; + } + + /* 2. Explicit oauth_cb wins. Skip ahead to strip+sentinels — librdkafka + * still needs the sentinels because we strip the marker (so its + * AWS_IAM bypass paths don't fire). The user's oauth_cb takes over + * the actual refresh path; the sentinels are just there to keep + * librdkafka's config-finalize happy. */ + cb = PyDict_GetItemString(confdict, OAUTH_CB_KEY); + if (cb && cb != Py_None) { + goto strip_and_sentinels; + } + + /* 3. Require sasl.oauthbearer.method = "oidc". */ + method = PyDict_GetItemString(confdict, METHOD_KEY); + if (!method || !PyUnicode_Check(method)) { + PyErr_SetString(PyExc_ValueError, METHOD_REQUIREMENT_ERR); + return -1; + } + method_c = PyUnicode_AsUTF8(method); + if (!method_c || strcmp(method_c, METHOD_OIDC_VALUE) != 0) { + if (!method_c) { + PyErr_Clear(); + } + PyErr_SetString(PyExc_ValueError, METHOD_REQUIREMENT_ERR); + return -1; + } + + /* 4. Require non-empty sasl.oauthbearer.config. */ + cfg_str = PyDict_GetItemString(confdict, CONFIG_KEY); + if (!cfg_str || !PyUnicode_Check(cfg_str) || + PyUnicode_GET_LENGTH(cfg_str) == 0) { + PyErr_SetString(PyExc_ValueError, CONFIG_REQUIREMENT_ERR); + return -1; + } + + /* 5. sasl.oauthbearer.extensions is optional. */ + ext_str = PyDict_GetItemString(confdict, EXTENSIONS_KEY); + + /* 6. Lazy import; rewrite ImportError friendly. */ + mod = PyImport_ImportModule(AUTOWIRE_MODULE); + if (!mod) { + PyObject *cause_type = NULL; + PyObject *cause_value = NULL; + PyObject *cause_tb = NULL; + PyObject *new_exc; + + PyErr_Fetch(&cause_type, &cause_value, &cause_tb); + PyErr_NormalizeException(&cause_type, &cause_value, &cause_tb); + + new_exc = PyObject_CallFunction( + PyExc_ImportError, "s", FRIENDLY_IMPORT_ERR); + if (new_exc) { + if (cause_value) { + Py_INCREF(cause_value); + PyException_SetCause(new_exc, cause_value); + } + PyErr_SetObject(PyExc_ImportError, new_exc); + Py_DECREF(new_exc); + } + Py_XDECREF(cause_type); + Py_XDECREF(cause_value); + Py_XDECREF(cause_tb); + return -1; + } + + /* 7. create_handler(cfg_str, ext_str or None) -> Callable */ + func = PyObject_GetAttrString(mod, CREATE_HANDLER); + Py_DECREF(mod); + if (!func) { + return -1; + } + callback = PyObject_CallFunction( + func, "OO", cfg_str, ext_str ? ext_str : Py_None); + Py_DECREF(func); + if (!callback) { + return -1; + } + if (PyDict_SetItemString(confdict, OAUTH_CB_KEY, callback) == -1) { + Py_DECREF(callback); + return -1; + } + Py_DECREF(callback); + +strip_and_sentinels: + /* 8. Strip the marker. UNCONDITIONAL — see TODO at top of function. + * Reached from both the autowire-fires path AND the + * explicit-oauth_cb-wins path (goto from step 2). */ + if (PyDict_DelItemString(confdict, MARKER_KEY) == -1) { + return -1; + } + + /* 9. Auto-set 3 sentinel OIDC fields that librdkafka demands when + * method=oidc. Required because we stripped the marker in step 8 + * (so librdkafka's AWS_IAM bypass paths never fire). The sentinels + * are never used at runtime — the oauth_cb (autowired or user- + * supplied) takes over the token fetch entirely. See TODO at top + * of function for full reasoning. User-supplied values are + * respected (no clobber). */ + { + static const struct { + const char *key; + const char *value; + } SENTINELS[] = { + {"sasl.oauthbearer.token.endpoint.url", + "https://aws-iam-autowire.invalid/"}, + {"sasl.oauthbearer.client.id", "aws-iam-autowire"}, + {"sasl.oauthbearer.client.secret", "aws-iam-autowire"}, + }; + size_t i; + for (i = 0; + i < sizeof(SENTINELS) / sizeof(SENTINELS[0]); + i++) { + PyObject *existing; + PyObject *val; + + existing = PyDict_GetItemString(confdict, + SENTINELS[i].key); + if (existing) { + /* User already supplied a value — respect it. */ + continue; + } + val = PyUnicode_FromString(SENTINELS[i].value); + if (!val) { + return -1; + } + if (PyDict_SetItemString(confdict, SENTINELS[i].key, + val) == -1) { + Py_DECREF(val); + return -1; + } + Py_DECREF(val); + } + } + + return 0; +} + + rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype, Handle *h, PyObject *args, @@ -2703,6 +2978,15 @@ rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype, PyDict_DelItemString(confdict, "default.topic.config"); } + /* AWS IAM OAUTHBEARER autowire: when the user sets the marker + * sasl.oauthbearer.metadata.authentication.type=aws_iam, replace it + * with an oauth_cb sourced from the optional oauthbearer-aws extra. + * No-op when the marker is absent. See resolve_aws_oauthbearer_marker + * above for the full flow and the unconditional-strip TODO. */ + if (resolve_aws_oauthbearer_marker(confdict) == -1) { + goto outer_err; + } + /* Convert config dict to config key-value pairs. */ while (PyDict_Next(confdict, &pos, &ko, &vo)) { PyObject *ks; diff --git a/tests/oauthbearer/aws/test_aws_iam_marker.py b/tests/oauthbearer/aws/test_aws_iam_marker.py index bbf6fd705..1b6cc8ccb 100644 --- a/tests/oauthbearer/aws/test_aws_iam_marker.py +++ b/tests/oauthbearer/aws/test_aws_iam_marker.py @@ -40,3 +40,35 @@ def test_marker_value_is_locked_value(): Same constraint as :func:`test_marker_key_is_locked_value`. """ assert AWS_IAM_MARKER_VALUE == "aws_iam" + + +# ---- End-to-end drift guard (active from Phase 5 onward) ---- + + +def test_c_dispatcher_recognises_python_authoritative_marker(): + """Drift-guard end-to-end check. + + If the C-side literal strings in confluent_kafka.c drift away from the + Python constants in _aws_iam_marker.py, the dispatcher would no longer + recognize a Producer built with these constants, the precondition check + would not fire, and the ValueError below would NOT raise. The test + therefore fails loudly on drift. + + We deliberately omit `sasl.oauthbearer.method=oidc` so the precondition + check fires — that's the cheapest reliable signal that the dispatcher + saw our marker. + """ + import pytest + from confluent_kafka import Producer + + with pytest.raises(ValueError, match="method=oidc"): + Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + AWS_IAM_MARKER_KEY: AWS_IAM_MARKER_VALUE, + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + # Deliberately omitting sasl.oauthbearer.method to trigger the + # dispatcher's precondition check. If the dispatcher's C-side + # marker literals differ from AWS_IAM_MARKER_KEY/VALUE, the + # dispatcher won't fire and this ValueError won't raise. + }) diff --git a/tests/oauthbearer/aws/test_dispatch.py b/tests/oauthbearer/aws/test_dispatch.py new file mode 100644 index 000000000..143379f52 --- /dev/null +++ b/tests/oauthbearer/aws/test_dispatch.py @@ -0,0 +1,422 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the C-extension dispatcher resolve_aws_oauthbearer_marker(). + +The dispatcher fires inside common_conf_setup() (src/confluent_kafka.c) on +every Producer / Consumer / AdminClient construction (and transitively +AIOProducer / AIOConsumer, which wrap the sync clients). + +These tests exercise the C dispatcher indirectly via client construction: +no direct path to the helper exists from Python — that's the design. +""" + +import base64 +import datetime +import importlib +import sys +from typing import Any, Dict, Optional +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip("boto3") + +import confluent_kafka # noqa: E402 +from confluent_kafka.admin import AdminClient # noqa: E402 + + +# ---- Test helpers ---- + + +def _base64url(data: bytes) -> str: + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _canned_jwt(sub: str = "arn:aws:iam::123:role/R") -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" + + +def _canned_response() -> Dict[str, Any]: + return { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime( + 2099, 4, 21, 6, 6, 47, tzinfo=datetime.timezone.utc, + ), + } + + +@pytest.fixture +def mocked_boto3(): + """Patch boto3.Session so the autowired provider's STS client is a mock. + The actual STS call doesn't happen at client construction (lazy), but the + Session/client construction does — without this, boto3 would try to + resolve real credentials.""" + with patch("boto3.Session") as session_cls: + session = session_cls.return_value + client = MagicMock() + client.get_web_identity_token.return_value = _canned_response() + session.client.return_value = client + yield client + + +def _minimal_aws_iam_config(extra: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + cfg = { + "bootstrap.servers": "broker.invalid:9092", + # security.protocol=SASL_SSL is required for librdkafka to actually + # engage the OAUTHBEARER refresh path. Without it, sasl.mechanisms is + # inert (librdkafka logs CONFWARN and the refresh_cb never fires). + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + if extra: + cfg.update(extra) + return cfg + + +# ---- 1. Marker absent → dispatcher is a no-op (regression check) ---- + + +def test_marker_absent_producer_constructs_unchanged(): + p = confluent_kafka.Producer({"bootstrap.servers": "localhost:9092"}) + p.flush(timeout=0.1) + + +def test_marker_absent_consumer_constructs_unchanged(): + c = confluent_kafka.Consumer({ + "bootstrap.servers": "localhost:9092", + "group.id": "test-group", + }) + c.close() + + +def test_marker_absent_admin_client_constructs_unchanged(): + a = AdminClient({"bootstrap.servers": "localhost:9092"}) + assert a is not None + + +def test_marker_absent_does_not_import_aws_modules(): + """If the marker is absent the dispatcher must NOT touch the optional + subpackage — boto3 import must not happen.""" + # Clear any prior imports of the AWS subpackage so we can detect a fresh import. + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + del sys.modules[name] + confluent_kafka.Producer({"bootstrap.servers": "localhost:9092"}) + for name in [ + "confluent_kafka.oauthbearer.aws.aws_autowire", + "confluent_kafka.oauthbearer.aws._aws_sts_token_provider", + ]: + assert name not in sys.modules, ( + f"Dispatcher imported {name} when no marker was set — " + "this means the no-op short-circuit broke." + ) + + +# ---- 2. Other marker values pass through verbatim ---- + + +def test_other_marker_value_passes_through_unchanged(): + """azure_imds / unknown values are not our concern — librdkafka handles them.""" + # azure_imds is a librdkafka-recognised value; we expect Producer construction + # to NOT raise our ValueError. (librdkafka itself may complain about the + # value, but our dispatcher must not.) + with pytest.raises(Exception) as exc_info: + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "azure_imds", + }) + # Whatever librdkafka raises, it must NOT be our method-requirement + # error or our config-requirement error. + msg = str(exc_info.value) + assert "AWS IAM" not in msg + assert "aws_iam" not in msg + + +# ---- 3. method=oidc requirement ---- + + +def test_marker_without_method_raises(): + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + }) + + +def test_marker_with_method_default_raises(): + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "default", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + }) + + +def test_marker_with_method_oidc_uppercase_raises(): + """Strict matching — librdkafka's method values are lowercase canonical.""" + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "OIDC", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + }) + + +# ---- 4. sasl.oauthbearer.config requirement ---- + + +def test_marker_with_method_oidc_but_missing_config_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.config.*missing or empty"): + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + }) + + +def test_marker_with_method_oidc_but_empty_config_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.config.*missing or empty"): + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "", + }) + + +# ---- 5. Happy path: marker + method=oidc + valid config + boto3 stubbed ---- + + +def test_happy_path_producer_constructs_with_marker(mocked_boto3): + p = confluent_kafka.Producer(_minimal_aws_iam_config()) + assert p is not None + p.flush(timeout=0.1) + + +def test_happy_path_consumer_constructs_with_marker(mocked_boto3): + c = confluent_kafka.Consumer( + _minimal_aws_iam_config({"group.id": "test-group"}) + ) + assert c is not None + c.close() + + +def test_happy_path_admin_client_constructs_with_marker(mocked_boto3): + a = AdminClient(_minimal_aws_iam_config()) + assert a is not None + + +async def test_happy_path_aio_producer_constructs_with_marker(mocked_boto3): + """AIO wraps sync — same dispatcher fires via the inner cimpl.Producer. + + AIOProducer.__init__ calls asyncio.get_running_loop(), so this test + must run inside an async context (pyproject.toml's asyncio_mode=auto + handles that for async def test functions).""" + from confluent_kafka.aio import AIOProducer + p = AIOProducer(_minimal_aws_iam_config()) + assert p is not None + + +async def test_happy_path_aio_consumer_constructs_with_marker(mocked_boto3): + """Same async-context requirement as AIOProducer.""" + from confluent_kafka.aio import AIOConsumer + c = AIOConsumer(_minimal_aws_iam_config({"group.id": "test-group"})) + assert c is not None + + +# ---- 6. Explicit oauth_cb wins (precedence rule) ---- + + +def test_explicit_oauth_cb_wins_over_marker(): + """When user supplies their own oauth_cb, our dispatcher strips the marker + and yields. boto3 is NOT touched.""" + sentinel_called = [] + + def user_oauth_cb(config_str): + sentinel_called.append(config_str) + return ("user-token", 9999999999.0, "user-principal", {}) + + # Clear sys.modules so we can verify aws_autowire was NOT imported. + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + del sys.modules[name] + + p = confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + "oauth_cb": user_oauth_cb, + }) + assert p is not None + p.flush(timeout=0.1) + # The autowire module must not have been imported. + assert "confluent_kafka.oauthbearer.aws.aws_autowire" not in sys.modules + + +# ---- 7. Parser errors surface from the autowire path ---- + + +def test_marker_with_invalid_config_grammar_raises(mocked_boto3): + """Config parser ValueError surfaces through the dispatcher.""" + with pytest.raises(ValueError, match="Unknown key.*not_a_key"): + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a not_a_key=foo", + }) + + +def test_marker_with_invalid_signing_algorithm_raises(mocked_boto3): + with pytest.raises(ValueError, match="signing_algorithm"): + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a signing_algorithm=HS256", + }) + + +def test_marker_with_invalid_extensions_grammar_raises(mocked_boto3): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + "sasl.oauthbearer.extensions": "malformed-no-equals", + }) + + +# ---- 8. Friendly ImportError when the optional extra is missing ---- + + +@pytest.fixture +def boto3_absent(monkeypatch): + """Simulates an opt-out environment: boto3 import fails + relevant + aws.* submodules cleared from sys.modules so the C dispatcher's + PyImport_ImportModule re-executes the module body and hits the + boto3=None gate.""" + # Force any subsequent `import boto3` to raise ImportError. + monkeypatch.setitem(sys.modules, "boto3", None) + # Clear the cached aws submodules so PyImport_ImportModule re-executes + # their top-level statements (including `import boto3`). + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + monkeypatch.delitem(sys.modules, name, raising=False) + yield + # monkeypatch reverts on teardown. + + +def test_marker_with_missing_extra_raises_friendly_import_error(boto3_absent): + """When boto3 isn't available, the dispatcher catches the + ModuleNotFoundError from the import chain and rewrites it into a + friendly install hint. __cause__ chain preserves the original.""" + with pytest.raises(ImportError) as exc_info: + confluent_kafka.Producer({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + }) + msg = str(exc_info.value) + assert "oauthbearer-aws" in msg + assert "pip install" in msg + assert "aws_iam" in msg + # __cause__ preserves the original failure for diagnostic tools. + assert exc_info.value.__cause__ is not None + + +def test_friendly_import_error_on_consumer_too(boto3_absent): + with pytest.raises(ImportError, match="oauthbearer-aws"): + confluent_kafka.Consumer({ + "bootstrap.servers": "broker.invalid:9092", + "group.id": "g", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + }) + + +def test_friendly_import_error_on_admin_client_too(boto3_absent): + with pytest.raises(ImportError, match="oauthbearer-aws"): + AdminClient({ + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + }) + + +# ---- 9. Marker is stripped before native handoff ---- + + +def test_marker_stripped_after_autowire(mocked_boto3): + """The C dispatcher must strip the marker from confdict before the + config-iteration loop sees it. Today's bundled librdkafka doesn't know + the 'aws_iam' enum value and would reject it at rd_kafka_conf_set time. + If the strip stops working, this test fails with librdkafka's + 'invalid value' error rather than Producer constructing cleanly.""" + p = confluent_kafka.Producer(_minimal_aws_iam_config()) + # If we got here, the strip succeeded — librdkafka never saw the marker. + assert p is not None + p.flush(timeout=0.1) + + +# ---- 10. Extensions plumbing through the dispatcher ---- + + +def test_marker_with_extensions_plumbs_through(mocked_boto3): + """The dispatcher reads sasl.oauthbearer.extensions and passes it as the + second arg to create_handler. Verify by checking the returned callable + yields the extensions in its 4-tuple result.""" + p = confluent_kafka.Producer(_minimal_aws_iam_config({ + "sasl.oauthbearer.extensions": "logicalCluster=lkc-123", + })) + assert p is not None + p.flush(timeout=0.1) + + +def test_marker_with_empty_extensions_treated_as_absent(mocked_boto3): + """Empty extensions string → autowire treats as None (no extensions).""" + p = confluent_kafka.Producer(_minimal_aws_iam_config({ + "sasl.oauthbearer.extensions": "", + })) + assert p is not None + p.flush(timeout=0.1) From f6558def5bd871d4059e52f73251bc9dbc80eb8d Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 27 May 2026 19:52:21 +0530 Subject: [PATCH 06/21] real-AWS integration test for the autowire path --- tests/oauthbearer/aws/test_real_sts.py | 409 +++++++++++++++++++++++++ 1 file changed, 409 insertions(+) create mode 100644 tests/oauthbearer/aws/test_real_sts.py diff --git a/tests/oauthbearer/aws/test_real_sts.py b/tests/oauthbearer/aws/test_real_sts.py new file mode 100644 index 000000000..bc0bde643 --- /dev/null +++ b/tests/oauthbearer/aws/test_real_sts.py @@ -0,0 +1,409 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Phase 6 — real-AWS integration tests for the autowire path. + +Gated on ``RUN_AWS_STS_REAL=1``. Default ``pytest tests/oauthbearer/aws`` +runs every other Phase 1-5 case but skips this file's tests so CI doesn't +hit the AWS API. + +Why this file lives under ``tests/oauthbearer/aws/`` rather than +``tests/integration/``: ``tests/integration/conftest.py`` eagerly imports +``trivup.clusters.KafkaCluster``, which is a heavyweight broker-only +dependency. The STS provider has no broker dependency — the env-var gate +is sufficient isolation. Same gate-placement call as the 29-April POC. + +Cross-language invariant: on the shared EC2 test box +(``ktrue-iam-sts-test-role`` in ``eu-north-1``, account ``708975691912``, +audience ``https://api.example.com``) Go / .NET / JS / librdkafka all +mint a **1256-byte JWT** with principal +``arn:aws:iam::708975691912:role/ktrue-iam-sts-test-role``. The Python +port preserves that invariant. + +Run instructions:: + + # On EC2 with role attached, audience-trust-policy enabled: + export RUN_AWS_STS_REAL=1 + export AWS_REGION=eu-north-1 + export AWS_STS_TEST_AUDIENCE=https://api.example.com + /tmp/ckp-optin/bin/pytest -v -s tests/oauthbearer/aws/test_real_sts.py + +The ``-s`` flag preserves the diagnostic ``print(...)`` lines that capture +the JWT length, principal, and timing — important for cross-language +parity evidence in PR descriptions. +""" + +import base64 +import json +import os +import re +import time + +import pytest + +# boto3 is loaded transitively via the autowire path. Skip the whole file +# in opt-out venvs so collection doesn't blow up. +pytest.importorskip("boto3") + +# Gate the entire module behind RUN_AWS_STS_REAL=1 so CI doesn't accidentally +# hit STS. Individual @pytest.mark.skipif decorators would work too but the +# module-level pytestmark is cleaner and less error-prone. +pytestmark = pytest.mark.skipif( + os.environ.get("RUN_AWS_STS_REAL") != "1", + reason="Set RUN_AWS_STS_REAL=1 and provide AWS credentials to run.", +) + +from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler # noqa: E402 + + +# ============================================================================= +# Configuration sourced from env vars so the test can be re-pointed at a +# different role / audience without code changes. +# ============================================================================= + +_REGION = os.environ.get("AWS_REGION", "eu-north-1") +_AUDIENCE = os.environ.get("AWS_STS_TEST_AUDIENCE", "https://api.example.com") +_DURATION = os.environ.get("DURATION_SECONDS", "300") + +# Cross-language invariant: 1256 bytes on the shared test role + audience. +# Allow a tolerance band for variation across audiences (URL length affects +# payload size). Catches order-of-magnitude bugs without flaking on tiny +# audience-string changes. +_JWT_LENGTH_MIN = int(os.environ.get("JWT_LENGTH_MIN", "1100")) +_JWT_LENGTH_MAX = int(os.environ.get("JWT_LENGTH_MAX", "1500")) + +_JWT_PATTERN = re.compile(r"^[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+$") +_ARN_PATTERN = re.compile(r"^arn:aws:(?:iam|sts)::\d+:[^/]+/.+$") + + +def _config_string(**parts) -> str: + """Render a sasl.oauthbearer.config wire string from key=value parts.""" + return " ".join(f"{k}={v}" for k, v in parts.items()) + + +def _default_config(**extra) -> str: + """Build a sasl.oauthbearer.config string with the defaults plus extras.""" + parts = {"region": _REGION, "audience": _AUDIENCE, "duration_seconds": _DURATION} + parts.update(extra) + return _config_string(**parts) + + +def _decode_jwt_payload(jwt: str) -> dict: + """Decode the JWT payload (middle segment) into a Python dict. + + Helper for tests that need to inspect specific claims (tags, audience, + algorithm) end-to-end. + """ + payload_b64url = jwt.split(".")[1] + # base64url → base64 with padding + padding = "=" * (-len(payload_b64url) % 4) + standard_b64 = (payload_b64url + padding).replace("-", "+").replace("_", "/") + return json.loads(base64.b64decode(standard_b64).decode("utf-8")) + + +# ============================================================================= +# create_handler() — the cross-module entry-point contract +# ============================================================================= + + +def test_create_handler_mints_valid_jwt(): + handler = create_handler(_default_config(), None) + token, expiry, principal, extensions = handler("") + + # 3-segment base64url JWT. + assert _JWT_PATTERN.match(token), f"not a 3-segment JWT: {token[:60]}..." + + # Length within cross-language tolerance. + assert _JWT_LENGTH_MIN <= len(token) <= _JWT_LENGTH_MAX, ( + f"JWT length {len(token)} outside [{_JWT_LENGTH_MIN}, " + f"{_JWT_LENGTH_MAX}] (Go/.NET/JS/librdkafka observe 1256 on the shared " + f"test box). Set JWT_LENGTH_MIN/MAX env vars if running on a different " + f"role/audience." + ) + + # Expiry within the requested window. + now = time.time() + assert expiry > now, "expiry must be in the future" + assert expiry < now + 10 * 60, ( + "expiry must be within 10 min for DurationSeconds <= 300" + ) + + # Principal is a bare role ARN (AWS STS GetWebIdentityToken convention). + assert _ARN_PATTERN.match(principal), f"unexpected principal: {principal!r}" + + # Default config has no extensions. + assert extensions == {} + + # Diagnostic for the run log so EC2 evidence is captured in CI / PR. + print( + f"\n[autowire-real] jwt_length={len(token)} " + f"principal={principal} expires_in={int(expiry - now)}s" + ) + + +def test_create_handler_jwt_length_matches_cross_language(): + """Cross-language byte-exact JWT length parity check. + + Different roles produce different JWT lengths because AWS STS bakes + role-attached ``principal_tags`` into the payload (e.g. provisioning- + system tags like ``divvy_owner`` add ~70 bytes). The 1256-byte length + seen on Go/.NET/JS/librdkafka in earlier validation was specific to + one tag-less role. + + This test only runs when ``EXPECTED_TOKEN_LEN`` is explicitly set — + its purpose is to confirm byte-for-byte parity against a baseline you've + measured by running any sibling client (Go / .NET / JS) against the + same role + audience. + """ + expected_str = os.environ.get("EXPECTED_TOKEN_LEN") + if not expected_str: + pytest.skip( + "EXPECTED_TOKEN_LEN not set — cross-language byte-exact parity " + "check only runs when you supply a known-good baseline. Run a " + "sibling client (Go/.NET/JS) against the same role+audience, " + "note the JWT length, and set EXPECTED_TOKEN_LEN to that value." + ) + + expected = int(expected_str) + handler = create_handler(_default_config(), None) + token, *_ = handler("") + assert len(token) == expected, ( + f"expected {expected} bytes (cross-language match), got {len(token)}" + ) + + +def test_create_handler_principal_matches_role_arn(): + """JWT 'sub' claim is the bare role ARN. AWS STS GetWebIdentityToken + convention (no session prefix on bare role tokens).""" + handler = create_handler(_default_config(), None) + _, _, principal, _ = handler("") + assert principal.startswith("arn:aws:iam::"), f"not an IAM role ARN: {principal!r}" + assert ":role/" in principal, f"not a role ARN: {principal!r}" + + +def test_create_handler_repeats_mint_distinct_tokens(): + """Provider does not cache; STS mints a fresh JWT each invocation.""" + handler = create_handler(_default_config(), None) + _, exp1, _, _ = handler("") + time.sleep(1.0) + _, exp2, _, _ = handler("") + assert exp2 > exp1, "second token's expiry must be later than the first" + + +def test_create_handler_honours_duration_seconds(): + """Custom duration_seconds=60 produces an expiry ~1 min out, not the 300s + default. + + The shared test role's IAM policy caps DurationSeconds at 300 (per .NET + PR validation §2g and the JS hybrid plan). So we go *below* the default + rather than above to prove the parameter is honored end-to-end. Same + value JS uses for the same reason. + """ + handler = create_handler(_default_config(duration_seconds="60"), None) + _, expiry, _, _ = handler("") + seconds_remaining = expiry - time.time() + assert 50 <= seconds_remaining <= 80, ( + f"duration_seconds=60 → expected ~60s window, got {seconds_remaining:.0f}s" + ) + + +def test_create_handler_honours_signing_algorithm_rs256(): + """signing_algorithm=RS256 produces a JWT with alg=RS256 in the header. + + Cross-language parity check: the .NET test asserts the same header.alg + round-trip. Catches any SDK-side algorithm-shaping surprises. + """ + handler = create_handler(_default_config(signing_algorithm="RS256"), None) + token, *_ = handler("") + # Decode the JWT header (first segment). + header_b64url = token.split(".")[0] + padding = "=" * (-len(header_b64url) % 4) + header_json = json.loads( + base64.b64decode( + (header_b64url + padding).replace("-", "+").replace("_", "/") + ).decode("utf-8") + ) + assert header_json.get("alg") == "RS256", ( + f"header.alg expected RS256, got {header_json.get('alg')!r}" + ) + + +def test_create_handler_honours_principal_name_override(): + """Setting principal_name=... in the wire grammar overrides the JWT 'sub' + extraction at the autowire layer.""" + handler = create_handler( + _default_config(principal_name="custom-principal"), None, + ) + _, _, principal, _ = handler("") + assert principal == "custom-principal", ( + f"principal_name override not honoured: got {principal!r}" + ) + + +def test_create_handler_round_trips_sasl_extensions(): + """The typed sasl.oauthbearer.extensions property is parsed separately + (comma-separated) and surfaces in the returned extensions dict. + + Note: extensions are forwarded to the broker — not to STS — so they + don't appear in the JWT. This test asserts the dispatcher-to-tuple + round-trip, not the JWT payload.""" + handler = create_handler( + _default_config(), + "logicalCluster=lkc-abc,identityPoolId=pool-xyz", + ) + _, _, _, extensions = handler("") + assert extensions == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-xyz", + } + + +def test_create_handler_tag_claims_flow_to_sts(): + """tag_= in the wire grammar flows into the STS Tags + parameter without breaking the request. + + The primary evidence is **token minting succeeding** — if our Tags + parameter caused STS to reject the request (e.g. malformed shape or + over the 50-tag cap), we'd see a ClientError exception. Token returned + cleanly means dispatcher → provider → boto3 → AWS STS handoff worked. + + Whether AWS STS surfaces our custom tags in the JWT payload's + ``principal_tags`` claim is a SEPARATE question depending on the role's + IAM trust policy — specifically whether ``sts:TagSession`` is allowed. + Without that permission, STS silently drops user tags from the JWT. + The diagnostic prints below show what STS actually embedded so users + can spot the case where tags are silently dropped. + + Set ``ASSERT_TAGS_IN_JWT=1`` to require that our tags appear in the + JWT — only meaningful on roles with ``sts:TagSession`` allowed. + """ + handler = create_handler( + _default_config(tag_team="platform", tag_environment="prod"), None, + ) + # If our Tags param breaks the STS call, this raises. + token, *_ = handler("") + payload = _decode_jwt_payload(token) + + # AWS STS surfaces tags either under "principal_tags" at the JWT root + # OR nested under "https://sts.amazonaws.com/" → "principal_tags". + found_tags = {} + if "principal_tags" in payload: + found_tags.update(payload["principal_tags"]) + for nested_value in payload.values(): + if isinstance(nested_value, dict) and "principal_tags" in nested_value: + found_tags.update(nested_value["principal_tags"]) + + our_tags_present = "team" in found_tags and "environment" in found_tags + print(f"\n[autowire-real] JWT payload keys: {list(payload.keys())}") + print(f"[autowire-real] tags surfaced in JWT: {found_tags}") + print( + f"[autowire-real] our injected tags appear in JWT: {our_tags_present}" + + ("" if our_tags_present else " (role likely missing sts:TagSession permission)") + ) + + # Token minted successfully — Tags param didn't cause STS rejection. + assert len(token) > 100, "Token should be a valid JWT" + + # Strict assertion only when the user explicitly opts in via env var. + if os.environ.get("ASSERT_TAGS_IN_JWT") == "1": + assert our_tags_present, ( + f"ASSERT_TAGS_IN_JWT=1 set but our tags aren't in the JWT. " + f"Role likely missing sts:TagSession permission. " + f"Found tags: {found_tags}" + ) + + +def test_create_handler_jwt_audience_matches_request(): + """The JWT payload's 'aud' claim matches the audience we asked AWS for. + + Defensive — catches any audience-shaping surprises (AWS should not munge + the audience string). Mirrors the .NET integration test's assertion. + """ + handler = create_handler(_default_config(), None) + token, *_ = handler("") + payload = _decode_jwt_payload(token) + assert payload.get("aud") == _AUDIENCE, ( + f"JWT aud {payload.get('aud')!r} doesn't match requested {_AUDIENCE!r}" + ) + + +# ============================================================================= +# Full dispatcher flow — Producer with marker +# ============================================================================= + + +def test_producer_with_marker_succeeds_against_real_sts(): + """End-to-end: construct a Producer with the marker set; the C dispatcher + in common_conf_setup() invokes create_handler() which builds a real + provider; librdkafka's background thread invokes the autowired callback; + STS mints a real token; rd_kafka_oauthbearer_set_token marks the token + as set; wait_for_oauth_token_set returns successfully and the Producer + constructor returns. + + The bootstrap.servers points at a non-resolvable host so we don't waste + a real broker connection — the OAUTHBEARER refresh path doesn't need + one. It fires immediately on background-thread startup. + """ + from confluent_kafka import Producer + + t0 = time.time() + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": _default_config(), + } + ) + elapsed = time.time() - t0 + + # Producer construction returns once the autowire callback has set a + # real token (typically <2s on EC2). wait_for_oauth_token_set timeout + # is 10s — anything over that means the autowire chain didn't fire. + assert elapsed < 5.0, ( + f"Producer construction took {elapsed:.1f}s — autowire path slow " + f"or broken (10s timeout = OAuth refresh never set the token)" + ) + print(f"\n[autowire-real] Producer constructed via dispatcher in {elapsed:.2f}s") + + +def test_producer_with_marker_and_extensions_succeeds(): + """Same end-to-end path with sasl.oauthbearer.extensions set on the + typed property. Proves the C dispatcher reads the extensions string + and passes it as the second arg to create_handler.""" + from confluent_kafka import Producer + + t0 = time.time() + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": _default_config(), + "sasl.oauthbearer.extensions": "logicalCluster=lkc-abc", + } + ) + elapsed = time.time() - t0 + assert elapsed < 5.0, ( + f"Producer with extensions took {elapsed:.1f}s — autowire path " + f"slow or broken" + ) + print( + f"\n[autowire-real] Producer (with extensions) constructed via " + f"dispatcher in {elapsed:.2f}s" + ) From 96c252346f8628ce8eca3d6d790752526be9f61c Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Thu, 4 Jun 2026 11:03:24 +0530 Subject: [PATCH 07/21] Fix comments --- .../oauthbearer/aws/_aws_iam_marker.py | 7 ++- .../aws/_aws_jwt_subject_extractor.py | 18 +----- .../aws/_aws_oauthbearer_config.py | 27 +-------- .../aws/_aws_sasl_extensions_parser.py | 11 +--- .../aws/_aws_sts_token_provider.py | 37 +----------- .../oauthbearer/aws/aws_autowire.py | 57 ++++++------------- tests/_util/test_kv_string_parser.py | 5 +- tests/oauthbearer/aws/test_aws_iam_marker.py | 31 +--------- .../aws/test_aws_jwt_subject_extractor.py | 12 ++-- .../aws/test_aws_oauthbearer_config.py | 6 +- .../aws/test_aws_sasl_extensions_parser.py | 5 +- .../aws/test_aws_sts_token_provider.py | 5 +- tests/oauthbearer/aws/test_contract.py | 7 --- 13 files changed, 39 insertions(+), 189 deletions(-) diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py index 5bd95ac95..7a1cf69ad 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py @@ -14,7 +14,8 @@ """Marker constants identifying the AWS IAM OAUTHBEARER autowire path. -Mirrors .NET's ``Confluent.Kafka.Internal.OAuthBearer.Aws.AwsIamMarker``. +The config-key/value pair that activates the AWS OAUTHBEARER autowire path. + The C dispatcher in ``src/confluent_kafka/src/confluent_kafka.c`` keeps its own literal copies of these values for compile-time use; the drift-guard test in ``tests/oauthbearer/aws/test_aws_iam_marker.py`` asserts the C-side @@ -27,8 +28,8 @@ __all__ = ["AWS_IAM_MARKER_KEY", "AWS_IAM_MARKER_VALUE"] -#: librdkafka config key that activates the AWS IAM autowire path when set -#: to :data:`AWS_IAM_MARKER_VALUE`. See the README for the full user contract. +#: Config key that activates the AWS IAM autowire path when set +#: to :data:`AWS_IAM_MARKER_VALUE`. AWS_IAM_MARKER_KEY: str = "sasl.oauthbearer.metadata.authentication.type" #: On-wire value of :data:`AWS_IAM_MARKER_KEY` that selects AWS IAM diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py index 7c268110b..6596aec00 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py @@ -14,16 +14,7 @@ """Internal: extracts the ``sub`` claim from an unverified AWS-minted JWT. -Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.Internal.AwsJwtSubjectExtractor``. - -No signature verification — AWS STS already signed the JWT and the broker -performs the cryptographic validation. We only decode the unprotected -payload segment to read the ``sub`` claim (the role ARN), which becomes the -``principal`` field handed to ``rd_kafka_oauthbearer_set_token``. - -Strict base64 decoding (``validate=True``) is used so stray non-alphabet -characters raise instead of being silently dropped — ``urlsafe_b64decode`` -is lenient and would let malformed payloads slip past JSON parsing. +No signature verification — STS signs, broker validates. """ import base64 @@ -88,12 +79,6 @@ def extract_sub(jwt: str) -> str: def _decode_base64url_segment(segment: str) -> bytes: - """Base64url-decode a JWT segment to bytes. - - Restores '=' padding to the next 4-byte boundary, swaps '-' → '+' and - '_' → '/', then defers to :func:`base64.b64decode` with strict - ``validate=True``. - """ if len(segment) == 0: raise ValueError("JWT payload segment is empty.") @@ -106,7 +91,6 @@ def _decode_base64url_segment(segment: str) -> bytes: elif remainder == 3: s += "=" else: - # remainder == 1 → not a valid base64url length raise ValueError("JWT payload segment has invalid base64url length.") try: diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py index 79dbf50cd..ff33ac555 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py @@ -14,10 +14,6 @@ """Internal: validated ``sasl.oauthbearer.config`` dataclass + parser. -Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.Internal.AwsOAuthBearerConfig``. -Encapsulates the wire-grammar parser and the validated, immutable typed view -of the AWS path's ``sasl.oauthbearer.config`` value. - The full grammar (whitespace-separated ``key=value`` pairs, no quoting): region= (required) @@ -57,17 +53,16 @@ ] -#: Top-level librdkafka config key carrying the AWS-path wire-grammar string. +#: Config key carrying the AWS-path wire-grammar string. CONFIG_KEY: str = "sasl.oauthbearer.config" -#: Default JWT signing algorithm. Matches .NET / cross-language default. +#: Default JWT signing algorithm. DEFAULT_SIGNING_ALGORITHM: str = "ES384" #: Signing algorithms accepted by AWS STS ``GetWebIdentityToken``. ALLOWED_SIGNING_ALGORITHMS = ("ES384", "RS256") #: Minimum / default / maximum token lifetime AWS STS allows -#: (60s and 3600s are AWS-enforced bounds; 300s is the .NET-aligned default). MIN_DURATION_SECONDS: int = 60 MAX_DURATION_SECONDS: int = 3600 DEFAULT_DURATION_SECONDS: int = 300 @@ -83,12 +78,6 @@ AWS_DEBUG_CONSOLE: str = "console" #: ``aws_debug`` values accepted by the Python client. -#: -#: Python takes the **Pythonic subset** decision (locked 2026-05-26): only -#: ``none`` (no-op) and ``console`` (routes botocore logs to stderr via -#: :func:`boto3.set_stream_logger`). The .NET-only sinks ``log4net`` and -#: ``systemdiagnostics`` are rejected with a "not supported on this platform" -#: error so users moving between language clients get a clear signal. ALLOWED_AWS_DEBUG_VALUES = (AWS_DEBUG_NONE, AWS_DEBUG_CONSOLE) #: ``aws_debug`` values that exist in .NET but are not supported in Python. @@ -107,10 +96,6 @@ "aws_debug", }) -# Keys whose value must NOT be the empty string (``region= audience=...``). -# Empty values are tolerated for ``tag_`` (mirrors AWS allowing empty -# tag values) and for ``duration_seconds`` (which separately fails integer -# parsing with a clearer error). _NON_EMPTY_KEYS = frozenset({ "region", "audience", @@ -123,13 +108,7 @@ @dataclass(frozen=True) class AwsOAuthBearerConfig: - """Validated, immutable view of the AWS path's ``sasl.oauthbearer.config``. - - Construct via :meth:`parse` rather than directly — the classmethod is the - only path that exercises the wire-grammar parser. The dataclass's - :meth:`__post_init__` validates final state regardless of construction - path, so direct construction with bad values still raises. - """ + """Immutable view of the AWS path's ``sasl.oauthbearer.config``.""" region: str audience: str diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py b/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py index fc8514f5a..05d133e2e 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py @@ -14,11 +14,8 @@ """Internal: parser for the ``sasl.oauthbearer.extensions`` config property. -Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.Internal.AwsSaslExtensionsParser``. - The ``sasl.oauthbearer.extensions`` config carries RFC 7628 §3.1 SASL -extensions as a comma-separated ``key=value`` list. Forwarded verbatim to -the broker alongside the JWT — not part of the JWT itself. +extensions as a comma-separated ``key=value`` list. """ from typing import Dict, Optional @@ -28,16 +25,14 @@ __all__ = ["CONFIG_KEY", "parse"] -#: Top-level librdkafka config key carrying the SASL extensions list. +#: Config key carrying the SASL extensions list. CONFIG_KEY: str = "sasl.oauthbearer.extensions" def parse(raw: Optional[str]) -> Optional[Dict[str, str]]: """Parse the verbatim ``sasl.oauthbearer.extensions`` value into a dict. - Grammar: comma-separated ``key=value`` tokens. Whitespace around each - token is trimmed (mirrors .NET / librdkafka). Empty tokens (e.g. - ``"a=1,,b=2,"``) are tolerated and skipped. + The grammar mirrors the cross-language convention for consistency. Returns ``None`` for ``None`` / empty input so the autowire layer can short-circuit without constructing an empty dict. diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py index fac6715ac..a2f72c59e 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py @@ -12,29 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Internal: AWS STS ``GetWebIdentityToken``-based OAUTHBEARER token provider. - -Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.Internal.AwsStsTokenProvider``. - -This is the file whose top-level ``import boto3`` is the actual gate that -makes the entire ``oauthbearer-aws`` extra opt-in: opt-out users have no -``boto3`` installed, so transitively importing this module from -:mod:`...aws_autowire` raises ``ModuleNotFoundError`` and the C dispatcher -rewrites it as a friendly install hint at client construction time. - -The class :class:`AwsStsTokenProvider` is constructed once per autowired -client by :func:`...aws_autowire.create_handler` and its bound ``token`` -method is installed as the ``oauth_cb`` Python callable. librdkafka invokes -this method from its background thread on every token refresh; the return -4-tuple matches the C extension's ``oauth_cb`` contract (see -``confluent_kafka.c`` around ``L2291``: ``PyArg_ParseTuple(result, "sd|sO!", -...)``). - -Credential resolution is **lazy**: ``__init__`` constructs the boto3 -``Session`` and STS client objects without making any HTTP calls. The first -``token()`` invocation triggers boto3's default credential chain (env → -shared config → IMDS → ECS → IRSA → SSO). -""" +"""Internal: Fetches OAUTHBEARER tokens via AWS STS GetWebIdentityToken.""" import logging from typing import Any, Dict, Optional, Tuple @@ -57,18 +35,7 @@ class AwsStsTokenProvider: - """Mints OAUTHBEARER tokens via AWS STS ``GetWebIdentityToken``. - - The instance method :meth:`token` is shaped to slot directly into the - ``oauth_cb`` C contract (4-tuple of ``(token, expiry_seconds, principal, - extensions_dict)``). - - Construction is lightweight and side-effect-light: the boto3 STS client - is built eagerly but no network call is made until :meth:`token` runs. - The ``aws_debug=console`` side-effect (process-wide - :func:`boto3.set_stream_logger` configuration) does fire at construction - when configured. - """ + """Mints OAUTHBEARER tokens via AWS STS ``GetWebIdentityToken``.""" def __init__( self, diff --git a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py index 0ee5ad3fe..d84c27cb9 100644 --- a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py +++ b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py @@ -14,15 +14,24 @@ """Public entry-point for AWS IAM OAUTHBEARER autowire. -Mirrors .NET's ``Confluent.Kafka.OAuthBearer.Aws.AwsAutoWire``. This is the -**only publicly importable name** in the optional subpackage. The C -dispatcher in ``src/confluent_kafka/src/confluent_kafka.c`` (Phase 5) reaches -this module via:: +This is the **only publicly importable name** in the optional subpackage. +End users do not call :func:`create_handler` directly — the C dispatcher in +``src/confluent_kafka/src/confluent_kafka.c`` reaches it via:: PyImport_ImportModule("confluent_kafka.oauthbearer.aws.aws_autowire") -and resolves :func:`create_handler` by name. The function signature is a -**frozen cross-module contract**: +and resolves :func:`create_handler` by name. The marker-key check is +performed in core; :func:`create_handler` is invoked only when the C +dispatcher has decided to autowire the AWS path. + +User-facing contract — four config keys:: + + "sasl.oauthbearer.method": "oidc" + "sasl.oauthbearer.metadata.authentication.type": "aws_iam" + "sasl.oauthbearer.config": "region=... audience=..." + "sasl.oauthbearer.extensions": "key=val,..." # optional + +Frozen cross-module contract of :func:`create_handler`: * arity: 2 positional parameters * names: ``sasl_oauthbearer_config``, ``sasl_oauthbearer_extensions`` @@ -30,13 +39,8 @@ * return: :data:`OAuthBearerCallback` Bumping any of these is a breaking change requiring a major version -increment on the ``confluent-kafka`` distribution. The frozen contract is -test-guarded by ``tests/oauthbearer/aws/test_contract.py``. - -The marker key/value check is performed in core (the C dispatcher); -:func:`create_handler` is invoked only when the caller has already decided -to autowire the AWS path. The function therefore unconditionally attempts -to build a handler and raises on any input it cannot parse. +increment on the ``confluent-kafka`` distribution. Test-guarded by +``tests/oauthbearer/aws/test_contract.py``. """ from typing import Callable, Dict, Optional, Tuple @@ -48,41 +52,14 @@ __all__ = ["create_handler", "OAuthBearerCallback"] - -#: Type alias for the callable returned by :func:`create_handler`. -#: -#: Tuple shape matches the existing ``oauth_cb`` contract enforced in C at -#: ``confluent_kafka.c`` around L2291 via -#: ``PyArg_ParseTuple(result, "sd|sO!", ...)`` — single ``str`` argument -#: (the ``sasl.oauthbearer.config`` value librdkafka passes back on every -#: refresh), returning ``(token_str, expiry_epoch_seconds, principal_str, -#: extensions_dict)``. OAuthBearerCallback = Callable[[str], Tuple[str, float, str, Dict[str, str]]] - def create_handler( sasl_oauthbearer_config: str, sasl_oauthbearer_extensions: Optional[str], ) -> OAuthBearerCallback: """Build an OAUTHBEARER refresh callback from the two OAUTHBEARER config strings. - Construction-time work: - - 1. Validates ``sasl_oauthbearer_config`` is a non-empty string. - 2. Parses ``sasl_oauthbearer_extensions`` (comma-separated ``key=value``) - via :mod:`._aws_sasl_extensions_parser` into an optional dict. - 3. Parses ``sasl_oauthbearer_config`` (whitespace-separated ``key=value``) - into a validated :class:`._aws_oauthbearer_config.AwsOAuthBearerConfig`. - 4. Constructs an :class:`._aws_sts_token_provider.AwsStsTokenProvider` - (no HTTP yet — credential resolution is lazy until the first - ``token()`` invocation). - 5. Returns the provider's bound :meth:`token` method as the callable. - - The returned callable is invoked by the C ``oauth_cb`` wrapper on every - OAUTHBEARER refresh; each call performs one STS ``GetWebIdentityToken`` - round-trip and returns a fresh 4-tuple suitable for - ``rd_kafka_oauthbearer_set_token``. - :param sasl_oauthbearer_config: The verbatim ``sasl.oauthbearer.config`` value (whitespace-separated ``key=value`` pairs). Must be non-empty. :param sasl_oauthbearer_extensions: The verbatim diff --git a/tests/_util/test_kv_string_parser.py b/tests/_util/test_kv_string_parser.py index f11d99c84..984139f2a 100644 --- a/tests/_util/test_kv_string_parser.py +++ b/tests/_util/test_kv_string_parser.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for confluent_kafka._util.kv_string_parser. - -Mirrors .NET's KvStringParserTests. -""" +"""Tests for confluent_kafka._util.kv_string_parser.""" import pytest diff --git a/tests/oauthbearer/aws/test_aws_iam_marker.py b/tests/oauthbearer/aws/test_aws_iam_marker.py index 1b6cc8ccb..a40c78f53 100644 --- a/tests/oauthbearer/aws/test_aws_iam_marker.py +++ b/tests/oauthbearer/aws/test_aws_iam_marker.py @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Drift-guard tests for the AWS IAM marker constants. - -Phase 2 covers the Python-side half of the drift guard: the literal values -must not move. The end-to-end half (asserting the C dispatcher's literal -strings match these) lands in Phase 5 once the C dispatcher exists. -""" +"""Drift-guard tests for the AWS IAM marker constants.""" from confluent_kafka.oauthbearer.aws._aws_iam_marker import ( AWS_IAM_MARKER_KEY, @@ -26,38 +21,14 @@ def test_marker_key_is_locked_value(): - """The marker key is part of the cross-language wire contract. - - Bumping it would silently break .NET / Go / JS / Python parity. Any change - here MUST be coordinated as a major version bump across all four clients. - """ assert AWS_IAM_MARKER_KEY == "sasl.oauthbearer.metadata.authentication.type" def test_marker_value_is_locked_value(): - """The marker value is part of the cross-language wire contract. - - Same constraint as :func:`test_marker_key_is_locked_value`. - """ assert AWS_IAM_MARKER_VALUE == "aws_iam" -# ---- End-to-end drift guard (active from Phase 5 onward) ---- - - def test_c_dispatcher_recognises_python_authoritative_marker(): - """Drift-guard end-to-end check. - - If the C-side literal strings in confluent_kafka.c drift away from the - Python constants in _aws_iam_marker.py, the dispatcher would no longer - recognize a Producer built with these constants, the precondition check - would not fire, and the ValueError below would NOT raise. The test - therefore fails loudly on drift. - - We deliberately omit `sasl.oauthbearer.method=oidc` so the precondition - check fires — that's the cheapest reliable signal that the dispatcher - saw our marker. - """ import pytest from confluent_kafka import Producer diff --git a/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py b/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py index e6fad53eb..1e6078818 100644 --- a/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py +++ b/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for confluent_kafka.oauthbearer.aws._aws_jwt_subject_extractor. - -Mirrors .NET's AwsJwtSubjectExtractorTests. -""" +"""Tests for confluent_kafka.oauthbearer.aws._aws_jwt_subject_extractor.""" import base64 @@ -28,7 +25,9 @@ def _base64url_encode(data: bytes) -> str: - """Mirror of .NET's AwsTestHelpers.Base64UrlEncode.""" + """Base64url-encodes a byte array: standard base64, trim '=' padding, + swap '+' → '-' and '/' → '_'. + """ return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") @@ -112,7 +111,6 @@ def test_extract_sub_other_claims_ignored(): @pytest.mark.parametrize("jwt", [_REAL_ES384_JWT, _REAL_RS256_JWT]) def test_extract_sub_real_sts_jwt_returns_expected_arn(jwt): - """Cross-language wire-shape parity with .NET / Go / JS / librdkafka.""" assert extract_sub(jwt) == _EXPECTED_ROLE_ARN @@ -130,8 +128,6 @@ def test_extract_sub_padded_base64url_also_works(): # Encode WITHOUT stripping padding. payload = base64.b64encode(b'{"sub":"abc"}').decode("ascii").replace("+", "-").replace("/", "_") jwt = f"{header}.{payload}." - # Trailing empty segment makes len(parts) != 3 if payload contains '=' — fix segment count. - # The .NET test uses a 3-segment shape with empty signature. Mirror that: jwt_3seg = f"{header}.{payload}.sig" assert extract_sub(jwt_3seg) == "abc" diff --git a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py index 8876c848e..6629a611c 100644 --- a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py +++ b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for confluent_kafka.oauthbearer.aws._aws_oauthbearer_config. - -Mirrors .NET's AwsOAuthBearerConfigTests, modulo Python conventions and the -Pythonic aws_debug subset decision (none/console only). -""" +"""Tests for confluent_kafka.oauthbearer.aws._aws_oauthbearer_config.""" import pytest diff --git a/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py b/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py index 9b1b46d49..7305c5432 100644 --- a/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py +++ b/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for confluent_kafka.oauthbearer.aws._aws_sasl_extensions_parser. - -Mirrors .NET's AwsSaslExtensionsParserTests. -""" +"""Tests for confluent_kafka.oauthbearer.aws._aws_sasl_extensions_parser.""" import pytest diff --git a/tests/oauthbearer/aws/test_aws_sts_token_provider.py b/tests/oauthbearer/aws/test_aws_sts_token_provider.py index aef20aefa..fd8e9122f 100644 --- a/tests/oauthbearer/aws/test_aws_sts_token_provider.py +++ b/tests/oauthbearer/aws/test_aws_sts_token_provider.py @@ -14,8 +14,6 @@ """Tests for confluent_kafka.oauthbearer.aws._aws_sts_token_provider. -Mirrors .NET's AwsStsTokenProviderTests, modulo: - - Python uses a tuple return (not a typed record) so the "no extensions → null" assertion becomes "no extensions → empty dict" (the C oauth_cb contract requires a dict, not None, at that slot). @@ -107,9 +105,8 @@ def test_ctor_null_config_raises(): def test_ctor_valid_parsed_config_succeeds(): """No AWS / no HTTP call at construction time (lazy credential chain).""" cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") - # Pass a fake client so we don't try to instantiate a real boto3 client - # against a possibly-network-isolated test env. provider = AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + # Does not throw; does not call AWS (lazy credential chain). assert provider is not None diff --git a/tests/oauthbearer/aws/test_contract.py b/tests/oauthbearer/aws/test_contract.py index cfc9b0f2a..23d06da06 100644 --- a/tests/oauthbearer/aws/test_contract.py +++ b/tests/oauthbearer/aws/test_contract.py @@ -14,13 +14,6 @@ """Frozen cross-module contract guard for the autowire entry-point. -The C dispatcher (Phase 5) in ``src/confluent_kafka/src/confluent_kafka.c`` -resolves ``confluent_kafka.oauthbearer.aws.aws_autowire.create_handler`` via -``PyImport_ImportModule`` + ``PyObject_GetAttrString`` and calls it with two -string arguments. Any signature drift on the Python side breaks that -contract and requires a major version bump on the ``confluent-kafka`` -distribution. - These tests guard against accidental drift: parameter names, type annotations, return annotation, arity, and absence of defaults. """ From 3e7da4def6e6c9dec51e919a463f8eb65980ad6c Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Mon, 8 Jun 2026 11:52:37 +0530 Subject: [PATCH 08/21] Fix black/isort + flake8 --- .../aws/_aws_jwt_subject_extractor.py | 21 +- .../aws/_aws_oauthbearer_config.py | 114 +++----- .../aws/_aws_sts_token_provider.py | 13 +- .../oauthbearer/aws/aws_autowire.py | 1 + tests/_util/test_kv_string_parser.py | 1 - tests/oauthbearer/aws/test_aws_autowire.py | 41 +-- tests/oauthbearer/aws/test_aws_iam_marker.py | 23 +- .../aws/test_aws_jwt_subject_extractor.py | 20 +- .../aws/test_aws_oauthbearer_config.py | 131 +++------ .../aws/test_aws_sasl_extensions_parser.py | 1 - .../aws/test_aws_sts_token_provider.py | 71 ++--- tests/oauthbearer/aws/test_contract.py | 14 +- tests/oauthbearer/aws/test_dispatch.py | 259 ++++++++++-------- tests/oauthbearer/aws/test_real_sts.py | 50 +--- 14 files changed, 330 insertions(+), 430 deletions(-) diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py index 6596aec00..31395e337 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py @@ -39,31 +39,22 @@ def extract_sub(jwt: str) -> str: if jwt == "": raise ValueError("JWT is empty.") if len(jwt) > _MAX_TOKEN_LENGTH_CHARS: - raise ValueError( - f"JWT length {len(jwt)} exceeds maximum allowed " - f"({_MAX_TOKEN_LENGTH_CHARS})." - ) + raise ValueError(f"JWT length {len(jwt)} exceeds maximum allowed " f"({_MAX_TOKEN_LENGTH_CHARS}).") parts = jwt.split(".") if len(parts) != 3: - raise ValueError( - f"JWT must have exactly 3 '.'-separated segments; got {len(parts)}." - ) + raise ValueError(f"JWT must have exactly 3 '.'-separated segments; got {len(parts)}.") payload_bytes = _decode_base64url_segment(parts[1]) try: payload_string = payload_bytes.decode("utf-8") except UnicodeDecodeError as exc: - raise ValueError( - f"JWT payload is not valid UTF-8: {exc}" - ) from exc + raise ValueError(f"JWT payload is not valid UTF-8: {exc}") from exc try: token = json.loads(payload_string) except json.JSONDecodeError as exc: - raise ValueError( - f"JWT payload is not valid JSON: {exc}" - ) from exc + raise ValueError(f"JWT payload is not valid JSON: {exc}") from exc if not isinstance(token, dict): raise ValueError("JWT payload is not a JSON object.") @@ -96,6 +87,4 @@ def _decode_base64url_segment(segment: str) -> bytes: try: return base64.b64decode(s.encode("ascii"), validate=True) except (binascii.Error, UnicodeEncodeError, ValueError) as exc: - raise ValueError( - f"JWT payload segment is not valid base64url: {exc}" - ) from exc + raise ValueError(f"JWT payload segment is not valid base64url: {exc}") from exc diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py index ff33ac555..c0a8db0cb 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py @@ -86,24 +86,28 @@ # Recognised non-tag keys for the wire grammar. Anything else (other than # ``tag_``) raises "Unknown key" during :meth:`AwsOAuthBearerConfig.parse`. -_RECOGNISED_KEYS = frozenset({ - "region", - "audience", - "duration_seconds", - "signing_algorithm", - "sts_endpoint", - "principal_name", - "aws_debug", -}) - -_NON_EMPTY_KEYS = frozenset({ - "region", - "audience", - "signing_algorithm", - "sts_endpoint", - "principal_name", - "aws_debug", -}) +_RECOGNISED_KEYS = frozenset( + { + "region", + "audience", + "duration_seconds", + "signing_algorithm", + "sts_endpoint", + "principal_name", + "aws_debug", + } +) + +_NON_EMPTY_KEYS = frozenset( + { + "region", + "audience", + "signing_algorithm", + "sts_endpoint", + "principal_name", + "aws_debug", + } +) @dataclass(frozen=True) @@ -123,41 +127,27 @@ class AwsOAuthBearerConfig: def __post_init__(self) -> None: """Final-state validation. Raises :class:`ValueError` on bad input.""" if not isinstance(self.region, str) or self.region == "": - raise ValueError( - f"{CONFIG_KEY} 'region' must not be empty." - ) + raise ValueError(f"{CONFIG_KEY} 'region' must not be empty.") if not isinstance(self.audience, str) or self.audience == "": - raise ValueError( - f"{CONFIG_KEY} 'audience' must not be empty." - ) + raise ValueError(f"{CONFIG_KEY} 'audience' must not be empty.") if self.signing_algorithm not in ALLOWED_SIGNING_ALGORITHMS: raise ValueError( - f"{CONFIG_KEY} 'signing_algorithm' must be 'ES384' or 'RS256'; " - f"got {self.signing_algorithm!r}." + f"{CONFIG_KEY} 'signing_algorithm' must be 'ES384' or 'RS256'; " f"got {self.signing_algorithm!r}." ) # bool is-a int in Python — reject explicitly so True/False doesn't # slip through as 1/0. - if (not isinstance(self.duration_seconds, int) - or isinstance(self.duration_seconds, bool)): - raise ValueError( - f"{CONFIG_KEY} 'duration_seconds' must be an integer." - ) - if not (MIN_DURATION_SECONDS - <= self.duration_seconds - <= MAX_DURATION_SECONDS): + if not isinstance(self.duration_seconds, int) or isinstance(self.duration_seconds, bool): + raise ValueError(f"{CONFIG_KEY} 'duration_seconds' must be an integer.") + if not (MIN_DURATION_SECONDS <= self.duration_seconds <= MAX_DURATION_SECONDS): raise ValueError( f"{CONFIG_KEY} 'duration_seconds' must be between " f"{MIN_DURATION_SECONDS} and {MAX_DURATION_SECONDS} inclusive; " f"got {self.duration_seconds}." ) if self.sts_endpoint is not None and self.sts_endpoint == "": - raise ValueError( - f"{CONFIG_KEY} 'sts_endpoint' must not be empty." - ) + raise ValueError(f"{CONFIG_KEY} 'sts_endpoint' must not be empty.") if self.principal_name is not None and self.principal_name == "": - raise ValueError( - f"{CONFIG_KEY} 'principal_name' must not be empty." - ) + raise ValueError(f"{CONFIG_KEY} 'principal_name' must not be empty.") if self.aws_debug not in ALLOWED_AWS_DEBUG_VALUES: if self.aws_debug in NET_ONLY_AWS_DEBUG_VALUES: raise ValueError( @@ -166,22 +156,13 @@ def __post_init__(self) -> None: f"{list(ALLOWED_AWS_DEBUG_VALUES)} only; " f"'log4net' and 'systemdiagnostics' are .NET-only sinks." ) - raise ValueError( - f"{CONFIG_KEY} 'aws_debug' must be one of: " - f"none, console. Got {self.aws_debug!r}." - ) + raise ValueError(f"{CONFIG_KEY} 'aws_debug' must be one of: " f"none, console. Got {self.aws_debug!r}.") if self.tags is not None: if not isinstance(self.tags, dict): - raise ValueError( - f"{CONFIG_KEY} 'tags' must be a dict." - ) + raise ValueError(f"{CONFIG_KEY} 'tags' must be a dict.") if len(self.tags) > MAX_TAGS: - raise ValueError( - f"{CONFIG_KEY} has {len(self.tags)} tags; AWS allows at " - f"most {MAX_TAGS}." - ) - if (self.sasl_extensions is not None - and not isinstance(self.sasl_extensions, dict)): + raise ValueError(f"{CONFIG_KEY} has {len(self.tags)} tags; AWS allows at " f"most {MAX_TAGS}.") + if self.sasl_extensions is not None and not isinstance(self.sasl_extensions, dict): raise ValueError("sasl_extensions, if set, must be a dict.") @classmethod @@ -223,9 +204,7 @@ def parse( context_label=CONFIG_KEY, ): if key in _NON_EMPTY_KEYS and value == "": - raise ValueError( - f"{CONFIG_KEY} {key!r} must not be empty." - ) + raise ValueError(f"{CONFIG_KEY} {key!r} must not be empty.") if key == "region": region = value @@ -237,10 +216,7 @@ def parse( try: duration_seconds = int(value) except ValueError as exc: - raise ValueError( - f"{CONFIG_KEY} 'duration_seconds' must be an integer; " - f"got {value!r}." - ) from exc + raise ValueError(f"{CONFIG_KEY} 'duration_seconds' must be an integer; " f"got {value!r}.") from exc elif key == "sts_endpoint": sts_endpoint = value elif key == "principal_name": @@ -251,27 +227,19 @@ def parse( # straightforward. aws_debug = value.lower() elif key.startswith(TAG_KEY_PREFIX): - tag_name = key[len(TAG_KEY_PREFIX):] + tag_name = key[len(TAG_KEY_PREFIX) :] if tag_name == "": - raise ValueError( - f"{CONFIG_KEY} tag key {key!r} has empty name." - ) + raise ValueError(f"{CONFIG_KEY} tag key {key!r} has empty name.") if tags is None: tags = {} tags[tag_name] = value # last-wins on duplicate tag names else: - raise ValueError( - f"Unknown key {key!r} in {CONFIG_KEY}." - ) + raise ValueError(f"Unknown key {key!r} in {CONFIG_KEY}.") if region is None: - raise ValueError( - f"'region' is required in {CONFIG_KEY}." - ) + raise ValueError(f"'region' is required in {CONFIG_KEY}.") if audience is None: - raise ValueError( - f"'audience' is required in {CONFIG_KEY}." - ) + raise ValueError(f"'audience' is required in {CONFIG_KEY}.") return cls( region=region, diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py index a2f72c59e..107dd4af0 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py @@ -22,7 +22,6 @@ from . import _aws_jwt_subject_extractor from ._aws_oauthbearer_config import ( AWS_DEBUG_CONSOLE, - AWS_DEBUG_NONE, AwsOAuthBearerConfig, ) @@ -110,23 +109,17 @@ def token( "DurationSeconds": self._cfg.duration_seconds, } if self._cfg.tags: - request_kwargs["Tags"] = [ - {"Key": k, "Value": v} for k, v in self._cfg.tags.items() - ] + request_kwargs["Tags"] = [{"Key": k, "Value": v} for k, v in self._cfg.tags.items()] response = self._sts.get_web_identity_token(**request_kwargs) jwt = response.get("WebIdentityToken") if not isinstance(jwt, str) or not jwt: - raise ValueError( - "STS response missing WebIdentityToken; cannot mint OAUTHBEARER token." - ) + raise ValueError("STS response missing WebIdentityToken; cannot mint OAUTHBEARER token.") expiration = response.get("Expiration") if expiration is None: - raise ValueError( - "STS response missing Expiration; cannot compute token lifetime." - ) + raise ValueError("STS response missing Expiration; cannot compute token lifetime.") # boto3 normalises the timestamp to a tz-aware UTC datetime; # .timestamp() returns epoch seconds as a float. expiry_epoch_seconds = expiration.timestamp() diff --git a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py index d84c27cb9..a1955904e 100644 --- a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py +++ b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py @@ -54,6 +54,7 @@ OAuthBearerCallback = Callable[[str], Tuple[str, float, str, Dict[str, str]]] + def create_handler( sasl_oauthbearer_config: str, sasl_oauthbearer_extensions: Optional[str], diff --git a/tests/_util/test_kv_string_parser.py b/tests/_util/test_kv_string_parser.py index 984139f2a..48700cc17 100644 --- a/tests/_util/test_kv_string_parser.py +++ b/tests/_util/test_kv_string_parser.py @@ -18,7 +18,6 @@ from confluent_kafka._util.kv_string_parser import parse_kv - # ---- Null guards ---- diff --git a/tests/oauthbearer/aws/test_aws_autowire.py b/tests/oauthbearer/aws/test_aws_autowire.py index 7d84afaeb..5c7967c05 100644 --- a/tests/oauthbearer/aws/test_aws_autowire.py +++ b/tests/oauthbearer/aws/test_aws_autowire.py @@ -22,8 +22,7 @@ pytest.importorskip("boto3") -from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler - +from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler # noqa: E402 # ---- Input validation (defensive checks for direct callers) ---- @@ -71,21 +70,24 @@ def test_create_handler_invalid_signing_algorithm_raises(): def test_create_handler_invalid_duration_raises(): with pytest.raises(ValueError, match="duration_seconds"): create_handler( - "region=us-east-1 audience=https://a duration_seconds=10", None, + "region=us-east-1 audience=https://a duration_seconds=10", + None, ) def test_create_handler_unknown_key_raises(): with pytest.raises(ValueError, match="not_a_key"): create_handler( - "region=us-east-1 audience=https://a not_a_key=foo", None, + "region=us-east-1 audience=https://a not_a_key=foo", + None, ) def test_create_handler_invalid_extensions_grammar_raises(): with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): create_handler( - "region=us-east-1 audience=https://a", "noEqualsHere", + "region=us-east-1 audience=https://a", + "noEqualsHere", ) @@ -93,7 +95,8 @@ def test_create_handler_aws_debug_dotnet_only_value_raises_with_platform_hint(): """log4net / systemdiagnostics are .NET-only sinks; surface that clearly.""" with pytest.raises(ValueError, match="not supported on this platform"): create_handler( - "region=us-east-1 audience=https://a aws_debug=log4net", None, + "region=us-east-1 audience=https://a aws_debug=log4net", + None, ) @@ -102,7 +105,8 @@ def test_create_handler_aws_debug_dotnet_only_value_raises_with_platform_hint(): def test_create_handler_marker_only_minimum_config_returns_handler(): handler = create_handler( - "region=us-east-1 audience=https://a", None, + "region=us-east-1 audience=https://a", + None, ) assert handler is not None assert callable(handler) @@ -143,21 +147,24 @@ def test_create_handler_principal_name_override_handler_ready(): def test_create_handler_null_extensions_treats_as_absent(): handler = create_handler( - "region=us-east-1 audience=https://a", None, + "region=us-east-1 audience=https://a", + None, ) assert callable(handler) def test_create_handler_empty_extensions_treats_as_absent(): handler = create_handler( - "region=us-east-1 audience=https://a", "", + "region=us-east-1 audience=https://a", + "", ) assert callable(handler) def test_create_handler_single_extension_handler_ready(): handler = create_handler( - "region=us-east-1 audience=https://a", "logicalCluster=lkc-abc", + "region=us-east-1 audience=https://a", + "logicalCluster=lkc-abc", ) assert callable(handler) @@ -182,6 +189,7 @@ def test_create_handler_does_not_call_sts_at_construction(): Verify by patching boto3.Session to detect any client.get_web_identity_token call. """ from unittest.mock import MagicMock, patch + with patch("boto3.Session") as mock_session_cls: mock_session = mock_session_cls.return_value mock_client = MagicMock() @@ -193,8 +201,8 @@ def test_create_handler_does_not_call_sts_at_construction(): def test_create_handler_returned_callable_when_invoked_calls_sts(): """When the returned callable is invoked, it triggers exactly one STS call.""" - from unittest.mock import MagicMock, patch import datetime + from unittest.mock import MagicMock, patch canned_response = { "WebIdentityToken": _canned_jwt(), @@ -222,8 +230,8 @@ def test_create_handler_returned_callable_when_invoked_calls_sts(): def test_create_handler_returned_callable_round_trips_extensions(): """Extensions configured via the typed property flow through to the 4-tuple's extensions slot.""" - from unittest.mock import MagicMock, patch import datetime + from unittest.mock import MagicMock, patch canned_response = { "WebIdentityToken": _canned_jwt(), @@ -252,13 +260,8 @@ def test_create_handler_returned_callable_round_trips_extensions(): def _base64url(data: bytes) -> str: import base64 - return ( - base64.b64encode(data) - .decode("ascii") - .rstrip("=") - .replace("+", "-") - .replace("/", "_") - ) + + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") def _canned_jwt(sub: str = "arn:aws:iam::123:role/R") -> str: diff --git a/tests/oauthbearer/aws/test_aws_iam_marker.py b/tests/oauthbearer/aws/test_aws_iam_marker.py index a40c78f53..079247619 100644 --- a/tests/oauthbearer/aws/test_aws_iam_marker.py +++ b/tests/oauthbearer/aws/test_aws_iam_marker.py @@ -30,16 +30,19 @@ def test_marker_value_is_locked_value(): def test_c_dispatcher_recognises_python_authoritative_marker(): import pytest + from confluent_kafka import Producer with pytest.raises(ValueError, match="method=oidc"): - Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - AWS_IAM_MARKER_KEY: AWS_IAM_MARKER_VALUE, - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - # Deliberately omitting sasl.oauthbearer.method to trigger the - # dispatcher's precondition check. If the dispatcher's C-side - # marker literals differ from AWS_IAM_MARKER_KEY/VALUE, the - # dispatcher won't fire and this ValueError won't raise. - }) + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + AWS_IAM_MARKER_KEY: AWS_IAM_MARKER_VALUE, + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + # Deliberately omitting sasl.oauthbearer.method to trigger the + # dispatcher's precondition check. If the dispatcher's C-side + # marker literals differ from AWS_IAM_MARKER_KEY/VALUE, the + # dispatcher won't fire and this ValueError won't raise. + } + ) diff --git a/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py b/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py index 1e6078818..724fb3584 100644 --- a/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py +++ b/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py @@ -20,7 +20,6 @@ from confluent_kafka.oauthbearer.aws._aws_jwt_subject_extractor import extract_sub - # ---- Test helpers ---- @@ -102,10 +101,7 @@ def test_extract_sub_assumed_role_arn_returned(): def test_extract_sub_other_claims_ignored(): - jwt = _make_jwt( - '{"iss":"https://x","sub":"arn:aws:iam::1:role/R",' - '"aud":"a","exp":1,"iat":0,"jti":"j"}' - ) + jwt = _make_jwt('{"iss":"https://x","sub":"arn:aws:iam::1:role/R",' '"aud":"a","exp":1,"iat":0,"jti":"j"}') assert extract_sub(jwt) == "arn:aws:iam::1:role/R" @@ -127,7 +123,6 @@ def test_extract_sub_padded_base64url_also_works(): header = _base64url_encode(b'{"alg":"none"}') # Encode WITHOUT stripping padding. payload = base64.b64encode(b'{"sub":"abc"}').decode("ascii").replace("+", "-").replace("/", "_") - jwt = f"{header}.{payload}." jwt_3seg = f"{header}.{payload}.sig" assert extract_sub(jwt_3seg) == "abc" @@ -140,11 +135,14 @@ def test_extract_sub_url_safe_chars_handled(): assert extract_sub(jwt) == "x" -@pytest.mark.parametrize("payload_json,expected_sub", [ - ('{"sub":"a"}', "a"), # → encoded length % 4 = 3 (one '=' padded) - ('{"sub":"ab"}', "ab"), # → encoded length % 4 = 0 (no padding) - ('{"sub":"abc"}', "abc"), # → encoded length % 4 = 2 (two '==' padded) -]) +@pytest.mark.parametrize( + "payload_json,expected_sub", + [ + ('{"sub":"a"}', "a"), # → encoded length % 4 = 3 (one '=' padded) + ('{"sub":"ab"}', "ab"), # → encoded length % 4 = 0 (no padding) + ('{"sub":"abc"}', "abc"), # → encoded length % 4 = 2 (two '==' padded) + ], +) def test_extract_sub_padding_branches_all_hit_decodes_correctly(payload_json, expected_sub): """Explicit per-branch coverage of the % 4 padding switch.""" assert extract_sub(_make_jwt(payload_json)) == expected_sub diff --git a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py index 6629a611c..8868d66d9 100644 --- a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py +++ b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py @@ -22,7 +22,6 @@ AwsOAuthBearerConfig, ) - # ---- Required fields ---- @@ -89,25 +88,19 @@ def test_parse_optional_fields_default_to_none(): @pytest.mark.parametrize("seconds", [60, 300, 3600]) def test_parse_duration_in_range_accepted(seconds): - cfg = AwsOAuthBearerConfig.parse( - f"region=us-east-1 audience=https://a duration_seconds={seconds}" - ) + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a duration_seconds={seconds}") assert cfg.duration_seconds == seconds @pytest.mark.parametrize("seconds", [0, 59, 3601, -10]) def test_parse_duration_out_of_range_raises(seconds): with pytest.raises(ValueError, match="duration_seconds.*must be between"): - AwsOAuthBearerConfig.parse( - f"region=us-east-1 audience=https://a duration_seconds={seconds}" - ) + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a duration_seconds={seconds}") def test_parse_duration_not_integer_raises(): with pytest.raises(ValueError, match="duration_seconds.*must be an integer"): - AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a duration_seconds=abc" - ) + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a duration_seconds=abc") # ---- signing_algorithm ---- @@ -115,18 +108,14 @@ def test_parse_duration_not_integer_raises(): @pytest.mark.parametrize("alg", ["ES384", "RS256"]) def test_parse_allowed_signing_algorithm_accepted(alg): - cfg = AwsOAuthBearerConfig.parse( - f"region=us-east-1 audience=https://a signing_algorithm={alg}" - ) + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a signing_algorithm={alg}") assert cfg.signing_algorithm == alg @pytest.mark.parametrize("alg", ["HS256", "es384", "RS512"]) def test_parse_disallowed_signing_algorithm_raises(alg): with pytest.raises(ValueError, match="signing_algorithm.*ES384.*RS256"): - AwsOAuthBearerConfig.parse( - f"region=us-east-1 audience=https://a signing_algorithm={alg}" - ) + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a signing_algorithm={alg}") # ---- aws_debug (Python subset: none/console only) ---- @@ -137,26 +126,28 @@ def test_parse_no_aws_debug_defaults_to_none(): assert cfg.aws_debug == AWS_DEBUG_NONE -@pytest.mark.parametrize("value,expected", [ - ("none", AWS_DEBUG_NONE), - ("console", AWS_DEBUG_CONSOLE), -]) +@pytest.mark.parametrize( + "value,expected", + [ + ("none", AWS_DEBUG_NONE), + ("console", AWS_DEBUG_CONSOLE), + ], +) def test_parse_aws_debug_values_accepted(value, expected): - cfg = AwsOAuthBearerConfig.parse( - f"region=us-east-1 audience=https://a aws_debug={value}" - ) + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") assert cfg.aws_debug == expected -@pytest.mark.parametrize("value,expected", [ - ("Console", AWS_DEBUG_CONSOLE), - ("CONSOLE", AWS_DEBUG_CONSOLE), - ("NONE", AWS_DEBUG_NONE), -]) +@pytest.mark.parametrize( + "value,expected", + [ + ("Console", AWS_DEBUG_CONSOLE), + ("CONSOLE", AWS_DEBUG_CONSOLE), + ("NONE", AWS_DEBUG_NONE), + ], +) def test_parse_aws_debug_case_insensitive_accepted(value, expected): - cfg = AwsOAuthBearerConfig.parse( - f"region=us-east-1 audience=https://a aws_debug={value}" - ) + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") assert cfg.aws_debug == expected @@ -165,24 +156,18 @@ def test_parse_aws_debug_dotnet_only_values_rejected_with_clear_message(value): """log4net / systemdiagnostics are .NET sinks; Python rejects with a platform-clarity message so cross-language users see the constraint.""" with pytest.raises(ValueError, match="not supported on this platform"): - AwsOAuthBearerConfig.parse( - f"region=us-east-1 audience=https://a aws_debug={value}" - ) + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") @pytest.mark.parametrize("value", ["verbose", "etw", "debug", "true", "foo"]) def test_parse_aws_debug_invalid_value_raises(value): with pytest.raises(ValueError, match="aws_debug.*none, console"): - AwsOAuthBearerConfig.parse( - f"region=us-east-1 audience=https://a aws_debug={value}" - ) + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") def test_parse_aws_debug_empty_value_raises(): with pytest.raises(ValueError, match="aws_debug.*must not be empty"): - AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a aws_debug=" - ) + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a aws_debug=") # ---- sts_endpoint, principal_name ---- @@ -190,16 +175,13 @@ def test_parse_aws_debug_empty_value_raises(): def test_parse_sts_endpoint_stored_verbatim(): cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a " - "sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" + "region=us-east-1 audience=https://a " "sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" ) assert cfg.sts_endpoint == "https://sts-fips.us-east-1.amazonaws.com" def test_parse_principal_name_stored_verbatim(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a principal_name=my-principal" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a principal_name=my-principal") assert cfg.principal_name == "my-principal" @@ -210,23 +192,17 @@ def test_parse_extension_prefix_in_config_rejected_as_unknown_key(): """sasl extensions are accepted via typed sasl.oauthbearer.extensions property only; embedded extension_* keys are rejected.""" with pytest.raises(ValueError, match="extension_logicalCluster"): - AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a extension_logicalCluster=lkc-abc" - ) + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a extension_logicalCluster=lkc-abc") def test_parse_sasl_extensions_arg_stored_on_config(): ext = {"logicalCluster": "lkc-abc", "identityPoolId": "pool-xyz"} - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a", ext - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a", ext) assert cfg.sasl_extensions is ext def test_parse_sasl_extensions_arg_null_keeps_sasl_extensions_null(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a", None - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a", None) assert cfg.sasl_extensions is None @@ -234,40 +210,28 @@ def test_parse_sasl_extensions_arg_null_keeps_sasl_extensions_null(): def test_parse_single_tag_collected_into_tags(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a tag_team=platform" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_team=platform") assert cfg.tags == {"team": "platform"} def test_parse_multiple_tags_all_collected(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a " - "tag_team=platform " - "tag_environment=prod" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a " "tag_team=platform " "tag_environment=prod") assert cfg.tags == {"team": "platform", "environment": "prod"} def test_parse_empty_tag_name_raises(): with pytest.raises(ValueError, match="empty name"): - AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a tag_=value" - ) + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_=value") def test_parse_empty_tag_value_accepted(): # AWS allows tag values of 0 chars; mirror that. - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a tag_team=" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_team=") assert cfg.tags == {"team": ""} def test_parse_duplicate_tag_name_last_wins(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a tag_team=infra tag_team=platform" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_team=infra tag_team=platform") assert cfg.tags == {"team": "platform"} @@ -290,18 +254,14 @@ def test_parse_over_max_tags_raises(): def test_parse_unknown_key_raises(): with pytest.raises(ValueError, match="Unknown key.*not_a_key"): - AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a not_a_key=foo" - ) + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a not_a_key=foo") # ---- Whitespace / ordering ---- def test_parse_tabs_and_multiple_spaces_tolerated(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1\taudience=https://a duration_seconds=600" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1\taudience=https://a duration_seconds=600") assert cfg.region == "us-east-1" assert cfg.audience == "https://a" assert cfg.duration_seconds == 600 @@ -314,9 +274,7 @@ def test_parse_leading_and_trailing_whitespace_tolerated(): def test_parse_order_invariant(): - cfg = AwsOAuthBearerConfig.parse( - "duration_seconds=600 audience=https://a region=us-east-1 signing_algorithm=RS256" - ) + cfg = AwsOAuthBearerConfig.parse("duration_seconds=600 audience=https://a region=us-east-1 signing_algorithm=RS256") assert cfg.region == "us-east-1" assert cfg.audience == "https://a" assert cfg.duration_seconds == 600 @@ -324,9 +282,7 @@ def test_parse_order_invariant(): def test_parse_duplicate_key_last_wins(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a region=us-west-2" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a region=us-west-2") assert cfg.region == "us-west-2" @@ -381,17 +337,13 @@ def test_direct_construction_with_empty_region_raises(): def test_direct_construction_with_out_of_range_duration_raises(): with pytest.raises(ValueError, match="duration_seconds.*must be between"): - AwsOAuthBearerConfig( - region="us-east-1", audience="https://a", duration_seconds=10 - ) + AwsOAuthBearerConfig(region="us-east-1", audience="https://a", duration_seconds=10) def test_direct_construction_with_bool_duration_rejected(): """bool is-a int in Python; reject explicitly so True/False doesn't pass.""" with pytest.raises(ValueError, match="duration_seconds.*must be an integer"): - AwsOAuthBearerConfig( - region="us-east-1", audience="https://a", duration_seconds=True - ) + AwsOAuthBearerConfig(region="us-east-1", audience="https://a", duration_seconds=True) # ---- Surface invariants ---- @@ -401,4 +353,5 @@ def test_aws_oauth_bearer_config_not_exposed_via_subpackage_init(): """Public surface stays minimal — AwsOAuthBearerConfig is private to the autowire layer (only `create_handler` is publicly importable).""" import confluent_kafka.oauthbearer.aws as aws_pkg + assert not hasattr(aws_pkg, "AwsOAuthBearerConfig") diff --git a/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py b/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py index 7305c5432..40f72043e 100644 --- a/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py +++ b/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py @@ -18,7 +18,6 @@ from confluent_kafka.oauthbearer.aws._aws_sasl_extensions_parser import parse - # ---- Null / empty input → None ---- diff --git a/tests/oauthbearer/aws/test_aws_sts_token_provider.py b/tests/oauthbearer/aws/test_aws_sts_token_provider.py index fd8e9122f..ee99535e6 100644 --- a/tests/oauthbearer/aws/test_aws_sts_token_provider.py +++ b/tests/oauthbearer/aws/test_aws_sts_token_provider.py @@ -33,15 +33,10 @@ pytest.importorskip("boto3") pytest.importorskip("botocore") -from botocore.exceptions import ClientError - -from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import ( - AwsOAuthBearerConfig, -) -from confluent_kafka.oauthbearer.aws._aws_sts_token_provider import ( - AwsStsTokenProvider, -) +from botocore.exceptions import ClientError # noqa: E402 +from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import AwsOAuthBearerConfig # noqa: E402 +from confluent_kafka.oauthbearer.aws._aws_sts_token_provider import AwsStsTokenProvider # noqa: E402 # ---- Test helpers ---- @@ -61,7 +56,14 @@ def _canned_jwt(sub: str = _ROLE_ARN) -> str: _CANNED_JWT = _canned_jwt() _CANNED_EXPIRY = datetime.datetime( - 2099, 4, 21, 6, 6, 47, 641_000, tzinfo=datetime.timezone.utc, + 2099, + 4, + 21, + 6, + 6, + 47, + 641_000, + tzinfo=datetime.timezone.utc, ) @@ -123,18 +125,14 @@ def test_ctor_no_aws_debug_does_not_mutate_botocore_logger(): def test_ctor_aws_debug_none_does_not_mutate_botocore_logger(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a aws_debug=none" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a aws_debug=none") with patch("boto3.set_stream_logger") as mock_setter: AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) mock_setter.assert_not_called() def test_ctor_aws_debug_console_routes_botocore_logger_to_stream(): - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a aws_debug=console" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a aws_debug=console") with patch("boto3.set_stream_logger") as mock_setter: AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) mock_setter.assert_called_once_with("botocore", logging.DEBUG) @@ -145,9 +143,7 @@ def test_ctor_aws_debug_console_routes_botocore_logger_to_stream(): def test_token_audience_passthrough(): fake = FakeStsClient() - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://my.audience" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://my.audience") provider = AwsStsTokenProvider(cfg, sts_client=fake) provider.token() assert fake.last_request["Audience"] == ["https://my.audience"] @@ -155,9 +151,7 @@ def test_token_audience_passthrough(): def test_token_signing_algorithm_passthrough(): fake = FakeStsClient() - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a signing_algorithm=RS256" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a signing_algorithm=RS256") provider = AwsStsTokenProvider(cfg, sts_client=fake) provider.token() assert fake.last_request["SigningAlgorithm"] == "RS256" @@ -165,9 +159,7 @@ def test_token_signing_algorithm_passthrough(): def test_token_duration_seconds_passthrough(): fake = FakeStsClient() - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a duration_seconds=900" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a duration_seconds=900") provider = AwsStsTokenProvider(cfg, sts_client=fake) provider.token() assert fake.last_request["DurationSeconds"] == 900 @@ -191,9 +183,7 @@ def test_token_default_signing_algorithm_sends_es384(): def test_token_tags_passthrough(): fake = FakeStsClient() - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a tag_team=platform tag_environment=prod" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_team=platform tag_environment=prod") provider = AwsStsTokenProvider(cfg, sts_client=fake) provider.token() @@ -229,9 +219,7 @@ def test_token_returns_mapped_fields(): def test_token_principal_name_override_wins_over_jwt_sub(): fake = FakeStsClient() - cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a principal_name=explicit-principal" - ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a principal_name=explicit-principal") provider = AwsStsTokenProvider(cfg, sts_client=fake) _, _, principal, _ = provider.token() assert principal == "explicit-principal" @@ -241,7 +229,8 @@ def test_token_sasl_extensions_passthrough(): fake = FakeStsClient() sasl_extensions = {"logicalCluster": "lkc-123", "identityPoolId": "pool-x"} cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a", sasl_extensions, + "region=us-east-1 audience=https://a", + sasl_extensions, ) provider = AwsStsTokenProvider(cfg, sts_client=fake) _, _, _, extensions = provider.token() @@ -278,9 +267,7 @@ def test_token_expiry_is_epoch_seconds_float(): def test_token_missing_expiration_raises(): """STS response without Expiration → ValueError surfaced as rd_kafka_oauthbearer_set_token_failure by the C wrapper.""" - fake = FakeStsClient( - responder=lambda kwargs: {"WebIdentityToken": _CANNED_JWT} - ) + fake = FakeStsClient(responder=lambda kwargs: {"WebIdentityToken": _CANNED_JWT}) cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") provider = AwsStsTokenProvider(cfg, sts_client=fake) with pytest.raises(ValueError, match="Expiration"): @@ -288,9 +275,7 @@ def test_token_missing_expiration_raises(): def test_token_missing_token_value_raises(): - fake = FakeStsClient( - responder=lambda kwargs: {"Expiration": _CANNED_EXPIRY} - ) + fake = FakeStsClient(responder=lambda kwargs: {"Expiration": _CANNED_EXPIRY}) cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") provider = AwsStsTokenProvider(cfg, sts_client=fake) with pytest.raises(ValueError, match="WebIdentityToken"): @@ -345,10 +330,7 @@ def test_token_outbound_federation_disabled_propagates(): provider = AwsStsTokenProvider(cfg, sts_client=fake) with pytest.raises(ClientError) as exc_info: provider.token() - assert ( - exc_info.value.response["Error"]["Code"] - == "OutboundWebIdentityFederationDisabledException" - ) + assert exc_info.value.response["Error"]["Code"] == "OutboundWebIdentityFederationDisabledException" # ---- Surface invariants ---- @@ -357,6 +339,7 @@ def test_token_outbound_federation_disabled_propagates(): def test_aws_sts_token_provider_not_exposed_via_subpackage_init(): """Public surface stays minimal — AwsStsTokenProvider is private.""" import confluent_kafka.oauthbearer.aws as aws_pkg + assert not hasattr(aws_pkg, "AwsStsTokenProvider") @@ -380,8 +363,7 @@ def test_ctor_sts_endpoint_plumbed_to_boto3_client(): """When sts_endpoint is set on config, boto3.Session().client('sts', ...) receives an endpoint_url kwarg.""" cfg = AwsOAuthBearerConfig.parse( - "region=us-east-1 audience=https://a " - "sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" + "region=us-east-1 audience=https://a " "sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" ) with patch("boto3.Session") as mock_session_cls: mock_session = mock_session_cls.return_value @@ -399,5 +381,6 @@ def test_ctor_no_sts_endpoint_omits_endpoint_url_kwarg(): mock_session = mock_session_cls.return_value AwsStsTokenProvider(cfg) mock_session.client.assert_called_once_with( - "sts", region_name="us-east-1", + "sts", + region_name="us-east-1", ) diff --git a/tests/oauthbearer/aws/test_contract.py b/tests/oauthbearer/aws/test_contract.py index 23d06da06..9d5a02472 100644 --- a/tests/oauthbearer/aws/test_contract.py +++ b/tests/oauthbearer/aws/test_contract.py @@ -25,12 +25,8 @@ pytest.importorskip("boto3") -from confluent_kafka.oauthbearer.aws import aws_autowire -from confluent_kafka.oauthbearer.aws.aws_autowire import ( - OAuthBearerCallback, - create_handler, -) - +from confluent_kafka.oauthbearer.aws import aws_autowire # noqa: E402 +from confluent_kafka.oauthbearer.aws.aws_autowire import OAuthBearerCallback, create_handler # noqa: E402 # ---- Module surface ---- @@ -38,6 +34,7 @@ def test_module_importable_at_canonical_path(): """The C dispatcher does PyImport_ImportModule(...) with this exact path.""" import importlib + mod = importlib.import_module("confluent_kafka.oauthbearer.aws.aws_autowire") assert mod is aws_autowire @@ -74,10 +71,7 @@ def test_create_handler_parameter_names_are_frozen(): def test_create_handler_parameter_annotations_are_frozen(): sig = inspect.signature(create_handler) assert sig.parameters["sasl_oauthbearer_config"].annotation is str - assert ( - sig.parameters["sasl_oauthbearer_extensions"].annotation - == Optional[str] - ) + assert sig.parameters["sasl_oauthbearer_extensions"].annotation == Optional[str] def test_create_handler_no_default_values(): diff --git a/tests/oauthbearer/aws/test_dispatch.py b/tests/oauthbearer/aws/test_dispatch.py index 143379f52..a5feb852b 100644 --- a/tests/oauthbearer/aws/test_dispatch.py +++ b/tests/oauthbearer/aws/test_dispatch.py @@ -24,7 +24,6 @@ import base64 import datetime -import importlib import sys from typing import Any, Dict, Optional from unittest.mock import MagicMock, patch @@ -36,7 +35,6 @@ import confluent_kafka # noqa: E402 from confluent_kafka.admin import AdminClient # noqa: E402 - # ---- Test helpers ---- @@ -54,7 +52,13 @@ def _canned_response() -> Dict[str, Any]: return { "WebIdentityToken": _canned_jwt(), "Expiration": datetime.datetime( - 2099, 4, 21, 6, 6, 47, tzinfo=datetime.timezone.utc, + 2099, + 4, + 21, + 6, + 6, + 47, + tzinfo=datetime.timezone.utc, ), } @@ -99,10 +103,12 @@ def test_marker_absent_producer_constructs_unchanged(): def test_marker_absent_consumer_constructs_unchanged(): - c = confluent_kafka.Consumer({ - "bootstrap.servers": "localhost:9092", - "group.id": "test-group", - }) + c = confluent_kafka.Consumer( + { + "bootstrap.servers": "localhost:9092", + "group.id": "test-group", + } + ) c.close() @@ -124,8 +130,7 @@ def test_marker_absent_does_not_import_aws_modules(): "confluent_kafka.oauthbearer.aws._aws_sts_token_provider", ]: assert name not in sys.modules, ( - f"Dispatcher imported {name} when no marker was set — " - "this means the no-op short-circuit broke." + f"Dispatcher imported {name} when no marker was set — " "this means the no-op short-circuit broke." ) @@ -138,12 +143,14 @@ def test_other_marker_value_passes_through_unchanged(): # to NOT raise our ValueError. (librdkafka itself may complain about the # value, but our dispatcher must not.) with pytest.raises(Exception) as exc_info: - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "azure_imds", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "azure_imds", + } + ) # Whatever librdkafka raises, it must NOT be our method-requirement # error or our config-requirement error. msg = str(exc_info.value) @@ -156,35 +163,41 @@ def test_other_marker_value_passes_through_unchanged(): def test_marker_without_method_raises(): with pytest.raises(ValueError, match="method=oidc"): - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) def test_marker_with_method_default_raises(): with pytest.raises(ValueError, match="method=oidc"): - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "default", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "default", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) def test_marker_with_method_oidc_uppercase_raises(): """Strict matching — librdkafka's method values are lowercase canonical.""" with pytest.raises(ValueError, match="method=oidc"): - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "OIDC", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "OIDC", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) # ---- 4. sasl.oauthbearer.config requirement ---- @@ -192,23 +205,27 @@ def test_marker_with_method_oidc_uppercase_raises(): def test_marker_with_method_oidc_but_missing_config_raises(): with pytest.raises(ValueError, match="sasl.oauthbearer.config.*missing or empty"): - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + } + ) def test_marker_with_method_oidc_but_empty_config_raises(): with pytest.raises(ValueError, match="sasl.oauthbearer.config.*missing or empty"): - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "", + } + ) # ---- 5. Happy path: marker + method=oidc + valid config + boto3 stubbed ---- @@ -221,9 +238,7 @@ def test_happy_path_producer_constructs_with_marker(mocked_boto3): def test_happy_path_consumer_constructs_with_marker(mocked_boto3): - c = confluent_kafka.Consumer( - _minimal_aws_iam_config({"group.id": "test-group"}) - ) + c = confluent_kafka.Consumer(_minimal_aws_iam_config({"group.id": "test-group"})) assert c is not None c.close() @@ -240,6 +255,7 @@ async def test_happy_path_aio_producer_constructs_with_marker(mocked_boto3): must run inside an async context (pyproject.toml's asyncio_mode=auto handles that for async def test functions).""" from confluent_kafka.aio import AIOProducer + p = AIOProducer(_minimal_aws_iam_config()) assert p is not None @@ -247,6 +263,7 @@ async def test_happy_path_aio_producer_constructs_with_marker(mocked_boto3): async def test_happy_path_aio_consumer_constructs_with_marker(mocked_boto3): """Same async-context requirement as AIOProducer.""" from confluent_kafka.aio import AIOConsumer + c = AIOConsumer(_minimal_aws_iam_config({"group.id": "test-group"})) assert c is not None @@ -268,15 +285,17 @@ def user_oauth_cb(config_str): if name.startswith("confluent_kafka.oauthbearer.aws"): del sys.modules[name] - p = confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "security.protocol": "SASL_SSL", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - "oauth_cb": user_oauth_cb, - }) + p = confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + "oauth_cb": user_oauth_cb, + } + ) assert p is not None p.flush(timeout=0.1) # The autowire module must not have been imported. @@ -289,36 +308,42 @@ def user_oauth_cb(config_str): def test_marker_with_invalid_config_grammar_raises(mocked_boto3): """Config parser ValueError surfaces through the dispatcher.""" with pytest.raises(ValueError, match="Unknown key.*not_a_key"): - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a not_a_key=foo", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a not_a_key=foo", + } + ) def test_marker_with_invalid_signing_algorithm_raises(mocked_boto3): with pytest.raises(ValueError, match="signing_algorithm"): - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a signing_algorithm=HS256", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a signing_algorithm=HS256", + } + ) def test_marker_with_invalid_extensions_grammar_raises(mocked_boto3): with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - "sasl.oauthbearer.extensions": "malformed-no-equals", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + "sasl.oauthbearer.extensions": "malformed-no-equals", + } + ) # ---- 8. Friendly ImportError when the optional extra is missing ---- @@ -346,13 +371,15 @@ def test_marker_with_missing_extra_raises_friendly_import_error(boto3_absent): ModuleNotFoundError from the import chain and rewrites it into a friendly install hint. __cause__ chain preserves the original.""" with pytest.raises(ImportError) as exc_info: - confluent_kafka.Producer({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - }) + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) msg = str(exc_info.value) assert "oauthbearer-aws" in msg assert "pip install" in msg @@ -363,25 +390,29 @@ def test_marker_with_missing_extra_raises_friendly_import_error(boto3_absent): def test_friendly_import_error_on_consumer_too(boto3_absent): with pytest.raises(ImportError, match="oauthbearer-aws"): - confluent_kafka.Consumer({ - "bootstrap.servers": "broker.invalid:9092", - "group.id": "g", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - }) + confluent_kafka.Consumer( + { + "bootstrap.servers": "broker.invalid:9092", + "group.id": "g", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) def test_friendly_import_error_on_admin_client_too(boto3_absent): with pytest.raises(ImportError, match="oauthbearer-aws"): - AdminClient({ - "bootstrap.servers": "broker.invalid:9092", - "sasl.mechanisms": "OAUTHBEARER", - "sasl.oauthbearer.method": "oidc", - "sasl.oauthbearer.metadata.authentication.type": "aws_iam", - "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", - }) + AdminClient( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) # ---- 9. Marker is stripped before native handoff ---- @@ -406,17 +437,25 @@ def test_marker_with_extensions_plumbs_through(mocked_boto3): """The dispatcher reads sasl.oauthbearer.extensions and passes it as the second arg to create_handler. Verify by checking the returned callable yields the extensions in its 4-tuple result.""" - p = confluent_kafka.Producer(_minimal_aws_iam_config({ - "sasl.oauthbearer.extensions": "logicalCluster=lkc-123", - })) + p = confluent_kafka.Producer( + _minimal_aws_iam_config( + { + "sasl.oauthbearer.extensions": "logicalCluster=lkc-123", + } + ) + ) assert p is not None p.flush(timeout=0.1) def test_marker_with_empty_extensions_treated_as_absent(mocked_boto3): """Empty extensions string → autowire treats as None (no extensions).""" - p = confluent_kafka.Producer(_minimal_aws_iam_config({ - "sasl.oauthbearer.extensions": "", - })) + p = confluent_kafka.Producer( + _minimal_aws_iam_config( + { + "sasl.oauthbearer.extensions": "", + } + ) + ) assert p is not None p.flush(timeout=0.1) diff --git a/tests/oauthbearer/aws/test_real_sts.py b/tests/oauthbearer/aws/test_real_sts.py index bc0bde643..83dab6705 100644 --- a/tests/oauthbearer/aws/test_real_sts.py +++ b/tests/oauthbearer/aws/test_real_sts.py @@ -66,7 +66,6 @@ from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler # noqa: E402 - # ============================================================================= # Configuration sourced from env vars so the test can be re-pointed at a # different role / audience without code changes. @@ -135,9 +134,7 @@ def test_create_handler_mints_valid_jwt(): # Expiry within the requested window. now = time.time() assert expiry > now, "expiry must be in the future" - assert expiry < now + 10 * 60, ( - "expiry must be within 10 min for DurationSeconds <= 300" - ) + assert expiry < now + 10 * 60, "expiry must be within 10 min for DurationSeconds <= 300" # Principal is a bare role ARN (AWS STS GetWebIdentityToken convention). assert _ARN_PATTERN.match(principal), f"unexpected principal: {principal!r}" @@ -146,10 +143,7 @@ def test_create_handler_mints_valid_jwt(): assert extensions == {} # Diagnostic for the run log so EC2 evidence is captured in CI / PR. - print( - f"\n[autowire-real] jwt_length={len(token)} " - f"principal={principal} expires_in={int(expiry - now)}s" - ) + print(f"\n[autowire-real] jwt_length={len(token)} " f"principal={principal} expires_in={int(expiry - now)}s") def test_create_handler_jwt_length_matches_cross_language(): @@ -178,9 +172,7 @@ def test_create_handler_jwt_length_matches_cross_language(): expected = int(expected_str) handler = create_handler(_default_config(), None) token, *_ = handler("") - assert len(token) == expected, ( - f"expected {expected} bytes (cross-language match), got {len(token)}" - ) + assert len(token) == expected, f"expected {expected} bytes (cross-language match), got {len(token)}" def test_create_handler_principal_matches_role_arn(): @@ -213,9 +205,7 @@ def test_create_handler_honours_duration_seconds(): handler = create_handler(_default_config(duration_seconds="60"), None) _, expiry, _, _ = handler("") seconds_remaining = expiry - time.time() - assert 50 <= seconds_remaining <= 80, ( - f"duration_seconds=60 → expected ~60s window, got {seconds_remaining:.0f}s" - ) + assert 50 <= seconds_remaining <= 80, f"duration_seconds=60 → expected ~60s window, got {seconds_remaining:.0f}s" def test_create_handler_honours_signing_algorithm_rs256(): @@ -230,25 +220,20 @@ def test_create_handler_honours_signing_algorithm_rs256(): header_b64url = token.split(".")[0] padding = "=" * (-len(header_b64url) % 4) header_json = json.loads( - base64.b64decode( - (header_b64url + padding).replace("-", "+").replace("_", "/") - ).decode("utf-8") - ) - assert header_json.get("alg") == "RS256", ( - f"header.alg expected RS256, got {header_json.get('alg')!r}" + base64.b64decode((header_b64url + padding).replace("-", "+").replace("_", "/")).decode("utf-8") ) + assert header_json.get("alg") == "RS256", f"header.alg expected RS256, got {header_json.get('alg')!r}" def test_create_handler_honours_principal_name_override(): """Setting principal_name=... in the wire grammar overrides the JWT 'sub' extraction at the autowire layer.""" handler = create_handler( - _default_config(principal_name="custom-principal"), None, + _default_config(principal_name="custom-principal"), + None, ) _, _, principal, _ = handler("") - assert principal == "custom-principal", ( - f"principal_name override not honoured: got {principal!r}" - ) + assert principal == "custom-principal", f"principal_name override not honoured: got {principal!r}" def test_create_handler_round_trips_sasl_extensions(): @@ -289,7 +274,8 @@ def test_create_handler_tag_claims_flow_to_sts(): JWT — only meaningful on roles with ``sts:TagSession`` allowed. """ handler = create_handler( - _default_config(tag_team="platform", tag_environment="prod"), None, + _default_config(tag_team="platform", tag_environment="prod"), + None, ) # If our Tags param breaks the STS call, this raises. token, *_ = handler("") @@ -333,9 +319,7 @@ def test_create_handler_jwt_audience_matches_request(): handler = create_handler(_default_config(), None) token, *_ = handler("") payload = _decode_jwt_payload(token) - assert payload.get("aud") == _AUDIENCE, ( - f"JWT aud {payload.get('aud')!r} doesn't match requested {_AUDIENCE!r}" - ) + assert payload.get("aud") == _AUDIENCE, f"JWT aud {payload.get('aud')!r} doesn't match requested {_AUDIENCE!r}" # ============================================================================= @@ -399,11 +383,5 @@ def test_producer_with_marker_and_extensions_succeeds(): } ) elapsed = time.time() - t0 - assert elapsed < 5.0, ( - f"Producer with extensions took {elapsed:.1f}s — autowire path " - f"slow or broken" - ) - print( - f"\n[autowire-real] Producer (with extensions) constructed via " - f"dispatcher in {elapsed:.2f}s" - ) + assert elapsed < 5.0, f"Producer with extensions took {elapsed:.1f}s — autowire path " f"slow or broken" + print(f"\n[autowire-real] Producer (with extensions) constructed via " f"dispatcher in {elapsed:.2f}s") From 92c908d4f260dada2592d70d1da97593708e2f21 Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Mon, 8 Jun 2026 13:48:26 +0530 Subject: [PATCH 09/21] Add boto3 version check --- .../aws/_aws_sts_token_provider.py | 41 +++++++++ .../oauthbearer/aws/aws_autowire.py | 3 + .../aws/test_aws_sts_token_provider.py | 85 ++++++++++++++++++- 3 files changed, 128 insertions(+), 1 deletion(-) diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py index 107dd4af0..0347db5ac 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py @@ -32,6 +32,41 @@ # credential-chain / signing diagnostic logs to stderr at DEBUG level. _BOTOCORE_LOGGER_NAME = "botocore" +#: Minimum boto3 version required by the AWS IAM path. +#: ``requirements/requirements-oauthbearer-aws.txt``. +MINIMUM_BOTO3_VERSION = "1.42.25" + + +def _version_tuple(version: str) -> Tuple[int, ...]: + """Parse a dotted version string into a tuple of its leading integers. + + Tolerant of pre-release suffixes (``"1.42.0rc1"`` -> ``(1, 42, 0)``) so the + comparison considers only the numeric ``major.minor.micro`` components. + """ + parts = [] + for segment in version.split("."): + digits = "" + for ch in segment: + if not ch.isdigit(): + break + digits += ch + parts.append(int(digits) if digits else 0) + return tuple(parts) + + +def _require_boto3_version() -> None: + """Raise :class:`ImportError` if the installed boto3 predates + :data:`MINIMUM_BOTO3_VERSION`. + + """ + if _version_tuple(boto3.__version__) < _version_tuple(MINIMUM_BOTO3_VERSION): + raise ImportError( + f"The AWS IAM OAUTHBEARER path requires boto3>={MINIMUM_BOTO3_VERSION} " + f"(for the STS GetWebIdentityToken operation), but found boto3 " + f"{boto3.__version__}. Upgrade with: " + f"pip install -U 'confluent-kafka[oauthbearer-aws]'." + ) + class AwsStsTokenProvider: """Mints OAUTHBEARER tokens via AWS STS ``GetWebIdentityToken``.""" @@ -48,6 +83,9 @@ def __init__( client directly instead of constructing a real boto3 STS client. Production callers pass ``None``. :raises TypeError: ``config`` is ``None``. + :raises ImportError: the installed boto3 is older than + :data:`MINIMUM_BOTO3_VERSION` (checked only on the real-client + path, i.e. when ``sts_client`` is ``None``). """ if config is None: raise TypeError("config must not be None") @@ -58,6 +96,9 @@ def __init__( if sts_client is not None: self._sts = sts_client else: + # Fail fast with a clear message if boto3 is present but too old for + # the STS GetWebIdentityToken operation. + _require_boto3_version() # boto3.Session() with explicit region_name short-circuits the # AWS_DEFAULT_REGION env lookup; client() still re-asserts the # region for STS endpoint resolution. diff --git a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py index a1955904e..c33046196 100644 --- a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py +++ b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py @@ -73,6 +73,9 @@ def create_handler( :raises ValueError: ``sasl_oauthbearer_config`` is ``None`` or empty; the wire-grammar parse fails (unknown key, malformed token, missing required field, range/enum violation, etc.). + :raises ImportError: the installed boto3 predates the minimum required by + the AWS IAM path (boto3 is present but too old for STS + ``GetWebIdentityToken``). :raises RuntimeError: AWS SDK reachability or initialisation failure (e.g. unknown region, malformed ``sts_endpoint``). """ diff --git a/tests/oauthbearer/aws/test_aws_sts_token_provider.py b/tests/oauthbearer/aws/test_aws_sts_token_provider.py index ee99535e6..9953b0778 100644 --- a/tests/oauthbearer/aws/test_aws_sts_token_provider.py +++ b/tests/oauthbearer/aws/test_aws_sts_token_provider.py @@ -24,6 +24,7 @@ import base64 import datetime import logging +from pathlib import Path from typing import Any, Dict, Optional from unittest.mock import patch @@ -36,7 +37,12 @@ from botocore.exceptions import ClientError # noqa: E402 from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import AwsOAuthBearerConfig # noqa: E402 -from confluent_kafka.oauthbearer.aws._aws_sts_token_provider import AwsStsTokenProvider # noqa: E402 +from confluent_kafka.oauthbearer.aws._aws_sts_token_provider import ( # noqa: E402 + MINIMUM_BOTO3_VERSION, + AwsStsTokenProvider, + _require_boto3_version, + _version_tuple, +) # ---- Test helpers ---- @@ -384,3 +390,80 @@ def test_ctor_no_sts_endpoint_omits_endpoint_url_kwarg(): "sts", region_name="us-east-1", ) + + +# ---- boto3 version floor check ---- + + +def test_version_tuple_parses_numeric_components(): + assert _version_tuple("1.42.25") == (1, 42, 25) + assert _version_tuple("1.42.97") == (1, 42, 97) + assert _version_tuple("2.0.0") == (2, 0, 0) + + +def test_version_tuple_tolerates_prerelease_suffix(): + assert _version_tuple("1.42.0rc1") == (1, 42, 0) + assert _version_tuple("1.42.25.dev0") == (1, 42, 25, 0) + + +def test_require_boto3_version_passes_on_installed_version(): + # The opt-in test venv installs boto3 >= the floor, so this must not raise. + _require_boto3_version() + + +def test_require_boto3_version_below_floor_raises(monkeypatch): + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.40.0") + with pytest.raises(ImportError, match="requires boto3>=") as exc_info: + _require_boto3_version() + # message names both the required floor and the offending installed version + assert MINIMUM_BOTO3_VERSION in str(exc_info.value) + assert "1.40.0" in str(exc_info.value) + + +def test_require_boto3_version_at_floor_ok(monkeypatch): + import boto3 + + monkeypatch.setattr(boto3, "__version__", MINIMUM_BOTO3_VERSION) + _require_boto3_version() # exactly the floor satisfies ">=" + + +def test_ctor_old_boto3_raises_before_building_client(monkeypatch): + """The floor check is wired into __init__ on the real-client path and fails + fast — before any boto3.Session/client is constructed.""" + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.40.0") + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + with patch("boto3.Session") as mock_session_cls: + with pytest.raises(ImportError, match="GetWebIdentityToken"): + AwsStsTokenProvider(cfg) # no sts_client → real path → check runs + mock_session_cls.assert_not_called() # raised before building a client + + +def test_ctor_injected_client_skips_version_check(monkeypatch): + """Injecting an sts_client (test seam) bypasses the boto3 floor check — the + real boto3 isn't used for the STS call in that case.""" + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.0.0") # ancient, but skipped + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + assert provider is not None + + +def test_minimum_boto3_version_matches_requirements_file(): + """Guard against MINIMUM_BOTO3_VERSION drifting from the pinned floor in + requirements/requirements-oauthbearer-aws.txt.""" + req = Path(__file__).resolve().parents[3] / "requirements" / "requirements-oauthbearer-aws.txt" + floor = None + for line in req.read_text().splitlines(): + stripped = line.strip() + if stripped.startswith("boto3") and ">=" in stripped: + floor = stripped.split(">=", 1)[1].strip() + break + assert floor == MINIMUM_BOTO3_VERSION, ( + f"requirements-oauthbearer-aws.txt pins boto3>={floor} but " + f"MINIMUM_BOTO3_VERSION is {MINIMUM_BOTO3_VERSION}; keep them in sync." + ) From 05144a9c152131e5fbe34c55cb69a2393adc6a08 Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Mon, 8 Jun 2026 14:02:27 +0530 Subject: [PATCH 10/21] Fix black formatting in kv_string_parser.py --- src/confluent_kafka/_util/kv_string_parser.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/confluent_kafka/_util/kv_string_parser.py b/src/confluent_kafka/_util/kv_string_parser.py index 2c84ba86e..bdb8df7d9 100644 --- a/src/confluent_kafka/_util/kv_string_parser.py +++ b/src/confluent_kafka/_util/kv_string_parser.py @@ -74,8 +74,6 @@ def parse_kv( idx = token.find("=") if idx <= 0: what = f"'{context_label}'" if context_label else "key=value" - raise ValueError( - f"Malformed {what} entry '{token}' (expected key=value)." - ) + raise ValueError(f"Malformed {what} entry '{token}' (expected key=value).") - yield token[:idx], token[idx + 1:] + yield token[:idx], token[idx + 1 :] From d6ce11c6fbb6714d58cd765d8f6dffed834068ca Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Mon, 8 Jun 2026 14:27:53 +0530 Subject: [PATCH 11/21] Treat unsupported aws_debug values uniformly; drop .NET framing --- .../aws/_aws_oauthbearer_config.py | 22 +++++-------------- .../aws/_aws_sts_token_provider.py | 5 ++--- tests/oauthbearer/aws/test_aws_autowire.py | 6 ++--- .../aws/test_aws_oauthbearer_config.py | 16 +++++--------- 4 files changed, 16 insertions(+), 33 deletions(-) diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py index c0a8db0cb..cac80fbdc 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py @@ -22,7 +22,7 @@ signing_algorithm=ES384|RS256 (default: ES384) sts_endpoint= (optional, FIPS / VPC) principal_name= (optional, override JWT 'sub') - aws_debug=none|console (default: none, Pythonic subset) + aws_debug=none|console (default: none) tag_= (zero or more JWT custom claims, max 50) SASL extensions arrive separately via :data:`sasl_extensions` (parsed from @@ -48,7 +48,6 @@ "AWS_DEBUG_NONE", "AWS_DEBUG_CONSOLE", "ALLOWED_AWS_DEBUG_VALUES", - "NET_ONLY_AWS_DEBUG_VALUES", "AwsOAuthBearerConfig", ] @@ -77,12 +76,9 @@ AWS_DEBUG_NONE: str = "none" AWS_DEBUG_CONSOLE: str = "console" -#: ``aws_debug`` values accepted by the Python client. +#: ``aws_debug`` values accepted by the Python client: ``none`` and ``console``. ALLOWED_AWS_DEBUG_VALUES = (AWS_DEBUG_NONE, AWS_DEBUG_CONSOLE) -#: ``aws_debug`` values that exist in .NET but are not supported in Python. -NET_ONLY_AWS_DEBUG_VALUES = ("log4net", "systemdiagnostics") - # Recognised non-tag keys for the wire grammar. Anything else (other than # ``tag_``) raises "Unknown key" during :meth:`AwsOAuthBearerConfig.parse`. @@ -149,14 +145,7 @@ def __post_init__(self) -> None: if self.principal_name is not None and self.principal_name == "": raise ValueError(f"{CONFIG_KEY} 'principal_name' must not be empty.") if self.aws_debug not in ALLOWED_AWS_DEBUG_VALUES: - if self.aws_debug in NET_ONLY_AWS_DEBUG_VALUES: - raise ValueError( - f"{CONFIG_KEY} 'aws_debug={self.aws_debug}' is not " - f"supported on this platform. The Python client supports " - f"{list(ALLOWED_AWS_DEBUG_VALUES)} only; " - f"'log4net' and 'systemdiagnostics' are .NET-only sinks." - ) - raise ValueError(f"{CONFIG_KEY} 'aws_debug' must be one of: " f"none, console. Got {self.aws_debug!r}.") + raise ValueError(f"{CONFIG_KEY} 'aws_debug' must be one of: none, console. Got {self.aws_debug!r}.") if self.tags is not None: if not isinstance(self.tags, dict): raise ValueError(f"{CONFIG_KEY} 'tags' must be a dict.") @@ -222,9 +211,8 @@ def parse( elif key == "principal_name": principal_name = value elif key == "aws_debug": - # Normalize case (matches .NET's case-insensitive parsing) so - # downstream comparisons against ALLOWED_AWS_DEBUG_VALUES are - # straightforward. + # Normalize case so downstream comparisons against + # ALLOWED_AWS_DEBUG_VALUES are straightforward. aws_debug = value.lower() elif key.startswith(TAG_KEY_PREFIX): tag_name = key[len(TAG_KEY_PREFIX) :] diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py index 0347db5ac..995ddf62d 100644 --- a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py @@ -112,9 +112,8 @@ def __init__( def _apply_aws_debug(aws_debug: str) -> None: """Apply the ``aws_debug`` side-effect to botocore's logger. - Process-wide effect, intentionally — mirrors .NET's - ``AWSConfigs.LoggingConfig.LogTo`` behaviour. When the user opts in - with ``aws_debug=console``, every boto3 client in the process gets + Process-wide effect, intentionally. When the user opts in with + ``aws_debug=console``, every boto3 client in the process gets DEBUG-level stderr logs. ``aws_debug=none`` is a no-op so any logging the user has configured elsewhere is preserved. """ diff --git a/tests/oauthbearer/aws/test_aws_autowire.py b/tests/oauthbearer/aws/test_aws_autowire.py index 5c7967c05..6e82a5dbe 100644 --- a/tests/oauthbearer/aws/test_aws_autowire.py +++ b/tests/oauthbearer/aws/test_aws_autowire.py @@ -91,9 +91,9 @@ def test_create_handler_invalid_extensions_grammar_raises(): ) -def test_create_handler_aws_debug_dotnet_only_value_raises_with_platform_hint(): - """log4net / systemdiagnostics are .NET-only sinks; surface that clearly.""" - with pytest.raises(ValueError, match="not supported on this platform"): +def test_create_handler_aws_debug_invalid_value_raises(): + """An unsupported aws_debug value is rejected — the client accepts none/console only.""" + with pytest.raises(ValueError, match="aws_debug.*none, console"): create_handler( "region=us-east-1 audience=https://a aws_debug=log4net", None, diff --git a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py index 8868d66d9..afaaffbc1 100644 --- a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py +++ b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py @@ -118,7 +118,7 @@ def test_parse_disallowed_signing_algorithm_raises(alg): AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a signing_algorithm={alg}") -# ---- aws_debug (Python subset: none/console only) ---- +# ---- aws_debug (none/console only) ---- def test_parse_no_aws_debug_defaults_to_none(): @@ -151,15 +151,11 @@ def test_parse_aws_debug_case_insensitive_accepted(value, expected): assert cfg.aws_debug == expected -@pytest.mark.parametrize("value", ["log4net", "systemdiagnostics", "Log4Net", "SystemDiagnostics"]) -def test_parse_aws_debug_dotnet_only_values_rejected_with_clear_message(value): - """log4net / systemdiagnostics are .NET sinks; Python rejects with a - platform-clarity message so cross-language users see the constraint.""" - with pytest.raises(ValueError, match="not supported on this platform"): - AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") - - -@pytest.mark.parametrize("value", ["verbose", "etw", "debug", "true", "foo"]) +# Every unsupported aws_debug value is rejected with the same uniform error. +@pytest.mark.parametrize( + "value", + ["verbose", "etw", "debug", "true", "foo", "log4net", "systemdiagnostics", "Log4Net", "SystemDiagnostics"], +) def test_parse_aws_debug_invalid_value_raises(value): with pytest.raises(ValueError, match="aws_debug.*none, console"): AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") From f6293b23d1e6ec57982042abe9c915a97a6703ab Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Mon, 8 Jun 2026 16:00:20 +0530 Subject: [PATCH 12/21] Add example for aws_iam --- examples/oauth_oidc_ccloud_aws_iam.py | 227 ++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 examples/oauth_oidc_ccloud_aws_iam.py diff --git a/examples/oauth_oidc_ccloud_aws_iam.py b/examples/oauth_oidc_ccloud_aws_iam.py new file mode 100644 index 000000000..9e6b86ea1 --- /dev/null +++ b/examples/oauth_oidc_ccloud_aws_iam.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end example for AWS IAM OAUTHBEARER authentication. + +Creates a unique topic, produces messages for ``--run-for`` seconds, and +consumes them back — exercising the autowire path across AdminClient, +Producer, and Consumer. + +Activation is config-only: setting +``sasl.oauthbearer.metadata.authentication.type=aws_iam`` is enough. The C +extension detects the marker and wires up the OAUTHBEARER refresh callback — +no ``import confluent_kafka.oauthbearer.aws`` is needed at the call site. + +With ``--duration-seconds 60`` (the AWS-STS minimum) and the default +``--run-for 120``, the ``debug=security`` log stream shows librdkafka refresh +the token mid-run (it refreshes at ~80% of the token lifetime). + +Install: + pip install 'confluent-kafka[oauthbearer-aws]' + +Prerequisites: + 1. Runs on AWS compute (EC2 / EKS / ECS / Fargate / Lambda) with an IAM role + attached — boto3's default credential chain resolves it, no static keys. + 2. The role's trust policy allows ``sts:GetWebIdentityToken`` for the audience. + 3. ``aws iam enable-outbound-web-identity-federation`` has been run once on + the account by an administrator. + 4. The role has produce + consume + create-topic rights on the cluster. + +To run: + python oauth_oidc_ccloud_aws_iam.py \\ + -b pkc-xxxx.aws.confluent.cloud:9092 \\ + --region us-east-1 \\ + --audience https://confluent.cloud/oidc \\ + --extensions logicalCluster=lkc-abc,identityPoolId=pool-xyz +""" + +import argparse +import logging +import time +import uuid + +from confluent_kafka import Consumer, Producer +from confluent_kafka.admin import AdminClient, NewTopic +from confluent_kafka.serialization import StringSerializer + + +def common_config(args): + """SASL config shared by Producer, Consumer, and AdminClient.""" + conf = { + 'bootstrap.servers': args.bootstrap_servers, + 'security.protocol': 'SASL_SSL', + 'sasl.mechanisms': 'OAUTHBEARER', + # The four AWS IAM autowire keys: method=oidc is required (the AWS path + # runs inside librdkafka's OIDC subsystem); the marker triggers + # autowiring; the config string is the AWS wire grammar + # (whitespace-separated key=value — region and audience required). + 'sasl.oauthbearer.method': 'oidc', + 'sasl.oauthbearer.metadata.authentication.type': 'aws_iam', + 'sasl.oauthbearer.config': f'region={args.region} ' + f'audience={args.audience} ' + f'duration_seconds={args.duration_seconds}', + # Surfaces the SASL handshake + OAUTHBEARER refresh events on stderr. + 'debug': 'security', + } + + # Optional SASL extensions (RFC 7628) forwarded verbatim to the broker, + # e.g. logicalCluster=lkc-abc,identityPoolId=pool-xyz + if args.extensions: + conf['sasl.oauthbearer.extensions'] = args.extensions + + return conf + + +def consumer_config(args, group_id): + cfg = common_config(args) + cfg['group.id'] = group_id + cfg['auto.offset.reset'] = 'earliest' + cfg['enable.auto.offset.store'] = False # commit offsets manually + return cfg + + +def create_topic(admin_conf, topic_name, num_partitions=1, replication_factor=3): + """Create the topic (RF=3 is the Confluent Cloud default).""" + admin = AdminClient(admin_conf) + futures = admin.create_topics( + [ + NewTopic(topic_name, num_partitions=num_partitions, replication_factor=replication_factor), + ] + ) + for topic, future in futures.items(): + try: + future.result() + print(f"[admin] Topic '{topic}' created " f"({num_partitions} partition(s), RF={replication_factor})") + except Exception as exc: + print(f"[admin] Failed to create topic '{topic}': {exc}") + raise + + +def delivery_report(err, msg): + if err is not None: + print(f"[producer] Delivery failed: {err}") + return + print( + f"[producer] Produced to {msg.topic()} [{msg.partition()}] " + f"at offset {msg.offset()}: {msg.value().decode('utf-8')}" + ) + + +def main(args): + # Unique topic + group per run so the example is self-contained. + topic_name = f"aws-iam-{uuid.uuid4()}" + group_id = f"aws-iam-consumer-{uuid.uuid4()}" + + p_conf = common_config(args) + c_conf = consumer_config(args, group_id) + a_conf = common_config(args) + + logging.basicConfig(level=logging.INFO) + + print("\n=== AWS IAM OAUTHBEARER end-to-end example ===") + print(f"bootstrap.servers: {args.bootstrap_servers}") + print(f"region: {args.region}") + print(f"audience: {args.audience}") + print(f"duration_seconds: {args.duration_seconds} " f"(auto-refresh at ~{int(args.duration_seconds * 0.8)}s)") + print(f"run-for: {args.run_for}s") + print(f"topic (generated): {topic_name}") + print(f"group.id (generated): {group_id}\n") + + create_topic(a_conf, topic_name) + + producer = Producer(p_conf) + consumer = Consumer(c_conf) + consumer.subscribe([topic_name]) + serializer = StringSerializer('utf_8') + + start = time.time() + end_at = start + args.run_for + produced = 0 + consumed = 0 + + print( + f"[loop] Producing/consuming for {args.run_for}s — " + f"watch the debug=security logs for token-refresh events.\n" + ) + + try: + while time.time() < end_at: + elapsed = time.time() - start + msg = f"hello-from-aws-iam T+{elapsed:.1f}s" + + producer.produce( + topic_name, + value=serializer(msg), + on_delivery=delivery_report, + ) + producer.poll(0) + produced += 1 + + received = consumer.poll(1.0) + if received is None: + pass # poll timeout, no message yet + elif received.error() is not None: + print(f"[consumer] error: {received.error()}") + else: + consumer.store_offsets(received) + consumed += 1 + print( + f"[consumer] Received from " + f"{received.topic()} [{received.partition()}] " + f"at offset {received.offset()}: " + f"{received.value().decode('utf-8')}" + ) + + time.sleep(args.interval) + except KeyboardInterrupt: + print("\n[main] Interrupted — flushing.") + finally: + print(f"\n[summary] Produced {produced}, consumed {consumed} " f"in {time.time() - start:.1f}s. Flushing...") + producer.flush(timeout=10) + consumer.close() + print("[summary] Done.") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='End-to-end OAUTHBEARER example via AWS IAM autowire ' '(produce + consume + admin).', + ) + parser.add_argument('-b', dest='bootstrap_servers', required=True, help='Bootstrap broker(s) (host[:port])') + parser.add_argument('--region', required=True, help='AWS region (e.g. us-east-1)') + parser.add_argument( + '--audience', + required=True, + help='OIDC audience claim the broker expects ' '(e.g. https://confluent.cloud/oidc)', + ) + parser.add_argument( + '--extensions', + default=None, + help='Optional sasl.oauthbearer.extensions value ' '(comma-separated key=value pairs)', + ) + parser.add_argument( + '--duration-seconds', + dest='duration_seconds', + type=int, + default=60, + help='STS DurationSeconds (default 60 = AWS minimum); ' 'librdkafka auto-refreshes at ~80%% of it.', + ) + parser.add_argument( + '--run-for', dest='run_for', type=int, default=120, help='Run duration in seconds (default 120).' + ) + parser.add_argument('--interval', type=float, default=5.0, help='Seconds between produce calls (default 5).') + + main(parser.parse_args()) From 8da99237e2c5968ffdf137654e948424a29b0e1c Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Tue, 9 Jun 2026 13:29:17 +0530 Subject: [PATCH 13/21] Add requirements-oauthbearer-aws entry in requirements-tests-install --- requirements/requirements-tests-install.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements-tests-install.txt b/requirements/requirements-tests-install.txt index ff5b3c348..95f6beb24 100644 --- a/requirements/requirements-tests-install.txt +++ b/requirements/requirements-tests-install.txt @@ -4,4 +4,5 @@ -r requirements-avro.txt -r requirements-protobuf.txt -r requirements-json.txt +-r requirements-oauthbearer-aws.txt tests/trivup/trivup-0.14.0.tar.gz \ No newline at end of file From 491ba0edcd888003adc2ff73b910f1a0fe63e50a Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Tue, 9 Jun 2026 15:32:20 +0530 Subject: [PATCH 14/21] Add missing entries requirements-oauthbearer-aws.txt --- requirements/requirements-all.txt | 3 ++- tests/oauthbearer/aws/test_real_sts.py | 4 ++-- tools/source-package-verification.sh | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index d2514d3a7..0acf81a06 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -7,4 +7,5 @@ -r requirements-examples.txt -r requirements-tests.txt -r requirements-docs.txt --r requirements-soaktest.txt \ No newline at end of file +-r requirements-soaktest.txt +-r requirements-oauthbearer-aws.txt \ No newline at end of file diff --git a/tests/oauthbearer/aws/test_real_sts.py b/tests/oauthbearer/aws/test_real_sts.py index 83dab6705..465713bb6 100644 --- a/tests/oauthbearer/aws/test_real_sts.py +++ b/tests/oauthbearer/aws/test_real_sts.py @@ -35,7 +35,7 @@ # On EC2 with role attached, audience-trust-policy enabled: export RUN_AWS_STS_REAL=1 - export AWS_REGION=eu-north-1 + export AWS_STS_TEST_REGION=eu-north-1 # AWS_REGION also accepted (fallback) export AWS_STS_TEST_AUDIENCE=https://api.example.com /tmp/ckp-optin/bin/pytest -v -s tests/oauthbearer/aws/test_real_sts.py @@ -71,7 +71,7 @@ # different role / audience without code changes. # ============================================================================= -_REGION = os.environ.get("AWS_REGION", "eu-north-1") +_REGION = os.environ.get("AWS_STS_TEST_REGION") or os.environ.get("AWS_REGION", "eu-north-1") _AUDIENCE = os.environ.get("AWS_STS_TEST_AUDIENCE", "https://api.example.com") _DURATION = os.environ.get("DURATION_SECONDS", "300") diff --git a/tools/source-package-verification.sh b/tools/source-package-verification.sh index 919d3d26a..e18c918d7 100755 --- a/tools/source-package-verification.sh +++ b/tools/source-package-verification.sh @@ -71,7 +71,7 @@ if [[ $OS_NAME == linux && $ARCH == x64 ]]; then python3 -c " import importlib.metadata, sys extras = set(importlib.metadata.metadata('confluent-kafka').get_all('Provides-Extra') or []) -required = {'schema-registry', 'schemaregistry', 'avro', 'json', 'protobuf', 'rules'} +required = {'schema-registry', 'schemaregistry', 'avro', 'json', 'protobuf', 'rules', 'oauthbearer-aws'} missing = required - extras if missing: print(f'Failing: package does not provide extras: {missing}', file=sys.stderr) From df5f96a5bf982b19d1356dfb69f651e5cb6af551 Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Tue, 9 Jun 2026 19:40:03 +0530 Subject: [PATCH 15/21] Update the version in semaphore.yml --- .semaphore/semaphore.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.semaphore/semaphore.yml b/.semaphore/semaphore.yml index 7e21667f0..f7a90e65c 100644 --- a/.semaphore/semaphore.yml +++ b/.semaphore/semaphore.yml @@ -8,7 +8,7 @@ execution_time_limit: global_job_config: env_vars: - name: LIBRDKAFKA_VERSION - value: v2.14.0 + value: v2.14.2-aws-iam.2-dev prologue: commands: - checkout From 31ae2700531d39338dcb8d6f06b600a0060219d2 Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 10 Jun 2026 08:15:35 +0530 Subject: [PATCH 16/21] version update to v2.14.2-aws-iam.2-dev --- CHANGELOG.md | 11 ++++++++++- examples/docker/Dockerfile.alpine | 2 +- pyproject.toml | 2 +- src/confluent_kafka/src/confluent_kafka.h | 2 +- tests/soak/setup_all_versions.py | 4 ++-- 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 91d8d4b90..2b871c2cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ # Confluent Python Client for Apache Kafka - CHANGELOG -## v2.xx.x +## v2.14.2-aws-iam.2-dev + +### Enhancements + +- Add AWS IAM OAUTHBEARER authentication via optional `oauthbearer-aws` extra. Set + `sasl.oauthbearer.method=oidc` + `sasl.oauthbearer.metadata.authentication.type=aws_iam` + + `sasl.oauthbearer.config="region=... audience=..."` to autowire a refresh + callback that mints fresh JWTs via AWS STS `GetWebIdentityToken` on every + token refresh, using boto3's default credential chain. Cross-language wire + parity with .NET / JS / Go. See `examples/oauth_oidc_ccloud_aws_iam.py`. ### Fixes diff --git a/examples/docker/Dockerfile.alpine b/examples/docker/Dockerfile.alpine index 6f2add9e6..c0c7c0095 100644 --- a/examples/docker/Dockerfile.alpine +++ b/examples/docker/Dockerfile.alpine @@ -30,7 +30,7 @@ FROM alpine:3.12 COPY . /usr/src/confluent-kafka-python -ENV LIBRDKAFKA_VERSION="v2.14.0" +ENV LIBRDKAFKA_VERSION="v2.14.2-aws-iam.2-dev" ENV KCAT_VERSION="master" ENV CKP_VERSION="master" diff --git a/pyproject.toml b/pyproject.toml index da7de97fb..e787d6a33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "confluent-kafka" -version = "2.14.0" +version = "2.14.2-aws-iam.2-dev" description = "Confluent's Python client for Apache Kafka" classifiers = [ "Development Status :: 5 - Production/Stable", diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index 2fd276133..94883b0d7 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -38,7 +38,7 @@ /** * @brief confluent-kafka-python version, must match that of pyproject.toml. */ -#define CFL_VERSION_STR "2.14.0" +#define CFL_VERSION_STR "2.14.2-aws-iam.2-dev" /** * Minimum required librdkafka version. This is checked both during diff --git a/tests/soak/setup_all_versions.py b/tests/soak/setup_all_versions.py index 9cdf4f2dc..12372f3bb 100755 --- a/tests/soak/setup_all_versions.py +++ b/tests/soak/setup_all_versions.py @@ -5,7 +5,7 @@ PYTHON_SOAK_TEST_BRANCH = 'master' LIBRDKAFKA_VERSIONS = [ - '2.14.0', + '2.14.2-aws-iam.2-dev', '2.13.2', '2.13.0', '2.12.1', @@ -22,7 +22,7 @@ ] PYTHON_VERSIONS = [ - '2.14.0', + '2.14.2-aws-iam.2-dev', '2.13.2', '2.13.0', '2.12.1', From 9069e8933729ffbe4d130b8dbfc9187b1d669dce Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 10 Jun 2026 09:07:01 +0530 Subject: [PATCH 17/21] Fix the tag version string --- pyproject.toml | 2 +- src/confluent_kafka/src/confluent_kafka.h | 2 +- tests/soak/setup_all_versions.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e787d6a33..de2da9fed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "confluent-kafka" -version = "2.14.2-aws-iam.2-dev" +version = "2.14.2+aws-iam.2-dev" description = "Confluent's Python client for Apache Kafka" classifiers = [ "Development Status :: 5 - Production/Stable", diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index 94883b0d7..5a28be555 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -38,7 +38,7 @@ /** * @brief confluent-kafka-python version, must match that of pyproject.toml. */ -#define CFL_VERSION_STR "2.14.2-aws-iam.2-dev" +#define CFL_VERSION_STR "2.14.2+aws-iam.2-dev" /** * Minimum required librdkafka version. This is checked both during diff --git a/tests/soak/setup_all_versions.py b/tests/soak/setup_all_versions.py index 12372f3bb..82b59a0b4 100755 --- a/tests/soak/setup_all_versions.py +++ b/tests/soak/setup_all_versions.py @@ -22,7 +22,7 @@ ] PYTHON_VERSIONS = [ - '2.14.2-aws-iam.2-dev', + '2.14.2+aws-iam.2-dev', '2.13.2', '2.13.0', '2.12.1', From d3699c7696f3a0363ba827597b1a2334b3e9b7d5 Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Wed, 10 Jun 2026 12:27:50 +0530 Subject: [PATCH 18/21] Change the version to v2.14.2.dev2 --- CHANGELOG.md | 2 +- pyproject.toml | 2 +- src/confluent_kafka/src/confluent_kafka.h | 2 +- tests/soak/setup_all_versions.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c47106f0..117408c61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Confluent Python Client for Apache Kafka - CHANGELOG -## v2.14.2-aws-iam.2-dev +## v2.14.2.dev2 ### Enhancements diff --git a/pyproject.toml b/pyproject.toml index de2da9fed..605404843 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "confluent-kafka" -version = "2.14.2+aws-iam.2-dev" +version = "2.14.2.dev2" description = "Confluent's Python client for Apache Kafka" classifiers = [ "Development Status :: 5 - Production/Stable", diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index 5a28be555..fc17e4a4f 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -38,7 +38,7 @@ /** * @brief confluent-kafka-python version, must match that of pyproject.toml. */ -#define CFL_VERSION_STR "2.14.2+aws-iam.2-dev" +#define CFL_VERSION_STR "2.14.2.dev2" /** * Minimum required librdkafka version. This is checked both during diff --git a/tests/soak/setup_all_versions.py b/tests/soak/setup_all_versions.py index 82b59a0b4..71f38dfa3 100755 --- a/tests/soak/setup_all_versions.py +++ b/tests/soak/setup_all_versions.py @@ -22,7 +22,7 @@ ] PYTHON_VERSIONS = [ - '2.14.2+aws-iam.2-dev', + '2.14.2.dev2', '2.13.2', '2.13.0', '2.12.1', From bab88598b298cb1f8918d92fdab719c18504a648 Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Thu, 11 Jun 2026 07:56:10 +0530 Subject: [PATCH 19/21] Remove strip + sentinel --- src/confluent_kafka/src/confluent_kafka.c | 138 ++++++---------------- src/confluent_kafka/src/confluent_kafka.h | 17 ++- tests/oauthbearer/aws/test_dispatch.py | 24 ++-- 3 files changed, 62 insertions(+), 117 deletions(-) diff --git a/src/confluent_kafka/src/confluent_kafka.c b/src/confluent_kafka/src/confluent_kafka.c index 68b18654d..9ae916019 100644 --- a/src/confluent_kafka/src/confluent_kafka.c +++ b/src/confluent_kafka/src/confluent_kafka.c @@ -2605,8 +2605,8 @@ static void common_conf_set_software(rd_kafka_conf_t *conf) { * Flow: * 1. Skip if the marker key is absent or its value is not "aws_iam". * (Other values like "azure_imds" flow through to librdkafka unchanged.) - * 2. If oauth_cb is already set, strip the marker and skip autowire — - * the explicit handler wins (precedence rule). + * 2. If oauth_cb is already set, the explicit handler wins (precedence + * rule) — return immediately, leaving the marker in place. * 3. Require sasl.oauthbearer.method == "oidc". The AWS IAM path runs as * a high-level-client refresh callback inside librdkafka's OIDC * subsystem (parallel to azure_imds); without method=oidc the @@ -2622,49 +2622,28 @@ static void common_conf_set_software(rd_kafka_conf_t *conf) { * The returned Python callable replaces oauth_cb in confdict; the * existing config-iteration loop below picks it up and registers the * librdkafka refresh callback. - * 7+. Strip the marker key + auto-set THREE sentinel OIDC fields that - * librdkafka's config-finalize demands when method=oidc. The full set: - * - sasl.oauthbearer.token.endpoint.url - * - sasl.oauthbearer.client.id - * - sasl.oauthbearer.client.secret - * - * Why all three are needed: - * librdkafka has TWO mandatory-field checks gated on the marker value: - * - * (a) rdkafka_conf.c:4130 demands token.endpoint.url unless - * metadata_authentication.type == AZURE_IMDS or AWS_IAM. With - * the marker stripped, this check fires unless we provide a URL. - * - * (b) rdkafka_conf.c:4155 — finalize_oauthbearer_oidc_grant_type - * ONLY runs when metadata_authentication.type == NONE. With the - * marker stripped, it runs and demands client.id + client.secret - * (for the default CLIENT_CREDENTIALS grant type). - * - * Why both old AND new librdkafkas tolerate this: - * - librdkafka < AWS-IAM-aware: doesn't know `aws_iam`, rejects the - * marker value at rd_kafka_conf_set() time. We must strip first. - * The 3 sentinels then satisfy its OIDC mandatory checks. - * - librdkafka >= AWS-IAM-aware: would honour the marker (bypass - * token.endpoint.url and skip the grant-type check), but we strip - * it for backward compatibility with older librdkafkas. The 3 - * sentinels keep this path working too. - * - * The sentinels are NEVER used at runtime. librdkafka's config - * finalize at rdkafka_conf.c:4166-4169 explicitly skips the built-in - * OIDC token fetcher when a refresh callback is registered: - * "Enable background thread for the builtin OIDC handler, - * unless a refresh callback has been set." - * Our autowire registers an oauth_cb. The sentinel URL uses RFC 2606 - * `.invalid` TLD so even an accidental fetch would fail at DNS. - * - * TODO: REMOVE STEPS 8+9 (strip + sentinels) once librdkafka - * v2.14.2-aws-iam.2-dev (or the official release that supersedes it) - * becomes the bundled MIN_VER floor. Leave the marker in place; the - * AWS-IAM-aware librdkafka knows `aws_iam`, bypasses the - * token.endpoint.url check, and skips the grant-type check entirely. - * All three sentinels become unnecessary at that point. The removal - * MUST be done in the same PR that bumps the librdkafka floor. - * Project memory: project_aws_iam_python_alignment.md decision #6. + * 7. Leave the marker in confdict and return — it is passed to librdkafka + * unchanged. + * + * Native librdkafka handling (this REQUIRES an AWS-IAM-aware build; the + * bundled librdkafka floor is >= v2.14.2-aws-iam.2-dev — see + * MIN_RD_KAFKA_VERSION in confluent_kafka.h): + * - `aws_iam` is a recognized value for metadata.authentication.type + * (rdkafka_conf.c), so rd_kafka_conf_set() accepts the marker. + * - librdkafka bypasses the mandatory token.endpoint.url check for + * AWS_IAM, and the grant-type finalize runs only for type==NONE — so + * no sentinel OIDC fields are required. + * - librdkafka registers its own aws_iam refresh cb only when no + * token_refresh_cb is set (rdkafka.c). Since step 6 registered our + * oauth_cb, our callback owns the token fetch and librdkafka's native + * aws_iam stub never fires; the built-in OIDC fetcher is likewise + * skipped when a refresh cb is present (rdkafka_conf.c). + * + * History: earlier versions stripped the marker and injected 3 sentinel + * OIDC fields (token.endpoint.url, client.id, client.secret) so the path + * also worked on stock librdkafka that didn't know `aws_iam`. That strip + * + sentinels was removed once the AWS-IAM-aware librdkafka became the + * bundled floor. Project memory: project_aws_iam_python_alignment.md #6. * * @returns 0 on success (no-op or autowire complete), -1 on error * (PyErr_* is set; caller goto outer_err). @@ -2724,14 +2703,13 @@ static int resolve_aws_oauthbearer_marker(PyObject *confdict) { return 0; } - /* 2. Explicit oauth_cb wins. Skip ahead to strip+sentinels — librdkafka - * still needs the sentinels because we strip the marker (so its - * AWS_IAM bypass paths don't fire). The user's oauth_cb takes over - * the actual refresh path; the sentinels are just there to keep - * librdkafka's config-finalize happy. */ + /* 2. Explicit oauth_cb wins (precedence). Leave the marker in place and + * return — the AWS-IAM-aware librdkafka accepts `aws_iam` and uses the + * user's oauth_cb (it only registers its own native aws_iam refresh cb + * when no refresh cb is set; see rdkafka.c). Nothing else to do. */ cb = PyDict_GetItemString(confdict, OAUTH_CB_KEY); if (cb && cb != Py_None) { - goto strip_and_sentinels; + return 0; } /* 3. Require sasl.oauthbearer.method = "oidc". */ @@ -2805,57 +2783,13 @@ static int resolve_aws_oauthbearer_marker(PyObject *confdict) { } Py_DECREF(callback); -strip_and_sentinels: - /* 8. Strip the marker. UNCONDITIONAL — see TODO at top of function. - * Reached from both the autowire-fires path AND the - * explicit-oauth_cb-wins path (goto from step 2). */ - if (PyDict_DelItemString(confdict, MARKER_KEY) == -1) { - return -1; - } - - /* 9. Auto-set 3 sentinel OIDC fields that librdkafka demands when - * method=oidc. Required because we stripped the marker in step 8 - * (so librdkafka's AWS_IAM bypass paths never fire). The sentinels - * are never used at runtime — the oauth_cb (autowired or user- - * supplied) takes over the token fetch entirely. See TODO at top - * of function for full reasoning. User-supplied values are - * respected (no clobber). */ - { - static const struct { - const char *key; - const char *value; - } SENTINELS[] = { - {"sasl.oauthbearer.token.endpoint.url", - "https://aws-iam-autowire.invalid/"}, - {"sasl.oauthbearer.client.id", "aws-iam-autowire"}, - {"sasl.oauthbearer.client.secret", "aws-iam-autowire"}, - }; - size_t i; - for (i = 0; - i < sizeof(SENTINELS) / sizeof(SENTINELS[0]); - i++) { - PyObject *existing; - PyObject *val; - - existing = PyDict_GetItemString(confdict, - SENTINELS[i].key); - if (existing) { - /* User already supplied a value — respect it. */ - continue; - } - val = PyUnicode_FromString(SENTINELS[i].value); - if (!val) { - return -1; - } - if (PyDict_SetItemString(confdict, SENTINELS[i].key, - val) == -1) { - Py_DECREF(val); - return -1; - } - Py_DECREF(val); - } - } - + /* 8. Leave the marker in confdict — it is passed to librdkafka + * unchanged. The AWS-IAM-aware librdkafka recognizes `aws_iam`, + * bypasses the token.endpoint.url + grant-type requirements + * (rdkafka_conf.c), and — because step 7 registered our oauth_cb + * (token_refresh_cb) — uses our callback for the token rather than its + * native aws_iam stub (which only registers when no refresh cb is set; + * rdkafka.c). No marker strip, no sentinel OIDC fields needed. */ return 0; } diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index fc17e4a4f..fe180102d 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -45,17 +45,24 @@ * build-time (just below) and runtime (see confluent_kafka.c). * Make sure to keep the MIN_RD_KAFKA_VERSION, MIN_VER_ERRSTR and #error * defines and strings in sync. + * + * Floor is v2.14.2 to match the bundled AWS-IAM-aware librdkafka + * (v2.14.2-aws-iam.2-dev), which the OAUTHBEARER aws_iam autowire path + * requires (the dispatcher passes the `aws_iam` marker straight to + * librdkafka — see resolve_aws_oauthbearer_marker in confluent_kafka.c). + * Note: the version number alone cannot distinguish an aws_iam-aware build + * from stock 2.14.2; the AWS-IAM behavior is guaranteed by the bundled wheel. */ -#define MIN_RD_KAFKA_VERSION 0x020e00ff +#define MIN_RD_KAFKA_VERSION 0x020e02ff #ifdef __APPLE__ #define MIN_VER_ERRSTR \ - "confluent-kafka-python requires librdkafka v2.14.0 or later. " \ + "confluent-kafka-python requires librdkafka v2.14.2 or later. " \ "Install the latest version of librdkafka from Homebrew by running " \ "`brew install librdkafka` or `brew upgrade librdkafka`" #else #define MIN_VER_ERRSTR \ - "confluent-kafka-python requires librdkafka v2.14.0 or later. " \ + "confluent-kafka-python requires librdkafka v2.14.2 or later. " \ "Install the latest version of librdkafka from the Confluent " \ "repositories, see http://docs.confluent.io/current/installation.html" #endif @@ -63,10 +70,10 @@ #if RD_KAFKA_VERSION < MIN_RD_KAFKA_VERSION #ifdef __APPLE__ #error \ - "confluent-kafka-python requires librdkafka v2.14.0 or later. Install the latest version of librdkafka from Homebrew by running `brew install librdkafka` or `brew upgrade librdkafka`" + "confluent-kafka-python requires librdkafka v2.14.2 or later. Install the latest version of librdkafka from Homebrew by running `brew install librdkafka` or `brew upgrade librdkafka`" #else #error \ - "confluent-kafka-python requires librdkafka v2.14.0 or later. Install the latest version of librdkafka from the Confluent repositories, see http://docs.confluent.io/current/installation.html" + "confluent-kafka-python requires librdkafka v2.14.2 or later. Install the latest version of librdkafka from the Confluent repositories, see http://docs.confluent.io/current/installation.html" #endif #endif diff --git a/tests/oauthbearer/aws/test_dispatch.py b/tests/oauthbearer/aws/test_dispatch.py index a5feb852b..f4cbce242 100644 --- a/tests/oauthbearer/aws/test_dispatch.py +++ b/tests/oauthbearer/aws/test_dispatch.py @@ -272,8 +272,11 @@ async def test_happy_path_aio_consumer_constructs_with_marker(mocked_boto3): def test_explicit_oauth_cb_wins_over_marker(): - """When user supplies their own oauth_cb, our dispatcher strips the marker - and yields. boto3 is NOT touched.""" + """When the user supplies their own oauth_cb, the dispatcher leaves the + marker in place and yields — the explicit handler wins and boto3 is NOT + touched. The AWS-IAM-aware librdkafka accepts the marker and uses the + user's oauth_cb (its native aws_iam cb only registers when no refresh cb + is set).""" sentinel_called = [] def user_oauth_cb(config_str): @@ -415,17 +418,18 @@ def test_friendly_import_error_on_admin_client_too(boto3_absent): ) -# ---- 9. Marker is stripped before native handoff ---- +# ---- 9. Marker is passed through to the AWS-IAM-aware librdkafka ---- -def test_marker_stripped_after_autowire(mocked_boto3): - """The C dispatcher must strip the marker from confdict before the - config-iteration loop sees it. Today's bundled librdkafka doesn't know - the 'aws_iam' enum value and would reject it at rd_kafka_conf_set time. - If the strip stops working, this test fails with librdkafka's - 'invalid value' error rather than Producer constructing cleanly.""" +def test_marker_passed_through_to_librdkafka(mocked_boto3): + """The dispatcher leaves the 'aws_iam' marker in confdict; it is passed to + librdkafka unchanged. The bundled AWS-IAM-aware librdkafka recognizes the + 'aws_iam' enum value and accepts it at rd_kafka_conf_set time (bypassing + the token.endpoint.url + grant-type requirements), so the Producer + constructs cleanly. This also pins the requirement that the linked + librdkafka be AWS-IAM-aware: against stock librdkafka the marker would be + rejected with an 'invalid value' error here.""" p = confluent_kafka.Producer(_minimal_aws_iam_config()) - # If we got here, the strip succeeded — librdkafka never saw the marker. assert p is not None p.flush(timeout=0.1) From 003caa24b9ce6fd3f995154320de2e8aaf035a1b Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Thu, 11 Jun 2026 07:56:52 +0530 Subject: [PATCH 20/21] Disable python github assestations --- .semaphore/publish-test-pypi.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.semaphore/publish-test-pypi.yml b/.semaphore/publish-test-pypi.yml index d482eb61c..603ef41b6 100644 --- a/.semaphore/publish-test-pypi.yml +++ b/.semaphore/publish-test-pypi.yml @@ -33,6 +33,9 @@ blocks: - name: Verify commands: - checkout + # mise refuses to install some Python versions when GitHub artifact + # attestations are missing; disable that check for the verify loop. + - export MISE_PYTHON_GITHUB_ATTESTATIONS=false - sem-version python 3.11 - export CK_VERSION=$(python -c "import tomllib; print(tomllib.load(open('pyproject.toml','rb'))['project']['version'])") - tools/test-released-wheels.sh $CK_VERSION test From f2cd9defd06248e0ebc10e7e377d1444dec0cb8c Mon Sep 17 00:00:00 2001 From: Pranav Shah Date: Thu, 11 Jun 2026 08:03:04 +0530 Subject: [PATCH 21/21] Update the version to v2.14.2.dev3 --- CHANGELOG.md | 2 +- pyproject.toml | 2 +- src/confluent_kafka/src/confluent_kafka.h | 2 +- tests/soak/setup_all_versions.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 117408c61..c624d4aaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Confluent Python Client for Apache Kafka - CHANGELOG -## v2.14.2.dev2 +## v2.14.2.dev3 ### Enhancements diff --git a/pyproject.toml b/pyproject.toml index 605404843..64c832491 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "confluent-kafka" -version = "2.14.2.dev2" +version = "2.14.2.dev3" description = "Confluent's Python client for Apache Kafka" classifiers = [ "Development Status :: 5 - Production/Stable", diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index fe180102d..9c34f92be 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -38,7 +38,7 @@ /** * @brief confluent-kafka-python version, must match that of pyproject.toml. */ -#define CFL_VERSION_STR "2.14.2.dev2" +#define CFL_VERSION_STR "2.14.2.dev3" /** * Minimum required librdkafka version. This is checked both during diff --git a/tests/soak/setup_all_versions.py b/tests/soak/setup_all_versions.py index 71f38dfa3..b9b322102 100755 --- a/tests/soak/setup_all_versions.py +++ b/tests/soak/setup_all_versions.py @@ -22,7 +22,7 @@ ] PYTHON_VERSIONS = [ - '2.14.2.dev2', + '2.14.2.dev3', '2.13.2', '2.13.0', '2.12.1',