diff --git a/.semaphore/publish-test-pypi.yml b/.semaphore/publish-test-pypi.yml index d482eb61c..603ef41b6 100644 --- a/.semaphore/publish-test-pypi.yml +++ b/.semaphore/publish-test-pypi.yml @@ -33,6 +33,9 @@ blocks: - name: Verify commands: - checkout + # mise refuses to install some Python versions when GitHub artifact + # attestations are missing; disable that check for the verify loop. + - export MISE_PYTHON_GITHUB_ATTESTATIONS=false - sem-version python 3.11 - export CK_VERSION=$(python -c "import tomllib; print(tomllib.load(open('pyproject.toml','rb'))['project']['version'])") - tools/test-released-wheels.sh $CK_VERSION test diff --git a/.semaphore/semaphore.yml b/.semaphore/semaphore.yml index 1fd44be4d..6218e33f4 100644 --- a/.semaphore/semaphore.yml +++ b/.semaphore/semaphore.yml @@ -8,7 +8,7 @@ execution_time_limit: global_job_config: env_vars: - name: LIBRDKAFKA_VERSION - value: v2.14.2 + value: v2.14.2-aws-iam.2-dev prologue: commands: - checkout diff --git a/CHANGELOG.md b/CHANGELOG.md index deffeab89..c624d4aaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ # Confluent Python Client for Apache Kafka - CHANGELOG -## v2.xx.x +## v2.14.2.dev3 + +### Enhancements + +- Add AWS IAM OAUTHBEARER authentication via optional `oauthbearer-aws` extra. Set + `sasl.oauthbearer.method=oidc` + `sasl.oauthbearer.metadata.authentication.type=aws_iam` + + `sasl.oauthbearer.config="region=... audience=..."` to autowire a refresh + callback that mints fresh JWTs via AWS STS `GetWebIdentityToken` on every + token refresh, using boto3's default credential chain. Cross-language wire + parity with .NET / JS / Go. See `examples/oauth_oidc_ccloud_aws_iam.py`. ### Fixes diff --git a/examples/docker/Dockerfile.alpine b/examples/docker/Dockerfile.alpine index 2a6156bf5..c0c7c0095 100644 --- a/examples/docker/Dockerfile.alpine +++ b/examples/docker/Dockerfile.alpine @@ -30,7 +30,7 @@ FROM alpine:3.12 COPY . /usr/src/confluent-kafka-python -ENV LIBRDKAFKA_VERSION="v2.14.2" +ENV LIBRDKAFKA_VERSION="v2.14.2-aws-iam.2-dev" ENV KCAT_VERSION="master" ENV CKP_VERSION="master" diff --git a/examples/oauth_oidc_ccloud_aws_iam.py b/examples/oauth_oidc_ccloud_aws_iam.py new file mode 100644 index 000000000..9e6b86ea1 --- /dev/null +++ b/examples/oauth_oidc_ccloud_aws_iam.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end example for AWS IAM OAUTHBEARER authentication. + +Creates a unique topic, produces messages for ``--run-for`` seconds, and +consumes them back — exercising the autowire path across AdminClient, +Producer, and Consumer. + +Activation is config-only: setting +``sasl.oauthbearer.metadata.authentication.type=aws_iam`` is enough. The C +extension detects the marker and wires up the OAUTHBEARER refresh callback — +no ``import confluent_kafka.oauthbearer.aws`` is needed at the call site. + +With ``--duration-seconds 60`` (the AWS-STS minimum) and the default +``--run-for 120``, the ``debug=security`` log stream shows librdkafka refresh +the token mid-run (it refreshes at ~80% of the token lifetime). + +Install: + pip install 'confluent-kafka[oauthbearer-aws]' + +Prerequisites: + 1. Runs on AWS compute (EC2 / EKS / ECS / Fargate / Lambda) with an IAM role + attached — boto3's default credential chain resolves it, no static keys. + 2. The role's trust policy allows ``sts:GetWebIdentityToken`` for the audience. + 3. ``aws iam enable-outbound-web-identity-federation`` has been run once on + the account by an administrator. + 4. The role has produce + consume + create-topic rights on the cluster. + +To run: + python oauth_oidc_ccloud_aws_iam.py \\ + -b pkc-xxxx.aws.confluent.cloud:9092 \\ + --region us-east-1 \\ + --audience https://confluent.cloud/oidc \\ + --extensions logicalCluster=lkc-abc,identityPoolId=pool-xyz +""" + +import argparse +import logging +import time +import uuid + +from confluent_kafka import Consumer, Producer +from confluent_kafka.admin import AdminClient, NewTopic +from confluent_kafka.serialization import StringSerializer + + +def common_config(args): + """SASL config shared by Producer, Consumer, and AdminClient.""" + conf = { + 'bootstrap.servers': args.bootstrap_servers, + 'security.protocol': 'SASL_SSL', + 'sasl.mechanisms': 'OAUTHBEARER', + # The four AWS IAM autowire keys: method=oidc is required (the AWS path + # runs inside librdkafka's OIDC subsystem); the marker triggers + # autowiring; the config string is the AWS wire grammar + # (whitespace-separated key=value — region and audience required). + 'sasl.oauthbearer.method': 'oidc', + 'sasl.oauthbearer.metadata.authentication.type': 'aws_iam', + 'sasl.oauthbearer.config': f'region={args.region} ' + f'audience={args.audience} ' + f'duration_seconds={args.duration_seconds}', + # Surfaces the SASL handshake + OAUTHBEARER refresh events on stderr. + 'debug': 'security', + } + + # Optional SASL extensions (RFC 7628) forwarded verbatim to the broker, + # e.g. logicalCluster=lkc-abc,identityPoolId=pool-xyz + if args.extensions: + conf['sasl.oauthbearer.extensions'] = args.extensions + + return conf + + +def consumer_config(args, group_id): + cfg = common_config(args) + cfg['group.id'] = group_id + cfg['auto.offset.reset'] = 'earliest' + cfg['enable.auto.offset.store'] = False # commit offsets manually + return cfg + + +def create_topic(admin_conf, topic_name, num_partitions=1, replication_factor=3): + """Create the topic (RF=3 is the Confluent Cloud default).""" + admin = AdminClient(admin_conf) + futures = admin.create_topics( + [ + NewTopic(topic_name, num_partitions=num_partitions, replication_factor=replication_factor), + ] + ) + for topic, future in futures.items(): + try: + future.result() + print(f"[admin] Topic '{topic}' created " f"({num_partitions} partition(s), RF={replication_factor})") + except Exception as exc: + print(f"[admin] Failed to create topic '{topic}': {exc}") + raise + + +def delivery_report(err, msg): + if err is not None: + print(f"[producer] Delivery failed: {err}") + return + print( + f"[producer] Produced to {msg.topic()} [{msg.partition()}] " + f"at offset {msg.offset()}: {msg.value().decode('utf-8')}" + ) + + +def main(args): + # Unique topic + group per run so the example is self-contained. + topic_name = f"aws-iam-{uuid.uuid4()}" + group_id = f"aws-iam-consumer-{uuid.uuid4()}" + + p_conf = common_config(args) + c_conf = consumer_config(args, group_id) + a_conf = common_config(args) + + logging.basicConfig(level=logging.INFO) + + print("\n=== AWS IAM OAUTHBEARER end-to-end example ===") + print(f"bootstrap.servers: {args.bootstrap_servers}") + print(f"region: {args.region}") + print(f"audience: {args.audience}") + print(f"duration_seconds: {args.duration_seconds} " f"(auto-refresh at ~{int(args.duration_seconds * 0.8)}s)") + print(f"run-for: {args.run_for}s") + print(f"topic (generated): {topic_name}") + print(f"group.id (generated): {group_id}\n") + + create_topic(a_conf, topic_name) + + producer = Producer(p_conf) + consumer = Consumer(c_conf) + consumer.subscribe([topic_name]) + serializer = StringSerializer('utf_8') + + start = time.time() + end_at = start + args.run_for + produced = 0 + consumed = 0 + + print( + f"[loop] Producing/consuming for {args.run_for}s — " + f"watch the debug=security logs for token-refresh events.\n" + ) + + try: + while time.time() < end_at: + elapsed = time.time() - start + msg = f"hello-from-aws-iam T+{elapsed:.1f}s" + + producer.produce( + topic_name, + value=serializer(msg), + on_delivery=delivery_report, + ) + producer.poll(0) + produced += 1 + + received = consumer.poll(1.0) + if received is None: + pass # poll timeout, no message yet + elif received.error() is not None: + print(f"[consumer] error: {received.error()}") + else: + consumer.store_offsets(received) + consumed += 1 + print( + f"[consumer] Received from " + f"{received.topic()} [{received.partition()}] " + f"at offset {received.offset()}: " + f"{received.value().decode('utf-8')}" + ) + + time.sleep(args.interval) + except KeyboardInterrupt: + print("\n[main] Interrupted — flushing.") + finally: + print(f"\n[summary] Produced {produced}, consumed {consumed} " f"in {time.time() - start:.1f}s. Flushing...") + producer.flush(timeout=10) + consumer.close() + print("[summary] Done.") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='End-to-end OAUTHBEARER example via AWS IAM autowire ' '(produce + consume + admin).', + ) + parser.add_argument('-b', dest='bootstrap_servers', required=True, help='Bootstrap broker(s) (host[:port])') + parser.add_argument('--region', required=True, help='AWS region (e.g. us-east-1)') + parser.add_argument( + '--audience', + required=True, + help='OIDC audience claim the broker expects ' '(e.g. https://confluent.cloud/oidc)', + ) + parser.add_argument( + '--extensions', + default=None, + help='Optional sasl.oauthbearer.extensions value ' '(comma-separated key=value pairs)', + ) + parser.add_argument( + '--duration-seconds', + dest='duration_seconds', + type=int, + default=60, + help='STS DurationSeconds (default 60 = AWS minimum); ' 'librdkafka auto-refreshes at ~80%% of it.', + ) + parser.add_argument( + '--run-for', dest='run_for', type=int, default=120, help='Run duration in seconds (default 120).' + ) + parser.add_argument('--interval', type=float, default=5.0, help='Seconds between produce calls (default 5).') + + main(parser.parse_args()) diff --git a/pyproject.toml b/pyproject.toml index c5ae028ee..64c832491 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "confluent-kafka" -version = "2.14.2" +version = "2.14.2.dev3" description = "Confluent's Python client for Apache Kafka" classifiers = [ "Development Status :: 5 - Production/Stable", @@ -106,6 +106,7 @@ optional-dependencies.rules = { file = ["requirements/requirements-rules.txt", " optional-dependencies.avro = { file = ["requirements/requirements-avro.txt", "requirements/requirements-schemaregistry.txt"] } optional-dependencies.json = { file = ["requirements/requirements-json.txt", "requirements/requirements-schemaregistry.txt"] } optional-dependencies.protobuf = { file = ["requirements/requirements-protobuf.txt", "requirements/requirements-schemaregistry.txt"] } +optional-dependencies.oauthbearer-aws = { file = ["requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.dev = { file = [ "requirements/requirements-docs.txt", "requirements/requirements-examples.txt", @@ -114,7 +115,8 @@ optional-dependencies.dev = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.docs = { file = [ "requirements/requirements-docs.txt", "requirements/requirements-schemaregistry.txt", @@ -128,7 +130,8 @@ optional-dependencies.tests = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } optional-dependencies.examples = { file = ["requirements/requirements-examples.txt"] } optional-dependencies.soaktest = { file = ["requirements/requirements-soaktest.txt"] } optional-dependencies.all = { file = [ @@ -140,7 +143,8 @@ optional-dependencies.all = { file = [ "requirements/requirements-rules.txt", "requirements/requirements-avro.txt", "requirements/requirements-json.txt", - "requirements/requirements-protobuf.txt"] } + "requirements/requirements-protobuf.txt", + "requirements/requirements-oauthbearer-aws.txt"] } [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index d2514d3a7..0acf81a06 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -7,4 +7,5 @@ -r requirements-examples.txt -r requirements-tests.txt -r requirements-docs.txt --r requirements-soaktest.txt \ No newline at end of file +-r requirements-soaktest.txt +-r requirements-oauthbearer-aws.txt \ No newline at end of file diff --git a/requirements/requirements-oauthbearer-aws.txt b/requirements/requirements-oauthbearer-aws.txt new file mode 100644 index 000000000..3c332eafc --- /dev/null +++ b/requirements/requirements-oauthbearer-aws.txt @@ -0,0 +1 @@ +boto3>=1.42.25 diff --git a/requirements/requirements-tests-install.txt b/requirements/requirements-tests-install.txt index ff5b3c348..95f6beb24 100644 --- a/requirements/requirements-tests-install.txt +++ b/requirements/requirements-tests-install.txt @@ -4,4 +4,5 @@ -r requirements-avro.txt -r requirements-protobuf.txt -r requirements-json.txt +-r requirements-oauthbearer-aws.txt tests/trivup/trivup-0.14.0.tar.gz \ No newline at end of file diff --git a/src/confluent_kafka/_util/kv_string_parser.py b/src/confluent_kafka/_util/kv_string_parser.py new file mode 100644 index 000000000..bdb8df7d9 --- /dev/null +++ b/src/confluent_kafka/_util/kv_string_parser.py @@ -0,0 +1,79 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utility for parsing ``key=value`` strings.""" + +import re +from typing import Iterable, Iterator, Optional, Tuple + +__all__ = ["parse_kv"] + + +def parse_kv( + raw: str, + separators: Iterable[str], + context_label: Optional[str] = None, + trim_tokens: bool = True, +) -> Iterator[Tuple[str, str]]: + """Tokenize ``raw`` and yield each non-empty token as a ``(key, value)`` pair. + + Tokens are split on any character in ``separators`` (e.g. ``[',']`` for + comma-separated values or ``[' ', '\\t', '\\r', '\\n']`` for + whitespace-separated). Within each token the split is on the first ``=`` + only — values may legitimately contain ``=`` (e.g. URL query strings). + + Empty tokens (e.g. consecutive separators, or whitespace-only tokens when + ``trim_tokens`` is true) are skipped. Tokens with no ``=`` or with ``=`` + at position 0 (empty key) raise :class:`ValueError` with + ``context_label`` woven into the message when supplied. + + The default trimming behaviour mirrors librdkafka's + ``rd_string_split`` (``rdstring.c``). + + :param raw: Input string to tokenize. + :param separators: Iterable of single-character separators. Each element + may be a string of any length, but only its first character is used. + :param context_label: Optional label woven into error messages to + identify which config the malformed token came from. When ``None``, + error messages fall back to a generic ``"key=value"`` phrasing. + :param trim_tokens: When true (default), each token is stripped of + leading and trailing whitespace before being split on ``=``. Set to + false to preserve whitespace inside tokens. + :raises TypeError: ``raw`` or ``separators`` is ``None``. + :raises ValueError: A token is malformed (no ``=`` or empty key). + """ + if raw is None: + raise TypeError("raw must not be None") + if separators is None: + raise TypeError("separators must not be None") + + # Build a single-character split pattern from the supplied separators. + chars = "".join(str(s)[0] for s in separators if s) + if not chars: + # Degenerate "no separators" → treat raw as a single token. + raw_tokens = [raw] + else: + raw_tokens = re.split("[" + re.escape(chars) + "]", raw) + + for raw_token in raw_tokens: + token = raw_token.strip() if trim_tokens else raw_token + if len(token) == 0: + continue + + idx = token.find("=") + if idx <= 0: + what = f"'{context_label}'" if context_label else "key=value" + raise ValueError(f"Malformed {what} entry '{token}' (expected key=value).") + + yield token[:idx], token[idx + 1 :] diff --git a/src/confluent_kafka/oauthbearer/__init__.py b/src/confluent_kafka/oauthbearer/__init__.py new file mode 100644 index 000000000..0f4eb3e04 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Namespace package for OAUTHBEARER provider integrations. + +Cloud-specific subpackages (e.g. :mod:`confluent_kafka.oauthbearer.aws`) live +under here and are gated by their own PEP 621 extras. +""" diff --git a/src/confluent_kafka/oauthbearer/aws/__init__.py b/src/confluent_kafka/oauthbearer/aws/__init__.py new file mode 100644 index 000000000..48ecf1f87 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AWS IAM OAUTHBEARER autowire subpackage. + +The only publicly importable name in this subpackage is +:func:`confluent_kafka.oauthbearer.aws.aws_autowire.create_handler`, loaded by +core's C extension when the user sets +``sasl.oauthbearer.metadata.authentication.type=aws_iam``. All other modules +are private (underscore-prefixed) and not re-exported here. + +Install with:: + + pip install 'confluent-kafka[oauthbearer-aws]' +""" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py new file mode 100644 index 000000000..7a1cf69ad --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py @@ -0,0 +1,37 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Marker constants identifying the AWS IAM OAUTHBEARER autowire path. + +The config-key/value pair that activates the AWS OAUTHBEARER autowire path. + +The C dispatcher in ``src/confluent_kafka/src/confluent_kafka.c`` keeps its +own literal copies of these values for compile-time use; the drift-guard +test in ``tests/oauthbearer/aws/test_aws_iam_marker.py`` asserts the C-side +literals and these Python constants stay in lock-step. + +These strings are part of the cross-language wire contract — bumping either +is a major version change on ``confluent-kafka``. +""" + +__all__ = ["AWS_IAM_MARKER_KEY", "AWS_IAM_MARKER_VALUE"] + + +#: Config key that activates the AWS IAM autowire path when set +#: to :data:`AWS_IAM_MARKER_VALUE`. +AWS_IAM_MARKER_KEY: str = "sasl.oauthbearer.metadata.authentication.type" + +#: On-wire value of :data:`AWS_IAM_MARKER_KEY` that selects AWS IAM +#: authentication. +AWS_IAM_MARKER_VALUE: str = "aws_iam" diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py new file mode 100644 index 000000000..31395e337 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_jwt_subject_extractor.py @@ -0,0 +1,90 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: extracts the ``sub`` claim from an unverified AWS-minted JWT. + +No signature verification — STS signs, broker validates. +""" + +import base64 +import binascii +import json + +__all__ = ["extract_sub"] + + +_MAX_TOKEN_LENGTH_CHARS: int = 8192 + + +def extract_sub(jwt: str) -> str: + """Return the ``sub`` claim from the JWT payload. + + :raises ValueError: ``jwt`` is null, empty, oversized, has the wrong + segment count, fails base64url decoding, isn't valid JSON, isn't a + JSON object, or has no ``sub`` string claim (or its value is empty). + """ + if jwt is None: + raise ValueError("JWT is null.") + if jwt == "": + raise ValueError("JWT is empty.") + if len(jwt) > _MAX_TOKEN_LENGTH_CHARS: + raise ValueError(f"JWT length {len(jwt)} exceeds maximum allowed " f"({_MAX_TOKEN_LENGTH_CHARS}).") + + parts = jwt.split(".") + if len(parts) != 3: + raise ValueError(f"JWT must have exactly 3 '.'-separated segments; got {len(parts)}.") + + payload_bytes = _decode_base64url_segment(parts[1]) + try: + payload_string = payload_bytes.decode("utf-8") + except UnicodeDecodeError as exc: + raise ValueError(f"JWT payload is not valid UTF-8: {exc}") from exc + + try: + token = json.loads(payload_string) + except json.JSONDecodeError as exc: + raise ValueError(f"JWT payload is not valid JSON: {exc}") from exc + + if not isinstance(token, dict): + raise ValueError("JWT payload is not a JSON object.") + + if "sub" not in token: + raise ValueError("JWT payload is missing a 'sub' string claim.") + sub = token["sub"] + if not isinstance(sub, str): + raise ValueError("JWT payload is missing a 'sub' string claim.") + if sub == "": + raise ValueError("JWT 'sub' claim value is empty.") + return sub + + +def _decode_base64url_segment(segment: str) -> bytes: + if len(segment) == 0: + raise ValueError("JWT payload segment is empty.") + + s = segment.replace("-", "+").replace("_", "/") + remainder = len(s) % 4 + if remainder == 0: + pass + elif remainder == 2: + s += "==" + elif remainder == 3: + s += "=" + else: + raise ValueError("JWT payload segment has invalid base64url length.") + + try: + return base64.b64decode(s.encode("ascii"), validate=True) + except (binascii.Error, UnicodeEncodeError, ValueError) as exc: + raise ValueError(f"JWT payload segment is not valid base64url: {exc}") from exc diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py new file mode 100644 index 000000000..cac80fbdc --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_oauthbearer_config.py @@ -0,0 +1,242 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: validated ``sasl.oauthbearer.config`` dataclass + parser. + +The full grammar (whitespace-separated ``key=value`` pairs, no quoting): + + region= (required) + audience= (required) + duration_seconds=<60..3600> (default: 300) + signing_algorithm=ES384|RS256 (default: ES384) + sts_endpoint= (optional, FIPS / VPC) + principal_name= (optional, override JWT 'sub') + aws_debug=none|console (default: none) + tag_= (zero or more JWT custom claims, max 50) + +SASL extensions arrive separately via :data:`sasl_extensions` (parsed from +the typed ``sasl.oauthbearer.extensions`` config property). They are NOT +accepted inside this string under any ``extension_*`` prefix — that key +shape is rejected as an unknown key. +""" + +from dataclasses import dataclass +from typing import Dict, Optional + +from confluent_kafka._util.kv_string_parser import parse_kv + +__all__ = [ + "CONFIG_KEY", + "DEFAULT_SIGNING_ALGORITHM", + "ALLOWED_SIGNING_ALGORITHMS", + "MIN_DURATION_SECONDS", + "MAX_DURATION_SECONDS", + "DEFAULT_DURATION_SECONDS", + "TAG_KEY_PREFIX", + "MAX_TAGS", + "AWS_DEBUG_NONE", + "AWS_DEBUG_CONSOLE", + "ALLOWED_AWS_DEBUG_VALUES", + "AwsOAuthBearerConfig", +] + + +#: Config key carrying the AWS-path wire-grammar string. +CONFIG_KEY: str = "sasl.oauthbearer.config" + +#: Default JWT signing algorithm. +DEFAULT_SIGNING_ALGORITHM: str = "ES384" + +#: Signing algorithms accepted by AWS STS ``GetWebIdentityToken``. +ALLOWED_SIGNING_ALGORITHMS = ("ES384", "RS256") + +#: Minimum / default / maximum token lifetime AWS STS allows +MIN_DURATION_SECONDS: int = 60 +MAX_DURATION_SECONDS: int = 3600 +DEFAULT_DURATION_SECONDS: int = 300 + +#: Wire-grammar prefix for STS ``Tags`` entries (e.g. ``tag_team=platform``). +TAG_KEY_PREFIX: str = "tag_" + +#: AWS-enforced upper bound on number of tags per ``GetWebIdentityToken`` call. +MAX_TAGS: int = 50 + +#: Sentinel string values for ``aws_debug``. +AWS_DEBUG_NONE: str = "none" +AWS_DEBUG_CONSOLE: str = "console" + +#: ``aws_debug`` values accepted by the Python client: ``none`` and ``console``. +ALLOWED_AWS_DEBUG_VALUES = (AWS_DEBUG_NONE, AWS_DEBUG_CONSOLE) + + +# Recognised non-tag keys for the wire grammar. Anything else (other than +# ``tag_``) raises "Unknown key" during :meth:`AwsOAuthBearerConfig.parse`. +_RECOGNISED_KEYS = frozenset( + { + "region", + "audience", + "duration_seconds", + "signing_algorithm", + "sts_endpoint", + "principal_name", + "aws_debug", + } +) + +_NON_EMPTY_KEYS = frozenset( + { + "region", + "audience", + "signing_algorithm", + "sts_endpoint", + "principal_name", + "aws_debug", + } +) + + +@dataclass(frozen=True) +class AwsOAuthBearerConfig: + """Immutable view of the AWS path's ``sasl.oauthbearer.config``.""" + + region: str + audience: str + signing_algorithm: str = DEFAULT_SIGNING_ALGORITHM + duration_seconds: int = DEFAULT_DURATION_SECONDS + sts_endpoint: Optional[str] = None + principal_name: Optional[str] = None + aws_debug: str = AWS_DEBUG_NONE + tags: Optional[Dict[str, str]] = None + sasl_extensions: Optional[Dict[str, str]] = None + + def __post_init__(self) -> None: + """Final-state validation. Raises :class:`ValueError` on bad input.""" + if not isinstance(self.region, str) or self.region == "": + raise ValueError(f"{CONFIG_KEY} 'region' must not be empty.") + if not isinstance(self.audience, str) or self.audience == "": + raise ValueError(f"{CONFIG_KEY} 'audience' must not be empty.") + if self.signing_algorithm not in ALLOWED_SIGNING_ALGORITHMS: + raise ValueError( + f"{CONFIG_KEY} 'signing_algorithm' must be 'ES384' or 'RS256'; " f"got {self.signing_algorithm!r}." + ) + # bool is-a int in Python — reject explicitly so True/False doesn't + # slip through as 1/0. + if not isinstance(self.duration_seconds, int) or isinstance(self.duration_seconds, bool): + raise ValueError(f"{CONFIG_KEY} 'duration_seconds' must be an integer.") + if not (MIN_DURATION_SECONDS <= self.duration_seconds <= MAX_DURATION_SECONDS): + raise ValueError( + f"{CONFIG_KEY} 'duration_seconds' must be between " + f"{MIN_DURATION_SECONDS} and {MAX_DURATION_SECONDS} inclusive; " + f"got {self.duration_seconds}." + ) + if self.sts_endpoint is not None and self.sts_endpoint == "": + raise ValueError(f"{CONFIG_KEY} 'sts_endpoint' must not be empty.") + if self.principal_name is not None and self.principal_name == "": + raise ValueError(f"{CONFIG_KEY} 'principal_name' must not be empty.") + if self.aws_debug not in ALLOWED_AWS_DEBUG_VALUES: + raise ValueError(f"{CONFIG_KEY} 'aws_debug' must be one of: none, console. Got {self.aws_debug!r}.") + if self.tags is not None: + if not isinstance(self.tags, dict): + raise ValueError(f"{CONFIG_KEY} 'tags' must be a dict.") + if len(self.tags) > MAX_TAGS: + raise ValueError(f"{CONFIG_KEY} has {len(self.tags)} tags; AWS allows at " f"most {MAX_TAGS}.") + if self.sasl_extensions is not None and not isinstance(self.sasl_extensions, dict): + raise ValueError("sasl_extensions, if set, must be a dict.") + + @classmethod + def parse( + cls, + raw: str, + sasl_extensions: Optional[Dict[str, str]] = None, + ) -> "AwsOAuthBearerConfig": + """Parse the verbatim ``sasl.oauthbearer.config`` value. + + Whitespace-separated ``key=value`` tokens; the union of recognised + keys plus ``tag_`` entries. Anything else raises + :class:`ValueError`. Empty values for required-non-empty keys raise + the same. Duplicate keys → last-wins (mirrors .NET). + + :param raw: The verbatim ``sasl.oauthbearer.config`` string. + :param sasl_extensions: Pre-parsed dict from the sibling + ``sasl.oauthbearer.extensions`` property (see + :mod:`._aws_sasl_extensions_parser`). Stored on the config + unchanged. + :raises TypeError: ``raw`` is ``None``. + :raises ValueError: grammar, range, or enum violations. + """ + if raw is None: + raise TypeError("raw must not be None") + + region: Optional[str] = None + audience: Optional[str] = None + signing_algorithm: str = DEFAULT_SIGNING_ALGORITHM + duration_seconds: int = DEFAULT_DURATION_SECONDS + sts_endpoint: Optional[str] = None + principal_name: Optional[str] = None + aws_debug: str = AWS_DEBUG_NONE + tags: Optional[Dict[str, str]] = None + + for key, value in parse_kv( + raw, + separators=[" ", "\t", "\r", "\n"], + context_label=CONFIG_KEY, + ): + if key in _NON_EMPTY_KEYS and value == "": + raise ValueError(f"{CONFIG_KEY} {key!r} must not be empty.") + + if key == "region": + region = value + elif key == "audience": + audience = value + elif key == "signing_algorithm": + signing_algorithm = value + elif key == "duration_seconds": + try: + duration_seconds = int(value) + except ValueError as exc: + raise ValueError(f"{CONFIG_KEY} 'duration_seconds' must be an integer; " f"got {value!r}.") from exc + elif key == "sts_endpoint": + sts_endpoint = value + elif key == "principal_name": + principal_name = value + elif key == "aws_debug": + # Normalize case so downstream comparisons against + # ALLOWED_AWS_DEBUG_VALUES are straightforward. + aws_debug = value.lower() + elif key.startswith(TAG_KEY_PREFIX): + tag_name = key[len(TAG_KEY_PREFIX) :] + if tag_name == "": + raise ValueError(f"{CONFIG_KEY} tag key {key!r} has empty name.") + if tags is None: + tags = {} + tags[tag_name] = value # last-wins on duplicate tag names + else: + raise ValueError(f"Unknown key {key!r} in {CONFIG_KEY}.") + + if region is None: + raise ValueError(f"'region' is required in {CONFIG_KEY}.") + if audience is None: + raise ValueError(f"'audience' is required in {CONFIG_KEY}.") + + return cls( + region=region, + audience=audience, + signing_algorithm=signing_algorithm, + duration_seconds=duration_seconds, + sts_endpoint=sts_endpoint, + principal_name=principal_name, + aws_debug=aws_debug, + tags=tags, + sasl_extensions=sasl_extensions, + ) diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py b/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py new file mode 100644 index 000000000..05d133e2e --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sasl_extensions_parser.py @@ -0,0 +1,50 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: parser for the ``sasl.oauthbearer.extensions`` config property. + +The ``sasl.oauthbearer.extensions`` config carries RFC 7628 §3.1 SASL +extensions as a comma-separated ``key=value`` list. +""" + +from typing import Dict, Optional + +from confluent_kafka._util.kv_string_parser import parse_kv + +__all__ = ["CONFIG_KEY", "parse"] + + +#: Config key carrying the SASL extensions list. +CONFIG_KEY: str = "sasl.oauthbearer.extensions" + + +def parse(raw: Optional[str]) -> Optional[Dict[str, str]]: + """Parse the verbatim ``sasl.oauthbearer.extensions`` value into a dict. + + The grammar mirrors the cross-language convention for consistency. + + Returns ``None`` for ``None`` / empty input so the autowire layer can + short-circuit without constructing an empty dict. + + :raises ValueError: A token is missing ``=`` or has an empty key. + """ + if raw is None or raw == "": + return None + + result: Dict[str, str] = {} + for key, value in parse_kv(raw, separators=[","], context_label=CONFIG_KEY): + # Last-wins on duplicate keys, mirroring librdkafka. + result[key] = value + + return result if result else None diff --git a/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py new file mode 100644 index 000000000..995ddf62d --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/_aws_sts_token_provider.py @@ -0,0 +1,179 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal: Fetches OAUTHBEARER tokens via AWS STS GetWebIdentityToken.""" + +import logging +from typing import Any, Dict, Optional, Tuple + +import boto3 + +from . import _aws_jwt_subject_extractor +from ._aws_oauthbearer_config import ( + AWS_DEBUG_CONSOLE, + AwsOAuthBearerConfig, +) + +__all__ = ["AwsStsTokenProvider"] + + +# Logger name targeted by ``aws_debug=console``. Routes botocore's HTTP / +# credential-chain / signing diagnostic logs to stderr at DEBUG level. +_BOTOCORE_LOGGER_NAME = "botocore" + +#: Minimum boto3 version required by the AWS IAM path. +#: ``requirements/requirements-oauthbearer-aws.txt``. +MINIMUM_BOTO3_VERSION = "1.42.25" + + +def _version_tuple(version: str) -> Tuple[int, ...]: + """Parse a dotted version string into a tuple of its leading integers. + + Tolerant of pre-release suffixes (``"1.42.0rc1"`` -> ``(1, 42, 0)``) so the + comparison considers only the numeric ``major.minor.micro`` components. + """ + parts = [] + for segment in version.split("."): + digits = "" + for ch in segment: + if not ch.isdigit(): + break + digits += ch + parts.append(int(digits) if digits else 0) + return tuple(parts) + + +def _require_boto3_version() -> None: + """Raise :class:`ImportError` if the installed boto3 predates + :data:`MINIMUM_BOTO3_VERSION`. + + """ + if _version_tuple(boto3.__version__) < _version_tuple(MINIMUM_BOTO3_VERSION): + raise ImportError( + f"The AWS IAM OAUTHBEARER path requires boto3>={MINIMUM_BOTO3_VERSION} " + f"(for the STS GetWebIdentityToken operation), but found boto3 " + f"{boto3.__version__}. Upgrade with: " + f"pip install -U 'confluent-kafka[oauthbearer-aws]'." + ) + + +class AwsStsTokenProvider: + """Mints OAUTHBEARER tokens via AWS STS ``GetWebIdentityToken``.""" + + def __init__( + self, + config: AwsOAuthBearerConfig, + sts_client: Optional[Any] = None, + ) -> None: + """Construct a provider bound to ``config``. + + :param config: Validated :class:`AwsOAuthBearerConfig` instance. + :param sts_client: Test seam — when supplied, the provider uses this + client directly instead of constructing a real boto3 STS client. + Production callers pass ``None``. + :raises TypeError: ``config`` is ``None``. + :raises ImportError: the installed boto3 is older than + :data:`MINIMUM_BOTO3_VERSION` (checked only on the real-client + path, i.e. when ``sts_client`` is ``None``). + """ + if config is None: + raise TypeError("config must not be None") + self._cfg = config + + self._apply_aws_debug(config.aws_debug) + + if sts_client is not None: + self._sts = sts_client + else: + # Fail fast with a clear message if boto3 is present but too old for + # the STS GetWebIdentityToken operation. + _require_boto3_version() + # boto3.Session() with explicit region_name short-circuits the + # AWS_DEFAULT_REGION env lookup; client() still re-asserts the + # region for STS endpoint resolution. + session = boto3.Session(region_name=config.region) + client_kwargs: Dict[str, Any] = {"region_name": config.region} + if config.sts_endpoint: + client_kwargs["endpoint_url"] = config.sts_endpoint + self._sts = session.client("sts", **client_kwargs) + + @staticmethod + def _apply_aws_debug(aws_debug: str) -> None: + """Apply the ``aws_debug`` side-effect to botocore's logger. + + Process-wide effect, intentionally. When the user opts in with + ``aws_debug=console``, every boto3 client in the process gets + DEBUG-level stderr logs. ``aws_debug=none`` is a no-op so any + logging the user has configured elsewhere is preserved. + """ + if aws_debug == AWS_DEBUG_CONSOLE: + boto3.set_stream_logger(_BOTOCORE_LOGGER_NAME, logging.DEBUG) + # AWS_DEBUG_NONE → no-op. Other values are rejected by config validation. + + def token( + self, + oauthbearer_config: str = "", + ) -> Tuple[str, float, str, Dict[str, str]]: + """Mint a fresh JWT and return the ``oauth_cb`` 4-tuple. + + :param oauthbearer_config: The verbatim ``sasl.oauthbearer.config`` + string librdkafka passes back on every refresh. Accepted for + interface completeness but unused — the AWS path's fields are + sourced from the bound :class:`AwsOAuthBearerConfig` at + construction time, not re-parsed per refresh. + + :returns: 4-tuple ``(token, expiry_epoch_seconds, principal, + extensions)`` matching the C ``oauth_cb`` contract. + + :raises botocore.exceptions.ClientError: STS-side error + (``AccessDenied``, ``OutboundWebIdentityFederationDisabled``, + ...). The C ``oauth_cb`` wrapper converts raised exceptions + into ``rd_kafka_oauthbearer_set_token_failure``. + :raises ValueError: STS returned a malformed JWT or missing + ``Expiration``. + """ + request_kwargs: Dict[str, Any] = { + "Audience": [self._cfg.audience], + "SigningAlgorithm": self._cfg.signing_algorithm, + "DurationSeconds": self._cfg.duration_seconds, + } + if self._cfg.tags: + request_kwargs["Tags"] = [{"Key": k, "Value": v} for k, v in self._cfg.tags.items()] + + response = self._sts.get_web_identity_token(**request_kwargs) + + jwt = response.get("WebIdentityToken") + if not isinstance(jwt, str) or not jwt: + raise ValueError("STS response missing WebIdentityToken; cannot mint OAUTHBEARER token.") + + expiration = response.get("Expiration") + if expiration is None: + raise ValueError("STS response missing Expiration; cannot compute token lifetime.") + # boto3 normalises the timestamp to a tz-aware UTC datetime; + # .timestamp() returns epoch seconds as a float. + expiry_epoch_seconds = expiration.timestamp() + + principal = ( + self._cfg.principal_name + if self._cfg.principal_name is not None + else _aws_jwt_subject_extractor.extract_sub(jwt) + ) + + # Always return a dict for the extensions slot — the C oauth_cb + # wrapper's PyArg_ParseTuple uses "O!" with PyDict_Type for that slot, + # which would reject None. Empty dict is the Pythonic equivalent of + # .NET's null-Extensions case. + extensions = dict(self._cfg.sasl_extensions) if self._cfg.sasl_extensions else {} + + return jwt, expiry_epoch_seconds, principal, extensions diff --git a/src/confluent_kafka/oauthbearer/aws/aws_autowire.py b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py new file mode 100644 index 000000000..c33046196 --- /dev/null +++ b/src/confluent_kafka/oauthbearer/aws/aws_autowire.py @@ -0,0 +1,93 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public entry-point for AWS IAM OAUTHBEARER autowire. + +This is the **only publicly importable name** in the optional subpackage. +End users do not call :func:`create_handler` directly — the C dispatcher in +``src/confluent_kafka/src/confluent_kafka.c`` reaches it via:: + + PyImport_ImportModule("confluent_kafka.oauthbearer.aws.aws_autowire") + +and resolves :func:`create_handler` by name. The marker-key check is +performed in core; :func:`create_handler` is invoked only when the C +dispatcher has decided to autowire the AWS path. + +User-facing contract — four config keys:: + + "sasl.oauthbearer.method": "oidc" + "sasl.oauthbearer.metadata.authentication.type": "aws_iam" + "sasl.oauthbearer.config": "region=... audience=..." + "sasl.oauthbearer.extensions": "key=val,..." # optional + +Frozen cross-module contract of :func:`create_handler`: + +* arity: 2 positional parameters +* names: ``sasl_oauthbearer_config``, ``sasl_oauthbearer_extensions`` +* types: ``str``, ``Optional[str]`` +* return: :data:`OAuthBearerCallback` + +Bumping any of these is a breaking change requiring a major version +increment on the ``confluent-kafka`` distribution. Test-guarded by +``tests/oauthbearer/aws/test_contract.py``. +""" + +from typing import Callable, Dict, Optional, Tuple + +from . import _aws_sasl_extensions_parser +from ._aws_iam_marker import AWS_IAM_MARKER_KEY, AWS_IAM_MARKER_VALUE +from ._aws_oauthbearer_config import CONFIG_KEY, AwsOAuthBearerConfig +from ._aws_sts_token_provider import AwsStsTokenProvider + +__all__ = ["create_handler", "OAuthBearerCallback"] + +OAuthBearerCallback = Callable[[str], Tuple[str, float, str, Dict[str, str]]] + + +def create_handler( + sasl_oauthbearer_config: str, + sasl_oauthbearer_extensions: Optional[str], +) -> OAuthBearerCallback: + """Build an OAUTHBEARER refresh callback from the two OAUTHBEARER config strings. + + :param sasl_oauthbearer_config: The verbatim ``sasl.oauthbearer.config`` + value (whitespace-separated ``key=value`` pairs). Must be non-empty. + :param sasl_oauthbearer_extensions: The verbatim + ``sasl.oauthbearer.extensions`` value (comma-separated ``key=value`` + pairs, RFC 7628 §3.1). May be ``None`` or empty when the user has + no extensions configured. + + :returns: A callable matching :data:`OAuthBearerCallback`. + + :raises ValueError: ``sasl_oauthbearer_config`` is ``None`` or empty; + the wire-grammar parse fails (unknown key, malformed token, missing + required field, range/enum violation, etc.). + :raises ImportError: the installed boto3 predates the minimum required by + the AWS IAM path (boto3 is present but too old for STS + ``GetWebIdentityToken``). + :raises RuntimeError: AWS SDK reachability or initialisation failure + (e.g. unknown region, malformed ``sts_endpoint``). + """ + if not sasl_oauthbearer_config: + raise ValueError( + f"'{AWS_IAM_MARKER_KEY}={AWS_IAM_MARKER_VALUE}' is set but " + f"'{CONFIG_KEY}' is missing or empty. The AWS IAM autowire path " + f"requires region and audience to be supplied via " + f"{CONFIG_KEY} (e.g. \"region=us-east-1 audience=https://...\")." + ) + + sasl_extensions = _aws_sasl_extensions_parser.parse(sasl_oauthbearer_extensions) + config = AwsOAuthBearerConfig.parse(sasl_oauthbearer_config, sasl_extensions) + provider = AwsStsTokenProvider(config) + return provider.token diff --git a/src/confluent_kafka/src/confluent_kafka.c b/src/confluent_kafka/src/confluent_kafka.c index 0a2c58495..9ae916019 100644 --- a/src/confluent_kafka/src/confluent_kafka.c +++ b/src/confluent_kafka/src/confluent_kafka.c @@ -2585,6 +2585,215 @@ static void common_conf_set_software(rd_kafka_conf_t *conf) { * Returns a conf object on success or NULL on failure in which case * an exception has been raised. */ +/** + * @brief Detect the aws_iam OAUTHBEARER autowire marker in the user's + * config dict and replace it with an oauth_cb callable sourced from + * the optional confluent_kafka.oauthbearer.aws subpackage. + * + * User contract (all three keys required when marker is set): + * sasl.oauthbearer.method = "oidc" + * sasl.oauthbearer.metadata.authentication.type = "aws_iam" + * sasl.oauthbearer.config = "region=... audience=..." + * + * Plus optional: + * sasl.oauthbearer.extensions = "key=val,key=val" + * + * IMPORTANT: These literals MUST match + * src/confluent_kafka/oauthbearer/aws/_aws_iam_marker.py. + * tests/oauthbearer/aws/test_aws_iam_marker.py guards against drift. + * + * Flow: + * 1. Skip if the marker key is absent or its value is not "aws_iam". + * (Other values like "azure_imds" flow through to librdkafka unchanged.) + * 2. If oauth_cb is already set, the explicit handler wins (precedence + * rule) — return immediately, leaving the marker in place. + * 3. Require sasl.oauthbearer.method == "oidc". The AWS IAM path runs as + * a high-level-client refresh callback inside librdkafka's OIDC + * subsystem (parallel to azure_imds); without method=oidc the + * configuration is rejected by design. + * 4. Require non-empty sasl.oauthbearer.config (carries region, audience, + * duration_seconds, etc.). + * 5. Lazy-import confluent_kafka.oauthbearer.aws.aws_autowire. If the + * optional 'oauthbearer-aws' extra isn't installed, the import chain + * bottoms out in a missing boto3; rewrite that ModuleNotFoundError + * into a friendly install hint while preserving the original via + * __cause__ (Python exception-chaining). + * 6. Call aws_autowire.create_handler(config_str, extensions_str_or_None). + * The returned Python callable replaces oauth_cb in confdict; the + * existing config-iteration loop below picks it up and registers the + * librdkafka refresh callback. + * 7. Leave the marker in confdict and return — it is passed to librdkafka + * unchanged. + * + * Native librdkafka handling (this REQUIRES an AWS-IAM-aware build; the + * bundled librdkafka floor is >= v2.14.2-aws-iam.2-dev — see + * MIN_RD_KAFKA_VERSION in confluent_kafka.h): + * - `aws_iam` is a recognized value for metadata.authentication.type + * (rdkafka_conf.c), so rd_kafka_conf_set() accepts the marker. + * - librdkafka bypasses the mandatory token.endpoint.url check for + * AWS_IAM, and the grant-type finalize runs only for type==NONE — so + * no sentinel OIDC fields are required. + * - librdkafka registers its own aws_iam refresh cb only when no + * token_refresh_cb is set (rdkafka.c). Since step 6 registered our + * oauth_cb, our callback owns the token fetch and librdkafka's native + * aws_iam stub never fires; the built-in OIDC fetcher is likewise + * skipped when a refresh cb is present (rdkafka_conf.c). + * + * History: earlier versions stripped the marker and injected 3 sentinel + * OIDC fields (token.endpoint.url, client.id, client.secret) so the path + * also worked on stock librdkafka that didn't know `aws_iam`. That strip + * + sentinels was removed once the AWS-IAM-aware librdkafka became the + * bundled floor. Project memory: project_aws_iam_python_alignment.md #6. + * + * @returns 0 on success (no-op or autowire complete), -1 on error + * (PyErr_* is set; caller goto outer_err). + */ +static int resolve_aws_oauthbearer_marker(PyObject *confdict) { + static const char MARKER_KEY[] = + "sasl.oauthbearer.metadata.authentication.type"; + static const char MARKER_VALUE[] = "aws_iam"; + static const char METHOD_KEY[] = "sasl.oauthbearer.method"; + static const char METHOD_OIDC_VALUE[] = "oidc"; + static const char CONFIG_KEY[] = "sasl.oauthbearer.config"; + static const char EXTENSIONS_KEY[] = "sasl.oauthbearer.extensions"; + static const char OAUTH_CB_KEY[] = "oauth_cb"; + static const char AUTOWIRE_MODULE[] = + "confluent_kafka.oauthbearer.aws.aws_autowire"; + static const char CREATE_HANDLER[] = "create_handler"; + static const char METHOD_REQUIREMENT_ERR[] = + "'sasl.oauthbearer.metadata.authentication.type=aws_iam' requires " + "'sasl.oauthbearer.method=oidc'. The AWS IAM path runs as a " + "high-level-client refresh callback inside librdkafka's OIDC " + "subsystem (parallel to azure_imds); without method=oidc the " + "configuration is rejected by design."; + static const char CONFIG_REQUIREMENT_ERR[] = + "'sasl.oauthbearer.metadata.authentication.type=aws_iam' is set " + "but 'sasl.oauthbearer.config' is missing or empty. The AWS IAM " + "autowire path requires region and audience to be supplied via " + "sasl.oauthbearer.config " + "(e.g. \"region=us-east-1 audience=https://...\")."; + static const char FRIENDLY_IMPORT_ERR[] = + "Config 'sasl.oauthbearer.metadata.authentication.type=aws_iam' " + "requires the optional 'oauthbearer-aws' extra. Install with:\n" + " pip install 'confluent-kafka[oauthbearer-aws]'"; + + PyObject *marker; + PyObject *cb; + PyObject *method; + PyObject *cfg_str; + PyObject *ext_str; + PyObject *mod; + PyObject *func; + PyObject *callback; + const char *marker_c; + const char *method_c; + + /* 1. Marker present and equals "aws_iam"? */ + marker = PyDict_GetItemString(confdict, MARKER_KEY); + if (!marker || !PyUnicode_Check(marker)) { + return 0; + } + marker_c = PyUnicode_AsUTF8(marker); + if (!marker_c) { + /* Non-ASCII / malformed unicode — treat as not-our-value. */ + PyErr_Clear(); + return 0; + } + if (strcmp(marker_c, MARKER_VALUE) != 0) { + return 0; + } + + /* 2. Explicit oauth_cb wins (precedence). Leave the marker in place and + * return — the AWS-IAM-aware librdkafka accepts `aws_iam` and uses the + * user's oauth_cb (it only registers its own native aws_iam refresh cb + * when no refresh cb is set; see rdkafka.c). Nothing else to do. */ + cb = PyDict_GetItemString(confdict, OAUTH_CB_KEY); + if (cb && cb != Py_None) { + return 0; + } + + /* 3. Require sasl.oauthbearer.method = "oidc". */ + method = PyDict_GetItemString(confdict, METHOD_KEY); + if (!method || !PyUnicode_Check(method)) { + PyErr_SetString(PyExc_ValueError, METHOD_REQUIREMENT_ERR); + return -1; + } + method_c = PyUnicode_AsUTF8(method); + if (!method_c || strcmp(method_c, METHOD_OIDC_VALUE) != 0) { + if (!method_c) { + PyErr_Clear(); + } + PyErr_SetString(PyExc_ValueError, METHOD_REQUIREMENT_ERR); + return -1; + } + + /* 4. Require non-empty sasl.oauthbearer.config. */ + cfg_str = PyDict_GetItemString(confdict, CONFIG_KEY); + if (!cfg_str || !PyUnicode_Check(cfg_str) || + PyUnicode_GET_LENGTH(cfg_str) == 0) { + PyErr_SetString(PyExc_ValueError, CONFIG_REQUIREMENT_ERR); + return -1; + } + + /* 5. sasl.oauthbearer.extensions is optional. */ + ext_str = PyDict_GetItemString(confdict, EXTENSIONS_KEY); + + /* 6. Lazy import; rewrite ImportError friendly. */ + mod = PyImport_ImportModule(AUTOWIRE_MODULE); + if (!mod) { + PyObject *cause_type = NULL; + PyObject *cause_value = NULL; + PyObject *cause_tb = NULL; + PyObject *new_exc; + + PyErr_Fetch(&cause_type, &cause_value, &cause_tb); + PyErr_NormalizeException(&cause_type, &cause_value, &cause_tb); + + new_exc = PyObject_CallFunction( + PyExc_ImportError, "s", FRIENDLY_IMPORT_ERR); + if (new_exc) { + if (cause_value) { + Py_INCREF(cause_value); + PyException_SetCause(new_exc, cause_value); + } + PyErr_SetObject(PyExc_ImportError, new_exc); + Py_DECREF(new_exc); + } + Py_XDECREF(cause_type); + Py_XDECREF(cause_value); + Py_XDECREF(cause_tb); + return -1; + } + + /* 7. create_handler(cfg_str, ext_str or None) -> Callable */ + func = PyObject_GetAttrString(mod, CREATE_HANDLER); + Py_DECREF(mod); + if (!func) { + return -1; + } + callback = PyObject_CallFunction( + func, "OO", cfg_str, ext_str ? ext_str : Py_None); + Py_DECREF(func); + if (!callback) { + return -1; + } + if (PyDict_SetItemString(confdict, OAUTH_CB_KEY, callback) == -1) { + Py_DECREF(callback); + return -1; + } + Py_DECREF(callback); + + /* 8. Leave the marker in confdict — it is passed to librdkafka + * unchanged. The AWS-IAM-aware librdkafka recognizes `aws_iam`, + * bypasses the token.endpoint.url + grant-type requirements + * (rdkafka_conf.c), and — because step 7 registered our oauth_cb + * (token_refresh_cb) — uses our callback for the token rather than its + * native aws_iam stub (which only registers when no refresh cb is set; + * rdkafka.c). No marker strip, no sentinel OIDC fields needed. */ + return 0; +} + + rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype, Handle *h, PyObject *args, @@ -2703,6 +2912,15 @@ rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype, PyDict_DelItemString(confdict, "default.topic.config"); } + /* AWS IAM OAUTHBEARER autowire: when the user sets the marker + * sasl.oauthbearer.metadata.authentication.type=aws_iam, replace it + * with an oauth_cb sourced from the optional oauthbearer-aws extra. + * No-op when the marker is absent. See resolve_aws_oauthbearer_marker + * above for the full flow and the unconditional-strip TODO. */ + if (resolve_aws_oauthbearer_marker(confdict) == -1) { + goto outer_err; + } + /* Convert config dict to config key-value pairs. */ while (PyDict_Next(confdict, &pos, &ko, &vo)) { PyObject *ks; diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index 0a4123b60..9c34f92be 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -38,24 +38,31 @@ /** * @brief confluent-kafka-python version, must match that of pyproject.toml. */ -#define CFL_VERSION_STR "2.14.2" +#define CFL_VERSION_STR "2.14.2.dev3" /** * Minimum required librdkafka version. This is checked both during * build-time (just below) and runtime (see confluent_kafka.c). * Make sure to keep the MIN_RD_KAFKA_VERSION, MIN_VER_ERRSTR and #error * defines and strings in sync. + * + * Floor is v2.14.2 to match the bundled AWS-IAM-aware librdkafka + * (v2.14.2-aws-iam.2-dev), which the OAUTHBEARER aws_iam autowire path + * requires (the dispatcher passes the `aws_iam` marker straight to + * librdkafka — see resolve_aws_oauthbearer_marker in confluent_kafka.c). + * Note: the version number alone cannot distinguish an aws_iam-aware build + * from stock 2.14.2; the AWS-IAM behavior is guaranteed by the bundled wheel. */ -#define MIN_RD_KAFKA_VERSION 0x020e00ff +#define MIN_RD_KAFKA_VERSION 0x020e02ff #ifdef __APPLE__ #define MIN_VER_ERRSTR \ - "confluent-kafka-python requires librdkafka v2.14.0 or later. " \ + "confluent-kafka-python requires librdkafka v2.14.2 or later. " \ "Install the latest version of librdkafka from Homebrew by running " \ "`brew install librdkafka` or `brew upgrade librdkafka`" #else #define MIN_VER_ERRSTR \ - "confluent-kafka-python requires librdkafka v2.14.0 or later. " \ + "confluent-kafka-python requires librdkafka v2.14.2 or later. " \ "Install the latest version of librdkafka from the Confluent " \ "repositories, see http://docs.confluent.io/current/installation.html" #endif @@ -63,10 +70,10 @@ #if RD_KAFKA_VERSION < MIN_RD_KAFKA_VERSION #ifdef __APPLE__ #error \ - "confluent-kafka-python requires librdkafka v2.14.0 or later. Install the latest version of librdkafka from Homebrew by running `brew install librdkafka` or `brew upgrade librdkafka`" + "confluent-kafka-python requires librdkafka v2.14.2 or later. Install the latest version of librdkafka from Homebrew by running `brew install librdkafka` or `brew upgrade librdkafka`" #else #error \ - "confluent-kafka-python requires librdkafka v2.14.0 or later. Install the latest version of librdkafka from the Confluent repositories, see http://docs.confluent.io/current/installation.html" + "confluent-kafka-python requires librdkafka v2.14.2 or later. Install the latest version of librdkafka from the Confluent repositories, see http://docs.confluent.io/current/installation.html" #endif #endif diff --git a/tests/_util/__init__.py b/tests/_util/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/_util/test_kv_string_parser.py b/tests/_util/test_kv_string_parser.py new file mode 100644 index 000000000..48700cc17 --- /dev/null +++ b/tests/_util/test_kv_string_parser.py @@ -0,0 +1,130 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka._util.kv_string_parser.""" + +import pytest + +from confluent_kafka._util.kv_string_parser import parse_kv + +# ---- Null guards ---- + + +def test_null_raw_raises(): + with pytest.raises(TypeError): + list(parse_kv(None, [","])) + + +def test_null_separators_raises(): + with pytest.raises(TypeError): + list(parse_kv("a=1", None)) + + +# ---- Empty / whitespace-only input ---- + + +def test_empty_raw_yields_nothing(): + assert list(parse_kv("", [","])) == [] + + +def test_whitespace_only_raw_yields_nothing(): + # After split + trim, no real tokens remain. + assert list(parse_kv(" ", [","])) == [] + + +# ---- Single separator ---- + + +def test_single_separator_comma_basic_case(): + pairs = list(parse_kv("a=1,b=2", [","])) + assert pairs == [("a", "1"), ("b", "2")] + + +# ---- Multiple separators (whitespace-tolerant case) ---- + + +def test_multiple_separators_all_recognized(): + pairs = list(parse_kv("a=1\tb=2\nc=3 d=4", [" ", "\t", "\r", "\n"])) + assert [k for k, _ in pairs] == ["a", "b", "c", "d"] + + +# ---- Empty tokens skipped ---- + + +def test_consecutive_separators_skip_empty_tokens(): + pairs = list(parse_kv("a=1,,b=2,", [","])) + assert len(pairs) == 2 + assert pairs == [("a", "1"), ("b", "2")] + + +# ---- Trim toggle ---- + + +def test_trim_tokens_default_true_strips_leading_and_trailing(): + pairs = list(parse_kv(" a=1 , b=2 ", [","])) + assert pairs == [("a", "1"), ("b", "2")] + + +def test_trim_tokens_false_preserves_whitespace_inside_tokens(): + pairs = list(parse_kv(" a=1 ,b=2 ", [","], trim_tokens=False)) + assert pairs == [(" a", "1 "), ("b", "2 ")] + + +# ---- Malformed tokens ---- + + +def test_token_without_equals_raises(): + with pytest.raises(ValueError, match="Malformed.*bad"): + list(parse_kv("a=1,bad,c=3", [","])) + + +def test_token_starts_with_equals_raises(): + with pytest.raises(ValueError, match="Malformed"): + list(parse_kv("=value", [","])) + + +# ---- Value-content semantics ---- + + +def test_value_contains_equals_kept_verbatim(): + # Split is on FIRST '=' only — the rest stays in the value. + pairs = list(parse_kv("key=val=ue", [","])) + assert pairs == [("key", "val=ue")] + + +def test_empty_value_allowed(): + # RFC 7628 SASL extensions allow empty values; AWS allows empty tag values. + pairs = list(parse_kv("key=", [","])) + assert pairs == [("key", "")] + + +# ---- Duplicate keys preserved (caller's responsibility to dedupe) ---- + + +def test_duplicate_keys_both_yielded(): + pairs = list(parse_kv("a=1,a=2", [","])) + assert pairs == [("a", "1"), ("a", "2")] + + +# ---- context_label weaves into error messages ---- + + +def test_context_label_present_appears_in_error_message(): + with pytest.raises(ValueError, match="my.config.key"): + list(parse_kv("bad", [","], context_label="my.config.key")) + + +def test_context_label_none_uses_generic_error_phrase(): + with pytest.raises(ValueError, match="key=value"): + list(parse_kv("bad", [","])) diff --git a/tests/oauthbearer/__init__.py b/tests/oauthbearer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/oauthbearer/aws/__init__.py b/tests/oauthbearer/aws/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/oauthbearer/aws/test_aws_autowire.py b/tests/oauthbearer/aws/test_aws_autowire.py new file mode 100644 index 000000000..6e82a5dbe --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_autowire.py @@ -0,0 +1,270 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws.aws_autowire.create_handler. + +End-to-end orchestration tests. The frozen-signature contract guard lives +in test_contract.py. +""" + +import pytest + +pytest.importorskip("boto3") + +from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler # noqa: E402 + +# ---- Input validation (defensive checks for direct callers) ---- + + +def test_create_handler_null_config_raises(): + with pytest.raises(ValueError, match="missing or empty"): + create_handler(None, None) + + +def test_create_handler_empty_config_raises(): + with pytest.raises(ValueError, match="missing or empty"): + create_handler("", None) + + +def test_create_handler_error_message_names_marker_and_config_key(): + with pytest.raises(ValueError) as exc_info: + create_handler(None, None) + # User should see both: the marker that's set AND the config key that's missing. + assert "sasl.oauthbearer.metadata.authentication.type" in str(exc_info.value) + assert "aws_iam" in str(exc_info.value) + assert "sasl.oauthbearer.config" in str(exc_info.value) + + +# ---- Parse delegation: errors from AwsOAuthBearerConfig.parse propagate ---- + + +def test_create_handler_missing_region_raises(): + with pytest.raises(ValueError, match="region.*required"): + create_handler("audience=https://a", None) + + +def test_create_handler_missing_audience_raises(): + with pytest.raises(ValueError, match="audience.*required"): + create_handler("region=us-east-1", None) + + +def test_create_handler_invalid_signing_algorithm_raises(): + with pytest.raises(ValueError, match="signing_algorithm"): + create_handler( + "region=us-east-1 audience=https://a signing_algorithm=HS256", + None, + ) + + +def test_create_handler_invalid_duration_raises(): + with pytest.raises(ValueError, match="duration_seconds"): + create_handler( + "region=us-east-1 audience=https://a duration_seconds=10", + None, + ) + + +def test_create_handler_unknown_key_raises(): + with pytest.raises(ValueError, match="not_a_key"): + create_handler( + "region=us-east-1 audience=https://a not_a_key=foo", + None, + ) + + +def test_create_handler_invalid_extensions_grammar_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + create_handler( + "region=us-east-1 audience=https://a", + "noEqualsHere", + ) + + +def test_create_handler_aws_debug_invalid_value_raises(): + """An unsupported aws_debug value is rejected — the client accepts none/console only.""" + with pytest.raises(ValueError, match="aws_debug.*none, console"): + create_handler( + "region=us-east-1 audience=https://a aws_debug=log4net", + None, + ) + + +# ---- Success cases — handler returned, no throw, no STS call yet ---- + + +def test_create_handler_marker_only_minimum_config_returns_handler(): + handler = create_handler( + "region=us-east-1 audience=https://a", + None, + ) + assert handler is not None + assert callable(handler) + + +def test_create_handler_all_optional_fields_returns_handler(): + handler = create_handler( + "region=us-east-1 audience=https://a " + "duration_seconds=900 signing_algorithm=RS256 " + "sts_endpoint=https://sts.us-east-1.amazonaws.com " + "principal_name=my-principal " + "aws_debug=none " + "tag_team=platform tag_environment=prod", + None, + ) + assert handler is not None + assert callable(handler) + + +def test_create_handler_tag_config_handler_ready(): + handler = create_handler( + "region=us-east-1 audience=https://a tag_team=platform tag_environment=prod", + None, + ) + assert callable(handler) + + +def test_create_handler_principal_name_override_handler_ready(): + handler = create_handler( + "region=us-east-1 audience=https://a principal_name=explicit-principal", + None, + ) + assert callable(handler) + + +# ---- Extensions argument handling ---- + + +def test_create_handler_null_extensions_treats_as_absent(): + handler = create_handler( + "region=us-east-1 audience=https://a", + None, + ) + assert callable(handler) + + +def test_create_handler_empty_extensions_treats_as_absent(): + handler = create_handler( + "region=us-east-1 audience=https://a", + "", + ) + assert callable(handler) + + +def test_create_handler_single_extension_handler_ready(): + handler = create_handler( + "region=us-east-1 audience=https://a", + "logicalCluster=lkc-abc", + ) + assert callable(handler) + + +def test_create_handler_multiple_extensions_handler_ready(): + handler = create_handler( + "region=us-east-1 audience=https://a", + "logicalCluster=lkc-abc,identityPoolId=pool-x", + ) + assert callable(handler) + + +# ---- Returned callable invokes correctly with a real STS round-trip stubbed at the boto3 layer ---- +# We can't easily stub the boto3 client AwsStsTokenProvider creates internally without +# patching boto3.Session. End-to-end with-injected-client tests live in +# test_aws_sts_token_provider.py; here we cover the *closure* level (no STS call yet). + + +def test_create_handler_does_not_call_sts_at_construction(): + """create_handler must NOT make an STS call — credential resolution and + network I/O are deferred to the first invocation of the returned callable. + Verify by patching boto3.Session to detect any client.get_web_identity_token call. + """ + from unittest.mock import MagicMock, patch + + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_session.client.return_value = mock_client + create_handler("region=us-east-1 audience=https://a", None) + # Session/client construction OK; get_web_identity_token must not have been called. + mock_client.get_web_identity_token.assert_not_called() + + +def test_create_handler_returned_callable_when_invoked_calls_sts(): + """When the returned callable is invoked, it triggers exactly one STS call.""" + import datetime + from unittest.mock import MagicMock, patch + + canned_response = { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime(2099, 4, 21, 6, 6, 47, tzinfo=datetime.timezone.utc), + } + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_client.get_web_identity_token.return_value = canned_response + mock_session.client.return_value = mock_client + + handler = create_handler("region=us-east-1 audience=https://a", None) + result = handler("ignored-passthrough-string") + + mock_client.get_web_identity_token.assert_called_once() + assert isinstance(result, tuple) + assert len(result) == 4 + token, expiry, principal, extensions = result + assert token == canned_response["WebIdentityToken"] + assert isinstance(expiry, float) + assert principal.startswith("arn:") + assert extensions == {} + + +def test_create_handler_returned_callable_round_trips_extensions(): + """Extensions configured via the typed property flow through to the + 4-tuple's extensions slot.""" + import datetime + from unittest.mock import MagicMock, patch + + canned_response = { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime(2099, 4, 21, 6, 6, 47, tzinfo=datetime.timezone.utc), + } + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + mock_client = MagicMock() + mock_client.get_web_identity_token.return_value = canned_response + mock_session.client.return_value = mock_client + + handler = create_handler( + "region=us-east-1 audience=https://a", + "logicalCluster=lkc-abc,identityPoolId=pool-x", + ) + _, _, _, extensions = handler("") + + assert extensions == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +# ---- Test helpers ---- + + +def _base64url(data: bytes) -> str: + import base64 + + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _canned_jwt(sub: str = "arn:aws:iam::123:role/R") -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" diff --git a/tests/oauthbearer/aws/test_aws_iam_marker.py b/tests/oauthbearer/aws/test_aws_iam_marker.py new file mode 100644 index 000000000..079247619 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_iam_marker.py @@ -0,0 +1,48 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Drift-guard tests for the AWS IAM marker constants.""" + +from confluent_kafka.oauthbearer.aws._aws_iam_marker import ( + AWS_IAM_MARKER_KEY, + AWS_IAM_MARKER_VALUE, +) + + +def test_marker_key_is_locked_value(): + assert AWS_IAM_MARKER_KEY == "sasl.oauthbearer.metadata.authentication.type" + + +def test_marker_value_is_locked_value(): + assert AWS_IAM_MARKER_VALUE == "aws_iam" + + +def test_c_dispatcher_recognises_python_authoritative_marker(): + import pytest + + from confluent_kafka import Producer + + with pytest.raises(ValueError, match="method=oidc"): + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + AWS_IAM_MARKER_KEY: AWS_IAM_MARKER_VALUE, + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + # Deliberately omitting sasl.oauthbearer.method to trigger the + # dispatcher's precondition check. If the dispatcher's C-side + # marker literals differ from AWS_IAM_MARKER_KEY/VALUE, the + # dispatcher won't fire and this ValueError won't raise. + } + ) diff --git a/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py b/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py new file mode 100644 index 000000000..724fb3584 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_jwt_subject_extractor.py @@ -0,0 +1,238 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_jwt_subject_extractor.""" + +import base64 + +import pytest + +from confluent_kafka.oauthbearer.aws._aws_jwt_subject_extractor import extract_sub + +# ---- Test helpers ---- + + +def _base64url_encode(data: bytes) -> str: + """Base64url-encodes a byte array: standard base64, trim '=' padding, + swap '+' → '-' and '/' → '_'. + """ + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _make_jwt(payload_json: str, alg: str = "ES384") -> str: + """Build a 3-segment JWT with the given payload JSON. + + Header and signature segments are placeholders — ``extract_sub`` only + decodes the payload (``parts[1]``). + """ + header = _base64url_encode(f'{{"alg":"{alg}","typ":"JWT"}}'.encode("utf-8")) + payload = _base64url_encode(payload_json.encode("utf-8")) + return f"{header}.{payload}.sig" + + +# ---- Real STS payloads (from .NET test suite — signature stripped, only payload matters) ---- + + +_EXPECTED_ROLE_ARN = "arn:aws:iam::708975691912:role/prashah-iam-sts-test-role" + +_REAL_ES384_JWT = ( + "eyJraWQiOiJFQzM4NF8wIiwidHlwIjoiSldUIiwiYWxnIjoiRVMzODQifQ" + ".eyJhdWQiOiJodHRwczovL2FwaS5leGFtcGxlLmNvbSIsInN1YiI6ImFybjphd3M6aWFtOjo3MDg5N" + "zU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtdGVzdC1yb2xlIiwiaHR0cHM6Ly9zdHMuYW1hem9u" + "YXdzLmNvbS8iOnsiZWMyX2luc3RhbmNlX3NvdXJjZV92cGMiOiJ2cGMtYWQ4NzMzYzQiLCJlYzJfcm9" + "sZV9kZWxpdmVyeSI6IjIuMCIsIm9yZ19pZCI6Im8tMHgzdDh1bW9seiIsImF3c19hY2NvdW50IjoiNz" + "A4OTc1NjkxOTEyIiwib3VfcGF0aCI6WyJvLTB4M3Q4dW1vbHovci16YzVqL291LXpjNWotNXJwd3pqb" + "HIvIl0sIm9yaWdpbmFsX3Nlc3Npb25fZXhwIjoiMjAyNi0wNS0xNFQyMzoxNjozMloiLCJzb3VyY2Vf" + "cmVnaW9uIjoiZXUtbm9ydGgtMSIsImVjMl9zb3VyY2VfaW5zdGFuY2VfYXJuIjoiYXJuOmF3czplYzI" + "6ZXUtbm9ydGgtMTo3MDg5NzU2OTE5MTI6aW5zdGFuY2UvaS0wOTc5MGY5OTY4YzExYTFjNyIsInByaW" + "5jaXBhbF9pZCI6ImFybjphd3M6aWFtOjo3MDg5NzU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtd" + "GVzdC1yb2xlIiwicHJpbmNpcGFsX3RhZ3MiOnsiZGl2dnlfb3duZXIiOiJwcmFzaGFoQGNvbmZsdWVu" + "dC5pbyIsImRpdnZ5X2xhc3RfbW9kaWZpZWRfYnkiOiJwcmFzaGFoQGNvbmZsdWVudC5pbyJ9LCJlYzJ" + "faW5zdGFuY2Vfc291cmNlX3ByaXZhdGVfaXB2NCI6IjE3Mi4zMS4zLjEwOSJ9LCJpc3MiOiJodHRwcz" + "ovL2ExZWJjNzA1LWNkNGQtNDJiNC05M2I1LTk2ZTkzYWNmYjQzMS50b2tlbnMuc3RzLmdsb2JhbC5hc" + "GkuYXdzIiwiZXhwIjoxNzc4Nzc4NzY2LCJpYXQiOjE3Nzg3Nzg0NjYsImp0aSI6IjFkM2QzZmMyLTBl" + "NzktNDI2OS05NjcxLTJmODQ4NDYxOWZiNyJ9" + ".sig" +) + +_REAL_RS256_JWT = ( + "eyJraWQiOiJSU0FfMCIsInR5cCI6IkpXVCIsImFsZyI6IlJTMjU2In0" + ".eyJhdWQiOiJodHRwczovL2FwaS5leGFtcGxlLmNvbSIsInN1YiI6ImFybjphd3M6aWFtOjo3MDg5N" + "zU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtdGVzdC1yb2xlIiwiaHR0cHM6Ly9zdHMuYW1hem9u" + "YXdzLmNvbS8iOnsiZWMyX2luc3RhbmNlX3NvdXJjZV92cGMiOiJ2cGMtYWQ4NzMzYzQiLCJlYzJfcm9" + "sZV9kZWxpdmVyeSI6IjIuMCIsIm9yZ19pZCI6Im8tMHgzdDh1bW9seiIsImF3c19hY2NvdW50IjoiNz" + "A4OTc1NjkxOTEyIiwib3VfcGF0aCI6WyJvLTB4M3Q4dW1vbHovci16YzVqL291LXpjNWotNXJwd3pqb" + "HIvIl0sIm9yaWdpbmFsX3Nlc3Npb25fZXhwIjoiMjAyNi0wNS0xNFQyMzoxNjozMloiLCJzb3VyY2Vf" + "cmVnaW9uIjoiZXUtbm9ydGgtMSIsImVjMl9zb3VyY2VfaW5zdGFuY2VfYXJuIjoiYXJuOmF3czplYzI" + "6ZXUtbm9ydGgtMTo3MDg5NzU2OTE5MTI6aW5zdGFuY2UvaS0wOTc5MGY5OTY4YzExYTFjNyIsInByaW" + "5jaXBhbF9pZCI6ImFybjphd3M6aWFtOjo3MDg5NzU2OTE5MTI6cm9sZS9wcmFzaGFoLWlhbS1zdHMtd" + "GVzdC1yb2xlIiwicHJpbmNpcGFsX3RhZ3MiOnsiZGl2dnlfb3duZXIiOiJwcmFzaGFoQGNvbmZsdWVu" + "dC5pbyIsImRpdnZ5X2xhc3RfbW9kaWZpZWRfYnkiOiJwcmFzaGFoQGNvbmZsdWVudC5pbyJ9LCJlYzJ" + "faW5zdGFuY2Vfc291cmNlX3ByaXZhdGVfaXB2NCI6IjE3Mi4zMS4zLjEwOSJ9LCJpc3MiOiJodHRwcz" + "ovL2ExZWJjNzA1LWNkNGQtNDJiNC05M2I1LTk2ZTkzYWNmYjQzMS50b2tlbnMuc3RzLmdsb2JhbC5hc" + "GkuYXdzIiwiZXhwIjoxNzc4Nzc4NzY2LCJpYXQiOjE3Nzg3Nzg0NjYsImp0aSI6IjU5OTliZDc1LTBm" + "NTctNDMzMS1iOGExLWJkYzYwMjgxN2UyZiJ9" + ".sig" +) + + +# ---- Happy paths ---- + + +def test_extract_sub_role_arn_returned(): + jwt = _make_jwt('{"sub":"arn:aws:iam::123456789012:role/MyRole","iat":1}') + assert extract_sub(jwt) == "arn:aws:iam::123456789012:role/MyRole" + + +def test_extract_sub_assumed_role_arn_returned(): + jwt = _make_jwt('{"sub":"arn:aws:sts::123456789012:assumed-role/MyRole/session-name"}') + assert extract_sub(jwt) == "arn:aws:sts::123456789012:assumed-role/MyRole/session-name" + + +def test_extract_sub_other_claims_ignored(): + jwt = _make_jwt('{"iss":"https://x","sub":"arn:aws:iam::1:role/R",' '"aud":"a","exp":1,"iat":0,"jti":"j"}') + assert extract_sub(jwt) == "arn:aws:iam::1:role/R" + + +@pytest.mark.parametrize("jwt", [_REAL_ES384_JWT, _REAL_RS256_JWT]) +def test_extract_sub_real_sts_jwt_returns_expected_arn(jwt): + assert extract_sub(jwt) == _EXPECTED_ROLE_ARN + + +# ---- Base64url padding branches (% 4 ∈ {0, 2, 3}) ---- + + +def test_extract_sub_unpadded_base64url_works(): + unpadded = _make_jwt('{"sub":"a"}') + assert extract_sub(unpadded) == "a" + + +def test_extract_sub_padded_base64url_also_works(): + """If user pre-pads with '=' (non-canonical but tolerated), still decodes.""" + header = _base64url_encode(b'{"alg":"none"}') + # Encode WITHOUT stripping padding. + payload = base64.b64encode(b'{"sub":"abc"}').decode("ascii").replace("+", "-").replace("/", "_") + jwt_3seg = f"{header}.{payload}.sig" + assert extract_sub(jwt_3seg) == "abc" + + +def test_extract_sub_url_safe_chars_handled(): + bytes_ = b'{"sub":"x"}' + normal = base64.b64encode(bytes_).decode("ascii") + url_safe = normal.replace("+", "-").replace("/", "_").rstrip("=") + jwt = "aGVhZGVy." + url_safe + ".c2ln" + assert extract_sub(jwt) == "x" + + +@pytest.mark.parametrize( + "payload_json,expected_sub", + [ + ('{"sub":"a"}', "a"), # → encoded length % 4 = 3 (one '=' padded) + ('{"sub":"ab"}', "ab"), # → encoded length % 4 = 0 (no padding) + ('{"sub":"abc"}', "abc"), # → encoded length % 4 = 2 (two '==' padded) + ], +) +def test_extract_sub_padding_branches_all_hit_decodes_correctly(payload_json, expected_sub): + """Explicit per-branch coverage of the % 4 padding switch.""" + assert extract_sub(_make_jwt(payload_json)) == expected_sub + + +# ---- Top-level shape errors ---- + + +def test_extract_sub_null_raises(): + with pytest.raises(ValueError, match="null"): + extract_sub(None) + + +def test_extract_sub_empty_raises(): + with pytest.raises(ValueError, match="empty"): + extract_sub("") + + +def test_extract_sub_one_segment_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("onlyonepart") + + +def test_extract_sub_two_segments_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("a.b") + + +def test_extract_sub_four_segments_raises(): + with pytest.raises(ValueError, match="3"): + extract_sub("a.b.c.d") + + +def test_extract_sub_oversized_input_raises(): + oversized = "a" * 8193 + with pytest.raises(ValueError, match="exceeds maximum"): + extract_sub(oversized) + + +def test_extract_sub_at_ceiling_reaches_parser(): + """8192-char input passes the size gate and fails downstream on segment count.""" + at_ceiling = "a" * 8192 + with pytest.raises(ValueError, match="3"): + extract_sub(at_ceiling) + + +def test_extract_sub_empty_payload_segment_raises(): + with pytest.raises(ValueError, match="empty"): + extract_sub("header..sig") + + +# ---- Payload-content errors ---- + + +def test_extract_sub_malformed_base64_in_payload_raises(): + with pytest.raises(ValueError, match="base64url"): + extract_sub("aGVhZGVy.not!base64.c2ln") + + +def test_extract_sub_malformed_json_in_payload_raises(): + bad_payload = _base64url_encode(b"not json") + with pytest.raises(ValueError, match="not valid JSON"): + extract_sub(f"aGVhZGVy.{bad_payload}.c2ln") + + +def test_extract_sub_payload_is_json_array_raises(): + array_payload = _base64url_encode(b'["not","an","object"]') + with pytest.raises(ValueError, match="not a JSON object"): + extract_sub(f"aGVhZGVy.{array_payload}.c2ln") + + +def test_extract_sub_missing_sub_claim_raises(): + jwt = _make_jwt('{"iss":"https://x","aud":"a"}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_number_raises(): + jwt = _make_jwt('{"sub":12345}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_null_raises(): + jwt = _make_jwt('{"sub":null}') + with pytest.raises(ValueError, match="'sub'"): + extract_sub(jwt) + + +def test_extract_sub_sub_claim_is_empty_string_raises(): + jwt = _make_jwt('{"sub":""}') + with pytest.raises(ValueError, match="empty"): + extract_sub(jwt) diff --git a/tests/oauthbearer/aws/test_aws_oauthbearer_config.py b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py new file mode 100644 index 000000000..afaaffbc1 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_oauthbearer_config.py @@ -0,0 +1,353 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_oauthbearer_config.""" + +import pytest + +from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import ( + AWS_DEBUG_CONSOLE, + AWS_DEBUG_NONE, + AwsOAuthBearerConfig, +) + +# ---- Required fields ---- + + +def test_parse_minimal_required_populates_region_and_audience(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + + +def test_parse_missing_region_raises(): + with pytest.raises(ValueError, match="region.*required"): + AwsOAuthBearerConfig.parse("audience=https://a") + + +def test_parse_missing_audience_raises(): + with pytest.raises(ValueError, match="audience.*required"): + AwsOAuthBearerConfig.parse("region=us-east-1") + + +def test_parse_null_input_raises(): + with pytest.raises(TypeError): + AwsOAuthBearerConfig.parse(None) + + +def test_parse_empty_input_raises_for_missing_region(): + with pytest.raises(ValueError, match="region.*required"): + AwsOAuthBearerConfig.parse("") + + +def test_parse_empty_value_on_required_key_raises(): + with pytest.raises(ValueError, match="region.*must not be empty"): + AwsOAuthBearerConfig.parse("region= audience=https://a") + + +@pytest.mark.parametrize("key", ["signing_algorithm", "sts_endpoint", "principal_name"]) +def test_parse_empty_value_on_optional_key_raises(key): + with pytest.raises(ValueError, match=key): + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a {key}=") + + +# ---- Defaults applied during parse ---- + + +def test_parse_no_signing_algorithm_defaults_to_es384(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.signing_algorithm == "ES384" + + +def test_parse_no_duration_defaults_to_300_seconds(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.duration_seconds == 300 + + +def test_parse_optional_fields_default_to_none(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.sts_endpoint is None + assert cfg.principal_name is None + assert cfg.sasl_extensions is None + assert cfg.tags is None + + +# ---- duration_seconds ---- + + +@pytest.mark.parametrize("seconds", [60, 300, 3600]) +def test_parse_duration_in_range_accepted(seconds): + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a duration_seconds={seconds}") + assert cfg.duration_seconds == seconds + + +@pytest.mark.parametrize("seconds", [0, 59, 3601, -10]) +def test_parse_duration_out_of_range_raises(seconds): + with pytest.raises(ValueError, match="duration_seconds.*must be between"): + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a duration_seconds={seconds}") + + +def test_parse_duration_not_integer_raises(): + with pytest.raises(ValueError, match="duration_seconds.*must be an integer"): + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a duration_seconds=abc") + + +# ---- signing_algorithm ---- + + +@pytest.mark.parametrize("alg", ["ES384", "RS256"]) +def test_parse_allowed_signing_algorithm_accepted(alg): + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a signing_algorithm={alg}") + assert cfg.signing_algorithm == alg + + +@pytest.mark.parametrize("alg", ["HS256", "es384", "RS512"]) +def test_parse_disallowed_signing_algorithm_raises(alg): + with pytest.raises(ValueError, match="signing_algorithm.*ES384.*RS256"): + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a signing_algorithm={alg}") + + +# ---- aws_debug (none/console only) ---- + + +def test_parse_no_aws_debug_defaults_to_none(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + assert cfg.aws_debug == AWS_DEBUG_NONE + + +@pytest.mark.parametrize( + "value,expected", + [ + ("none", AWS_DEBUG_NONE), + ("console", AWS_DEBUG_CONSOLE), + ], +) +def test_parse_aws_debug_values_accepted(value, expected): + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") + assert cfg.aws_debug == expected + + +@pytest.mark.parametrize( + "value,expected", + [ + ("Console", AWS_DEBUG_CONSOLE), + ("CONSOLE", AWS_DEBUG_CONSOLE), + ("NONE", AWS_DEBUG_NONE), + ], +) +def test_parse_aws_debug_case_insensitive_accepted(value, expected): + cfg = AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") + assert cfg.aws_debug == expected + + +# Every unsupported aws_debug value is rejected with the same uniform error. +@pytest.mark.parametrize( + "value", + ["verbose", "etw", "debug", "true", "foo", "log4net", "systemdiagnostics", "Log4Net", "SystemDiagnostics"], +) +def test_parse_aws_debug_invalid_value_raises(value): + with pytest.raises(ValueError, match="aws_debug.*none, console"): + AwsOAuthBearerConfig.parse(f"region=us-east-1 audience=https://a aws_debug={value}") + + +def test_parse_aws_debug_empty_value_raises(): + with pytest.raises(ValueError, match="aws_debug.*must not be empty"): + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a aws_debug=") + + +# ---- sts_endpoint, principal_name ---- + + +def test_parse_sts_endpoint_stored_verbatim(): + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a " "sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" + ) + assert cfg.sts_endpoint == "https://sts-fips.us-east-1.amazonaws.com" + + +def test_parse_principal_name_stored_verbatim(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a principal_name=my-principal") + assert cfg.principal_name == "my-principal" + + +# ---- sasl_extensions argument (typed property pass-through) ---- + + +def test_parse_extension_prefix_in_config_rejected_as_unknown_key(): + """sasl extensions are accepted via typed sasl.oauthbearer.extensions + property only; embedded extension_* keys are rejected.""" + with pytest.raises(ValueError, match="extension_logicalCluster"): + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a extension_logicalCluster=lkc-abc") + + +def test_parse_sasl_extensions_arg_stored_on_config(): + ext = {"logicalCluster": "lkc-abc", "identityPoolId": "pool-xyz"} + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a", ext) + assert cfg.sasl_extensions is ext + + +def test_parse_sasl_extensions_arg_null_keeps_sasl_extensions_null(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a", None) + assert cfg.sasl_extensions is None + + +# ---- tag_ ---- + + +def test_parse_single_tag_collected_into_tags(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_team=platform") + assert cfg.tags == {"team": "platform"} + + +def test_parse_multiple_tags_all_collected(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a " "tag_team=platform " "tag_environment=prod") + assert cfg.tags == {"team": "platform", "environment": "prod"} + + +def test_parse_empty_tag_name_raises(): + with pytest.raises(ValueError, match="empty name"): + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_=value") + + +def test_parse_empty_tag_value_accepted(): + # AWS allows tag values of 0 chars; mirror that. + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_team=") + assert cfg.tags == {"team": ""} + + +def test_parse_duplicate_tag_name_last_wins(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_team=infra tag_team=platform") + assert cfg.tags == {"team": "platform"} + + +def test_parse_exactly_max_tags_accepted(): + parts = ["region=us-east-1", "audience=https://a"] + parts.extend(f"tag_k{i}=v{i}" for i in range(50)) + cfg = AwsOAuthBearerConfig.parse(" ".join(parts)) + assert len(cfg.tags) == 50 + + +def test_parse_over_max_tags_raises(): + parts = ["region=us-east-1", "audience=https://a"] + parts.extend(f"tag_k{i}=v{i}" for i in range(51)) + with pytest.raises(ValueError, match="50"): + AwsOAuthBearerConfig.parse(" ".join(parts)) + + +# ---- Unknown keys ---- + + +def test_parse_unknown_key_raises(): + with pytest.raises(ValueError, match="Unknown key.*not_a_key"): + AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a not_a_key=foo") + + +# ---- Whitespace / ordering ---- + + +def test_parse_tabs_and_multiple_spaces_tolerated(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1\taudience=https://a duration_seconds=600") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + assert cfg.duration_seconds == 600 + + +def test_parse_leading_and_trailing_whitespace_tolerated(): + cfg = AwsOAuthBearerConfig.parse(" region=us-east-1 audience=https://a ") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + + +def test_parse_order_invariant(): + cfg = AwsOAuthBearerConfig.parse("duration_seconds=600 audience=https://a region=us-east-1 signing_algorithm=RS256") + assert cfg.region == "us-east-1" + assert cfg.audience == "https://a" + assert cfg.duration_seconds == 600 + assert cfg.signing_algorithm == "RS256" + + +def test_parse_duplicate_key_last_wins(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a region=us-west-2") + assert cfg.region == "us-west-2" + + +# ---- Malformed entries ---- + + +def test_parse_no_equals_raises(): + with pytest.raises(ValueError, match="Malformed"): + AwsOAuthBearerConfig.parse("region us-east-1 audience=https://a") + + +def test_parse_leading_equals_raises(): + with pytest.raises(ValueError, match="Malformed"): + AwsOAuthBearerConfig.parse("=value audience=https://a region=us-east-1") + + +# ---- Integration ---- + + +def test_parse_all_fields_together_all_populated_correctly(): + sasl_extensions = {"logicalCluster": "lkc-abc"} + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 " + "audience=https://confluent.cloud/oidc " + "duration_seconds=1800 " + "signing_algorithm=RS256 " + "sts_endpoint=https://sts.us-east-1.amazonaws.com " + "principal_name=test-principal " + "aws_debug=console " + "tag_team=platform", + sasl_extensions, + ) + + assert cfg.region == "us-east-1" + assert cfg.audience == "https://confluent.cloud/oidc" + assert cfg.duration_seconds == 1800 + assert cfg.signing_algorithm == "RS256" + assert cfg.sts_endpoint == "https://sts.us-east-1.amazonaws.com" + assert cfg.principal_name == "test-principal" + assert cfg.aws_debug == AWS_DEBUG_CONSOLE + assert cfg.sasl_extensions == {"logicalCluster": "lkc-abc"} + assert cfg.tags == {"team": "platform"} + + +# ---- Direct construction (bypassing parse) still validated ---- + + +def test_direct_construction_with_empty_region_raises(): + with pytest.raises(ValueError, match="region.*must not be empty"): + AwsOAuthBearerConfig(region="", audience="https://a") + + +def test_direct_construction_with_out_of_range_duration_raises(): + with pytest.raises(ValueError, match="duration_seconds.*must be between"): + AwsOAuthBearerConfig(region="us-east-1", audience="https://a", duration_seconds=10) + + +def test_direct_construction_with_bool_duration_rejected(): + """bool is-a int in Python; reject explicitly so True/False doesn't pass.""" + with pytest.raises(ValueError, match="duration_seconds.*must be an integer"): + AwsOAuthBearerConfig(region="us-east-1", audience="https://a", duration_seconds=True) + + +# ---- Surface invariants ---- + + +def test_aws_oauth_bearer_config_not_exposed_via_subpackage_init(): + """Public surface stays minimal — AwsOAuthBearerConfig is private to the + autowire layer (only `create_handler` is publicly importable).""" + import confluent_kafka.oauthbearer.aws as aws_pkg + + assert not hasattr(aws_pkg, "AwsOAuthBearerConfig") diff --git a/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py b/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py new file mode 100644 index 000000000..40f72043e --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_sasl_extensions_parser.py @@ -0,0 +1,86 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_sasl_extensions_parser.""" + +import pytest + +from confluent_kafka.oauthbearer.aws._aws_sasl_extensions_parser import parse + +# ---- Null / empty input → None ---- + + +def test_null_raw_returns_none(): + assert parse(None) is None + + +def test_empty_string_returns_none(): + assert parse("") is None + + +# ---- Happy paths ---- + + +def test_single_entry_returns_one_item(): + result = parse("logicalCluster=lkc-abc") + assert result == {"logicalCluster": "lkc-abc"} + + +def test_multiple_entries_returns_all(): + result = parse("logicalCluster=lkc-abc,identityPoolId=pool-x") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_whitespace_around_commas_trimmed_and_parsed(): + # Each comma-delimited entry is trimmed before parsing. + result = parse(" logicalCluster=lkc-abc , identityPoolId=pool-x ") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_empty_entries_tolerated(): + result = parse("logicalCluster=lkc-abc,,identityPoolId=pool-x,") + assert result == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-x", + } + + +def test_empty_value_accepted(): + # RFC 7628 SASL extensions allow empty values; mirror that. + result = parse("logicalCluster=") + assert result == {"logicalCluster": ""} + + +def test_duplicate_key_last_wins(): + result = parse("k=a,k=b") + assert result == {"k": "b"} + + +# ---- Malformed input → throws ---- + + +def test_missing_equals_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + parse("noEqualsHere") + + +def test_empty_key_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + parse("=value") diff --git a/tests/oauthbearer/aws/test_aws_sts_token_provider.py b/tests/oauthbearer/aws/test_aws_sts_token_provider.py new file mode 100644 index 000000000..9953b0778 --- /dev/null +++ b/tests/oauthbearer/aws/test_aws_sts_token_provider.py @@ -0,0 +1,469 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for confluent_kafka.oauthbearer.aws._aws_sts_token_provider. + +- Python uses a tuple return (not a typed record) so the "no extensions → + null" assertion becomes "no extensions → empty dict" (the C oauth_cb + contract requires a dict, not None, at that slot). +- Python returns ``expiry_epoch_seconds`` (float seconds) not + ``LifetimeMs`` (long milliseconds) — the C wrapper multiplies by 1000. +""" + +import base64 +import datetime +import logging +from pathlib import Path +from typing import Any, Dict, Optional +from unittest.mock import patch + +import pytest + +# botocore is a transitive of boto3; available in opt-in venv. +pytest.importorskip("boto3") +pytest.importorskip("botocore") + +from botocore.exceptions import ClientError # noqa: E402 + +from confluent_kafka.oauthbearer.aws._aws_oauthbearer_config import AwsOAuthBearerConfig # noqa: E402 +from confluent_kafka.oauthbearer.aws._aws_sts_token_provider import ( # noqa: E402 + MINIMUM_BOTO3_VERSION, + AwsStsTokenProvider, + _require_boto3_version, + _version_tuple, +) + +# ---- Test helpers ---- + + +_ROLE_ARN = "arn:aws:iam::123:role/R" + + +def _base64url(data: bytes) -> str: + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _canned_jwt(sub: str = _ROLE_ARN) -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" + + +_CANNED_JWT = _canned_jwt() +_CANNED_EXPIRY = datetime.datetime( + 2099, + 4, + 21, + 6, + 6, + 47, + 641_000, + tzinfo=datetime.timezone.utc, +) + + +class FakeStsClient: + """Test double mirroring .NET's FakeStsClient. + + Records the last request kwargs and returns a canned response (or raises + a canned exception). Avoids the ceremony of ``botocore.stub.Stubber`` for + the simple cases here. + """ + + def __init__(self, responder=None, raises: Optional[BaseException] = None) -> None: + self.last_request: Optional[Dict[str, Any]] = None + self._responder = responder + self._raises = raises + + def get_web_identity_token(self, **kwargs: Any) -> Dict[str, Any]: + self.last_request = kwargs + if self._raises is not None: + raise self._raises + if self._responder is not None: + return self._responder(kwargs) + return _ok_response() + + +def _ok_response( + jwt: str = _CANNED_JWT, + expiration: datetime.datetime = _CANNED_EXPIRY, +) -> Dict[str, Any]: + return {"WebIdentityToken": jwt, "Expiration": expiration} + + +# ---- Constructor: checks ---- + + +def test_ctor_null_config_raises(): + with pytest.raises(TypeError): + AwsStsTokenProvider(None) + + +def test_ctor_valid_parsed_config_succeeds(): + """No AWS / no HTTP call at construction time (lazy credential chain).""" + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + # Does not throw; does not call AWS (lazy credential chain). + assert provider is not None + + +# ---- Constructor: aws_debug ---- + + +def test_ctor_no_aws_debug_does_not_mutate_botocore_logger(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + # Capture botocore logger's level before & after construction; without + # an explicit aws_debug=console, we must not touch boto3.set_stream_logger. + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_not_called() + + +def test_ctor_aws_debug_none_does_not_mutate_botocore_logger(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a aws_debug=none") + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_not_called() + + +def test_ctor_aws_debug_console_routes_botocore_logger_to_stream(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a aws_debug=console") + with patch("boto3.set_stream_logger") as mock_setter: + AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + mock_setter.assert_called_once_with("botocore", logging.DEBUG) + + +# ---- token(): request shape ---- + + +def test_token_audience_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://my.audience") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["Audience"] == ["https://my.audience"] + + +def test_token_signing_algorithm_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a signing_algorithm=RS256") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["SigningAlgorithm"] == "RS256" + + +def test_token_duration_seconds_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a duration_seconds=900") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["DurationSeconds"] == 900 + + +def test_token_default_duration_sends_300_seconds(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["DurationSeconds"] == 300 + + +def test_token_default_signing_algorithm_sends_es384(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert fake.last_request["SigningAlgorithm"] == "ES384" + + +def test_token_tags_passthrough(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a tag_team=platform tag_environment=prod") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + + tags = fake.last_request["Tags"] + assert len(tags) == 2 + assert {"Key": "team", "Value": "platform"} in tags + assert {"Key": "environment", "Value": "prod"} in tags + + +def test_token_no_tags_omits_tags_field_from_request(): + """When no tags configured, omit the Tags field rather than send empty list.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + provider.token() + assert "Tags" not in fake.last_request + + +# ---- token(): response mapping ---- + + +def test_token_returns_mapped_fields(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + jwt, expiry, principal, extensions = provider.token() + + assert jwt == _CANNED_JWT + assert principal == _ROLE_ARN + assert expiry == _CANNED_EXPIRY.timestamp() + assert extensions == {} + + +def test_token_principal_name_override_wins_over_jwt_sub(): + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a principal_name=explicit-principal") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, _, principal, _ = provider.token() + assert principal == "explicit-principal" + + +def test_token_sasl_extensions_passthrough(): + fake = FakeStsClient() + sasl_extensions = {"logicalCluster": "lkc-123", "identityPoolId": "pool-x"} + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a", + sasl_extensions, + ) + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, _, _, extensions = provider.token() + assert extensions == {"logicalCluster": "lkc-123", "identityPoolId": "pool-x"} + + +def test_token_no_extensions_configured_returns_empty_dict(): + """Python deviation from .NET: empty dict, not None, because the C + oauth_cb contract uses PyArg_ParseTuple "O!" with PyDict_Type at that + slot (rejects None). Empty dict is the Pythonic equivalent of the + .NET null-extensions case.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, _, _, extensions = provider.token() + assert extensions == {} + assert isinstance(extensions, dict) + + +def test_token_expiry_is_epoch_seconds_float(): + """The C wrapper expects expiry as epoch seconds (float). It multiplies + by 1000 internally to get milliseconds for librdkafka.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + _, expiry, _, _ = provider.token() + assert isinstance(expiry, float) + assert expiry == _CANNED_EXPIRY.timestamp() + + +# ---- token(): error propagation ---- + + +def test_token_missing_expiration_raises(): + """STS response without Expiration → ValueError surfaced as + rd_kafka_oauthbearer_set_token_failure by the C wrapper.""" + fake = FakeStsClient(responder=lambda kwargs: {"WebIdentityToken": _CANNED_JWT}) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError, match="Expiration"): + provider.token() + + +def test_token_missing_token_value_raises(): + fake = FakeStsClient(responder=lambda kwargs: {"Expiration": _CANNED_EXPIRY}) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError, match="WebIdentityToken"): + provider.token() + + +def test_token_malformed_jwt_raises_value_error(): + fake = FakeStsClient( + responder=lambda kwargs: { + "WebIdentityToken": "not-a-jwt", + "Expiration": _CANNED_EXPIRY, + } + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ValueError): + provider.token() + + +def test_token_sts_access_denied_propagates(): + fake = FakeStsClient( + raises=ClientError( + error_response={ + "Error": { + "Code": "AccessDenied", + "Message": "User is not authorized to perform: sts:GetWebIdentityToken", + } + }, + operation_name="GetWebIdentityToken", + ), + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ClientError) as exc_info: + provider.token() + assert exc_info.value.response["Error"]["Code"] == "AccessDenied" + + +def test_token_outbound_federation_disabled_propagates(): + fake = FakeStsClient( + raises=ClientError( + error_response={ + "Error": { + "Code": "OutboundWebIdentityFederationDisabledException", + "Message": "Outbound web identity federation is not enabled on this account.", + } + }, + operation_name="GetWebIdentityToken", + ), + ) + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=fake) + with pytest.raises(ClientError) as exc_info: + provider.token() + assert exc_info.value.response["Error"]["Code"] == "OutboundWebIdentityFederationDisabledException" + + +# ---- Surface invariants ---- + + +def test_aws_sts_token_provider_not_exposed_via_subpackage_init(): + """Public surface stays minimal — AwsStsTokenProvider is private.""" + import confluent_kafka.oauthbearer.aws as aws_pkg + + assert not hasattr(aws_pkg, "AwsStsTokenProvider") + + +# ---- Lazy credential resolution ---- + + +def test_ctor_does_not_invoke_sts(): + """Construction must not make any STS call (lazy credential chain). + Verify by checking the fake client received no requests after __init__ + but before any .token() invocation.""" + fake = FakeStsClient() + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + AwsStsTokenProvider(cfg, sts_client=fake) + assert fake.last_request is None + + +# ---- sts_endpoint plumbed to boto3.client ---- + + +def test_ctor_sts_endpoint_plumbed_to_boto3_client(): + """When sts_endpoint is set on config, boto3.Session().client('sts', ...) + receives an endpoint_url kwarg.""" + cfg = AwsOAuthBearerConfig.parse( + "region=us-east-1 audience=https://a " "sts_endpoint=https://sts-fips.us-east-1.amazonaws.com" + ) + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + AwsStsTokenProvider(cfg) + mock_session.client.assert_called_once_with( + "sts", + region_name="us-east-1", + endpoint_url="https://sts-fips.us-east-1.amazonaws.com", + ) + + +def test_ctor_no_sts_endpoint_omits_endpoint_url_kwarg(): + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + with patch("boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + AwsStsTokenProvider(cfg) + mock_session.client.assert_called_once_with( + "sts", + region_name="us-east-1", + ) + + +# ---- boto3 version floor check ---- + + +def test_version_tuple_parses_numeric_components(): + assert _version_tuple("1.42.25") == (1, 42, 25) + assert _version_tuple("1.42.97") == (1, 42, 97) + assert _version_tuple("2.0.0") == (2, 0, 0) + + +def test_version_tuple_tolerates_prerelease_suffix(): + assert _version_tuple("1.42.0rc1") == (1, 42, 0) + assert _version_tuple("1.42.25.dev0") == (1, 42, 25, 0) + + +def test_require_boto3_version_passes_on_installed_version(): + # The opt-in test venv installs boto3 >= the floor, so this must not raise. + _require_boto3_version() + + +def test_require_boto3_version_below_floor_raises(monkeypatch): + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.40.0") + with pytest.raises(ImportError, match="requires boto3>=") as exc_info: + _require_boto3_version() + # message names both the required floor and the offending installed version + assert MINIMUM_BOTO3_VERSION in str(exc_info.value) + assert "1.40.0" in str(exc_info.value) + + +def test_require_boto3_version_at_floor_ok(monkeypatch): + import boto3 + + monkeypatch.setattr(boto3, "__version__", MINIMUM_BOTO3_VERSION) + _require_boto3_version() # exactly the floor satisfies ">=" + + +def test_ctor_old_boto3_raises_before_building_client(monkeypatch): + """The floor check is wired into __init__ on the real-client path and fails + fast — before any boto3.Session/client is constructed.""" + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.40.0") + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + with patch("boto3.Session") as mock_session_cls: + with pytest.raises(ImportError, match="GetWebIdentityToken"): + AwsStsTokenProvider(cfg) # no sts_client → real path → check runs + mock_session_cls.assert_not_called() # raised before building a client + + +def test_ctor_injected_client_skips_version_check(monkeypatch): + """Injecting an sts_client (test seam) bypasses the boto3 floor check — the + real boto3 isn't used for the STS call in that case.""" + import boto3 + + monkeypatch.setattr(boto3, "__version__", "1.0.0") # ancient, but skipped + cfg = AwsOAuthBearerConfig.parse("region=us-east-1 audience=https://a") + provider = AwsStsTokenProvider(cfg, sts_client=FakeStsClient()) + assert provider is not None + + +def test_minimum_boto3_version_matches_requirements_file(): + """Guard against MINIMUM_BOTO3_VERSION drifting from the pinned floor in + requirements/requirements-oauthbearer-aws.txt.""" + req = Path(__file__).resolve().parents[3] / "requirements" / "requirements-oauthbearer-aws.txt" + floor = None + for line in req.read_text().splitlines(): + stripped = line.strip() + if stripped.startswith("boto3") and ">=" in stripped: + floor = stripped.split(">=", 1)[1].strip() + break + assert floor == MINIMUM_BOTO3_VERSION, ( + f"requirements-oauthbearer-aws.txt pins boto3>={floor} but " + f"MINIMUM_BOTO3_VERSION is {MINIMUM_BOTO3_VERSION}; keep them in sync." + ) diff --git a/tests/oauthbearer/aws/test_contract.py b/tests/oauthbearer/aws/test_contract.py new file mode 100644 index 000000000..9d5a02472 --- /dev/null +++ b/tests/oauthbearer/aws/test_contract.py @@ -0,0 +1,135 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Frozen cross-module contract guard for the autowire entry-point. + +These tests guard against accidental drift: parameter names, type +annotations, return annotation, arity, and absence of defaults. +""" + +import inspect +from typing import Callable, Dict, Optional, Tuple + +import pytest + +pytest.importorskip("boto3") + +from confluent_kafka.oauthbearer.aws import aws_autowire # noqa: E402 +from confluent_kafka.oauthbearer.aws.aws_autowire import OAuthBearerCallback, create_handler # noqa: E402 + +# ---- Module surface ---- + + +def test_module_importable_at_canonical_path(): + """The C dispatcher does PyImport_ImportModule(...) with this exact path.""" + import importlib + + mod = importlib.import_module("confluent_kafka.oauthbearer.aws.aws_autowire") + assert mod is aws_autowire + + +def test_create_handler_is_module_level_callable(): + assert callable(aws_autowire.create_handler) + assert aws_autowire.create_handler is create_handler + + +def test_oauthbearer_callback_type_alias_exported(): + assert hasattr(aws_autowire, "OAuthBearerCallback") + assert aws_autowire.OAuthBearerCallback is OAuthBearerCallback + + +def test_module_all_lists_public_surface(): + """__all__ should advertise only the two public names.""" + assert set(aws_autowire.__all__) == {"create_handler", "OAuthBearerCallback"} + + +# ---- create_handler frozen signature ---- + + +def test_create_handler_arity_is_two(): + sig = inspect.signature(create_handler) + assert len(sig.parameters) == 2 + + +def test_create_handler_parameter_names_are_frozen(): + sig = inspect.signature(create_handler) + names = list(sig.parameters.keys()) + assert names == ["sasl_oauthbearer_config", "sasl_oauthbearer_extensions"] + + +def test_create_handler_parameter_annotations_are_frozen(): + sig = inspect.signature(create_handler) + assert sig.parameters["sasl_oauthbearer_config"].annotation is str + assert sig.parameters["sasl_oauthbearer_extensions"].annotation == Optional[str] + + +def test_create_handler_no_default_values(): + """Both parameters must be required positional — the C dispatcher always + passes two arguments (extensions may be the empty string or None, but + is always supplied).""" + sig = inspect.signature(create_handler) + assert sig.parameters["sasl_oauthbearer_config"].default is inspect.Parameter.empty + assert sig.parameters["sasl_oauthbearer_extensions"].default is inspect.Parameter.empty + + +def test_create_handler_return_annotation_is_oauthbearer_callback(): + sig = inspect.signature(create_handler) + assert sig.return_annotation is OAuthBearerCallback + + +def test_create_handler_parameters_accept_positional_arguments(): + """The C dispatcher calls via PyObject_CallFunction with positional args.""" + sig = inspect.signature(create_handler) + for name in ["sasl_oauthbearer_config", "sasl_oauthbearer_extensions"]: + kind = sig.parameters[name].kind + # POSITIONAL_OR_KEYWORD covers what PyObject_CallFunction provides. + assert kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + + +# ---- OAuthBearerCallback shape (the type returned by create_handler) ---- + + +def test_oauthbearer_callback_matches_c_oauth_cb_contract(): + """The C oauth_cb wrapper at confluent_kafka.c:2291 does: + + PyArg_ParseTuple(result, "sd|sO!", &token, &expiry, &principal, &PyDict_Type, &extensions) + + So the callable returned by create_handler must accept one string + argument (the sasl.oauthbearer.config pass-through) and return a tuple + of (str, float, str, Dict[str, str]). + """ + expected = Callable[[str], Tuple[str, float, str, Dict[str, str]]] + assert OAuthBearerCallback == expected + + +# ---- create_handler docstring sanity ---- + + +def test_create_handler_docstring_present(): + """The frozen cross-module contract is documented in the function's + docstring so it's visible to anyone reading the source. Block a future + contributor from accidentally stripping the docstring.""" + assert create_handler.__doc__ is not None + assert "frozen" in aws_autowire.__doc__.lower() + + +def test_create_handler_docstring_references_key_contract_terms(): + doc = create_handler.__doc__ or "" + # Names of the two arguments must appear so users grepping for them + # find the documentation. + assert "sasl.oauthbearer.config" in doc + assert "sasl.oauthbearer.extensions" in doc diff --git a/tests/oauthbearer/aws/test_dispatch.py b/tests/oauthbearer/aws/test_dispatch.py new file mode 100644 index 000000000..f4cbce242 --- /dev/null +++ b/tests/oauthbearer/aws/test_dispatch.py @@ -0,0 +1,465 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the C-extension dispatcher resolve_aws_oauthbearer_marker(). + +The dispatcher fires inside common_conf_setup() (src/confluent_kafka.c) on +every Producer / Consumer / AdminClient construction (and transitively +AIOProducer / AIOConsumer, which wrap the sync clients). + +These tests exercise the C dispatcher indirectly via client construction: +no direct path to the helper exists from Python — that's the design. +""" + +import base64 +import datetime +import sys +from typing import Any, Dict, Optional +from unittest.mock import MagicMock, patch + +import pytest + +pytest.importorskip("boto3") + +import confluent_kafka # noqa: E402 +from confluent_kafka.admin import AdminClient # noqa: E402 + +# ---- Test helpers ---- + + +def _base64url(data: bytes) -> str: + return base64.b64encode(data).decode("ascii").rstrip("=").replace("+", "-").replace("/", "_") + + +def _canned_jwt(sub: str = "arn:aws:iam::123:role/R") -> str: + header = _base64url(b'{"alg":"ES384","typ":"JWT"}') + payload = _base64url(f'{{"sub":"{sub}"}}'.encode("utf-8")) + return f"{header}.{payload}.sig" + + +def _canned_response() -> Dict[str, Any]: + return { + "WebIdentityToken": _canned_jwt(), + "Expiration": datetime.datetime( + 2099, + 4, + 21, + 6, + 6, + 47, + tzinfo=datetime.timezone.utc, + ), + } + + +@pytest.fixture +def mocked_boto3(): + """Patch boto3.Session so the autowired provider's STS client is a mock. + The actual STS call doesn't happen at client construction (lazy), but the + Session/client construction does — without this, boto3 would try to + resolve real credentials.""" + with patch("boto3.Session") as session_cls: + session = session_cls.return_value + client = MagicMock() + client.get_web_identity_token.return_value = _canned_response() + session.client.return_value = client + yield client + + +def _minimal_aws_iam_config(extra: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + cfg = { + "bootstrap.servers": "broker.invalid:9092", + # security.protocol=SASL_SSL is required for librdkafka to actually + # engage the OAUTHBEARER refresh path. Without it, sasl.mechanisms is + # inert (librdkafka logs CONFWARN and the refresh_cb never fires). + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + if extra: + cfg.update(extra) + return cfg + + +# ---- 1. Marker absent → dispatcher is a no-op (regression check) ---- + + +def test_marker_absent_producer_constructs_unchanged(): + p = confluent_kafka.Producer({"bootstrap.servers": "localhost:9092"}) + p.flush(timeout=0.1) + + +def test_marker_absent_consumer_constructs_unchanged(): + c = confluent_kafka.Consumer( + { + "bootstrap.servers": "localhost:9092", + "group.id": "test-group", + } + ) + c.close() + + +def test_marker_absent_admin_client_constructs_unchanged(): + a = AdminClient({"bootstrap.servers": "localhost:9092"}) + assert a is not None + + +def test_marker_absent_does_not_import_aws_modules(): + """If the marker is absent the dispatcher must NOT touch the optional + subpackage — boto3 import must not happen.""" + # Clear any prior imports of the AWS subpackage so we can detect a fresh import. + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + del sys.modules[name] + confluent_kafka.Producer({"bootstrap.servers": "localhost:9092"}) + for name in [ + "confluent_kafka.oauthbearer.aws.aws_autowire", + "confluent_kafka.oauthbearer.aws._aws_sts_token_provider", + ]: + assert name not in sys.modules, ( + f"Dispatcher imported {name} when no marker was set — " "this means the no-op short-circuit broke." + ) + + +# ---- 2. Other marker values pass through verbatim ---- + + +def test_other_marker_value_passes_through_unchanged(): + """azure_imds / unknown values are not our concern — librdkafka handles them.""" + # azure_imds is a librdkafka-recognised value; we expect Producer construction + # to NOT raise our ValueError. (librdkafka itself may complain about the + # value, but our dispatcher must not.) + with pytest.raises(Exception) as exc_info: + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "azure_imds", + } + ) + # Whatever librdkafka raises, it must NOT be our method-requirement + # error or our config-requirement error. + msg = str(exc_info.value) + assert "AWS IAM" not in msg + assert "aws_iam" not in msg + + +# ---- 3. method=oidc requirement ---- + + +def test_marker_without_method_raises(): + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) + + +def test_marker_with_method_default_raises(): + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "default", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) + + +def test_marker_with_method_oidc_uppercase_raises(): + """Strict matching — librdkafka's method values are lowercase canonical.""" + with pytest.raises(ValueError, match="method=oidc"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "OIDC", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) + + +# ---- 4. sasl.oauthbearer.config requirement ---- + + +def test_marker_with_method_oidc_but_missing_config_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.config.*missing or empty"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + } + ) + + +def test_marker_with_method_oidc_but_empty_config_raises(): + with pytest.raises(ValueError, match="sasl.oauthbearer.config.*missing or empty"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "", + } + ) + + +# ---- 5. Happy path: marker + method=oidc + valid config + boto3 stubbed ---- + + +def test_happy_path_producer_constructs_with_marker(mocked_boto3): + p = confluent_kafka.Producer(_minimal_aws_iam_config()) + assert p is not None + p.flush(timeout=0.1) + + +def test_happy_path_consumer_constructs_with_marker(mocked_boto3): + c = confluent_kafka.Consumer(_minimal_aws_iam_config({"group.id": "test-group"})) + assert c is not None + c.close() + + +def test_happy_path_admin_client_constructs_with_marker(mocked_boto3): + a = AdminClient(_minimal_aws_iam_config()) + assert a is not None + + +async def test_happy_path_aio_producer_constructs_with_marker(mocked_boto3): + """AIO wraps sync — same dispatcher fires via the inner cimpl.Producer. + + AIOProducer.__init__ calls asyncio.get_running_loop(), so this test + must run inside an async context (pyproject.toml's asyncio_mode=auto + handles that for async def test functions).""" + from confluent_kafka.aio import AIOProducer + + p = AIOProducer(_minimal_aws_iam_config()) + assert p is not None + + +async def test_happy_path_aio_consumer_constructs_with_marker(mocked_boto3): + """Same async-context requirement as AIOProducer.""" + from confluent_kafka.aio import AIOConsumer + + c = AIOConsumer(_minimal_aws_iam_config({"group.id": "test-group"})) + assert c is not None + + +# ---- 6. Explicit oauth_cb wins (precedence rule) ---- + + +def test_explicit_oauth_cb_wins_over_marker(): + """When the user supplies their own oauth_cb, the dispatcher leaves the + marker in place and yields — the explicit handler wins and boto3 is NOT + touched. The AWS-IAM-aware librdkafka accepts the marker and uses the + user's oauth_cb (its native aws_iam cb only registers when no refresh cb + is set).""" + sentinel_called = [] + + def user_oauth_cb(config_str): + sentinel_called.append(config_str) + return ("user-token", 9999999999.0, "user-principal", {}) + + # Clear sys.modules so we can verify aws_autowire was NOT imported. + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + del sys.modules[name] + + p = confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + "oauth_cb": user_oauth_cb, + } + ) + assert p is not None + p.flush(timeout=0.1) + # The autowire module must not have been imported. + assert "confluent_kafka.oauthbearer.aws.aws_autowire" not in sys.modules + + +# ---- 7. Parser errors surface from the autowire path ---- + + +def test_marker_with_invalid_config_grammar_raises(mocked_boto3): + """Config parser ValueError surfaces through the dispatcher.""" + with pytest.raises(ValueError, match="Unknown key.*not_a_key"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a not_a_key=foo", + } + ) + + +def test_marker_with_invalid_signing_algorithm_raises(mocked_boto3): + with pytest.raises(ValueError, match="signing_algorithm"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a signing_algorithm=HS256", + } + ) + + +def test_marker_with_invalid_extensions_grammar_raises(mocked_boto3): + with pytest.raises(ValueError, match="sasl.oauthbearer.extensions"): + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + "sasl.oauthbearer.extensions": "malformed-no-equals", + } + ) + + +# ---- 8. Friendly ImportError when the optional extra is missing ---- + + +@pytest.fixture +def boto3_absent(monkeypatch): + """Simulates an opt-out environment: boto3 import fails + relevant + aws.* submodules cleared from sys.modules so the C dispatcher's + PyImport_ImportModule re-executes the module body and hits the + boto3=None gate.""" + # Force any subsequent `import boto3` to raise ImportError. + monkeypatch.setitem(sys.modules, "boto3", None) + # Clear the cached aws submodules so PyImport_ImportModule re-executes + # their top-level statements (including `import boto3`). + for name in list(sys.modules): + if name.startswith("confluent_kafka.oauthbearer.aws"): + monkeypatch.delitem(sys.modules, name, raising=False) + yield + # monkeypatch reverts on teardown. + + +def test_marker_with_missing_extra_raises_friendly_import_error(boto3_absent): + """When boto3 isn't available, the dispatcher catches the + ModuleNotFoundError from the import chain and rewrites it into a + friendly install hint. __cause__ chain preserves the original.""" + with pytest.raises(ImportError) as exc_info: + confluent_kafka.Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) + msg = str(exc_info.value) + assert "oauthbearer-aws" in msg + assert "pip install" in msg + assert "aws_iam" in msg + # __cause__ preserves the original failure for diagnostic tools. + assert exc_info.value.__cause__ is not None + + +def test_friendly_import_error_on_consumer_too(boto3_absent): + with pytest.raises(ImportError, match="oauthbearer-aws"): + confluent_kafka.Consumer( + { + "bootstrap.servers": "broker.invalid:9092", + "group.id": "g", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) + + +def test_friendly_import_error_on_admin_client_too(boto3_absent): + with pytest.raises(ImportError, match="oauthbearer-aws"): + AdminClient( + { + "bootstrap.servers": "broker.invalid:9092", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": "region=us-east-1 audience=https://a", + } + ) + + +# ---- 9. Marker is passed through to the AWS-IAM-aware librdkafka ---- + + +def test_marker_passed_through_to_librdkafka(mocked_boto3): + """The dispatcher leaves the 'aws_iam' marker in confdict; it is passed to + librdkafka unchanged. The bundled AWS-IAM-aware librdkafka recognizes the + 'aws_iam' enum value and accepts it at rd_kafka_conf_set time (bypassing + the token.endpoint.url + grant-type requirements), so the Producer + constructs cleanly. This also pins the requirement that the linked + librdkafka be AWS-IAM-aware: against stock librdkafka the marker would be + rejected with an 'invalid value' error here.""" + p = confluent_kafka.Producer(_minimal_aws_iam_config()) + assert p is not None + p.flush(timeout=0.1) + + +# ---- 10. Extensions plumbing through the dispatcher ---- + + +def test_marker_with_extensions_plumbs_through(mocked_boto3): + """The dispatcher reads sasl.oauthbearer.extensions and passes it as the + second arg to create_handler. Verify by checking the returned callable + yields the extensions in its 4-tuple result.""" + p = confluent_kafka.Producer( + _minimal_aws_iam_config( + { + "sasl.oauthbearer.extensions": "logicalCluster=lkc-123", + } + ) + ) + assert p is not None + p.flush(timeout=0.1) + + +def test_marker_with_empty_extensions_treated_as_absent(mocked_boto3): + """Empty extensions string → autowire treats as None (no extensions).""" + p = confluent_kafka.Producer( + _minimal_aws_iam_config( + { + "sasl.oauthbearer.extensions": "", + } + ) + ) + assert p is not None + p.flush(timeout=0.1) diff --git a/tests/oauthbearer/aws/test_real_sts.py b/tests/oauthbearer/aws/test_real_sts.py new file mode 100644 index 000000000..465713bb6 --- /dev/null +++ b/tests/oauthbearer/aws/test_real_sts.py @@ -0,0 +1,387 @@ +# Copyright 2026 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Phase 6 — real-AWS integration tests for the autowire path. + +Gated on ``RUN_AWS_STS_REAL=1``. Default ``pytest tests/oauthbearer/aws`` +runs every other Phase 1-5 case but skips this file's tests so CI doesn't +hit the AWS API. + +Why this file lives under ``tests/oauthbearer/aws/`` rather than +``tests/integration/``: ``tests/integration/conftest.py`` eagerly imports +``trivup.clusters.KafkaCluster``, which is a heavyweight broker-only +dependency. The STS provider has no broker dependency — the env-var gate +is sufficient isolation. Same gate-placement call as the 29-April POC. + +Cross-language invariant: on the shared EC2 test box +(``ktrue-iam-sts-test-role`` in ``eu-north-1``, account ``708975691912``, +audience ``https://api.example.com``) Go / .NET / JS / librdkafka all +mint a **1256-byte JWT** with principal +``arn:aws:iam::708975691912:role/ktrue-iam-sts-test-role``. The Python +port preserves that invariant. + +Run instructions:: + + # On EC2 with role attached, audience-trust-policy enabled: + export RUN_AWS_STS_REAL=1 + export AWS_STS_TEST_REGION=eu-north-1 # AWS_REGION also accepted (fallback) + export AWS_STS_TEST_AUDIENCE=https://api.example.com + /tmp/ckp-optin/bin/pytest -v -s tests/oauthbearer/aws/test_real_sts.py + +The ``-s`` flag preserves the diagnostic ``print(...)`` lines that capture +the JWT length, principal, and timing — important for cross-language +parity evidence in PR descriptions. +""" + +import base64 +import json +import os +import re +import time + +import pytest + +# boto3 is loaded transitively via the autowire path. Skip the whole file +# in opt-out venvs so collection doesn't blow up. +pytest.importorskip("boto3") + +# Gate the entire module behind RUN_AWS_STS_REAL=1 so CI doesn't accidentally +# hit STS. Individual @pytest.mark.skipif decorators would work too but the +# module-level pytestmark is cleaner and less error-prone. +pytestmark = pytest.mark.skipif( + os.environ.get("RUN_AWS_STS_REAL") != "1", + reason="Set RUN_AWS_STS_REAL=1 and provide AWS credentials to run.", +) + +from confluent_kafka.oauthbearer.aws.aws_autowire import create_handler # noqa: E402 + +# ============================================================================= +# Configuration sourced from env vars so the test can be re-pointed at a +# different role / audience without code changes. +# ============================================================================= + +_REGION = os.environ.get("AWS_STS_TEST_REGION") or os.environ.get("AWS_REGION", "eu-north-1") +_AUDIENCE = os.environ.get("AWS_STS_TEST_AUDIENCE", "https://api.example.com") +_DURATION = os.environ.get("DURATION_SECONDS", "300") + +# Cross-language invariant: 1256 bytes on the shared test role + audience. +# Allow a tolerance band for variation across audiences (URL length affects +# payload size). Catches order-of-magnitude bugs without flaking on tiny +# audience-string changes. +_JWT_LENGTH_MIN = int(os.environ.get("JWT_LENGTH_MIN", "1100")) +_JWT_LENGTH_MAX = int(os.environ.get("JWT_LENGTH_MAX", "1500")) + +_JWT_PATTERN = re.compile(r"^[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+$") +_ARN_PATTERN = re.compile(r"^arn:aws:(?:iam|sts)::\d+:[^/]+/.+$") + + +def _config_string(**parts) -> str: + """Render a sasl.oauthbearer.config wire string from key=value parts.""" + return " ".join(f"{k}={v}" for k, v in parts.items()) + + +def _default_config(**extra) -> str: + """Build a sasl.oauthbearer.config string with the defaults plus extras.""" + parts = {"region": _REGION, "audience": _AUDIENCE, "duration_seconds": _DURATION} + parts.update(extra) + return _config_string(**parts) + + +def _decode_jwt_payload(jwt: str) -> dict: + """Decode the JWT payload (middle segment) into a Python dict. + + Helper for tests that need to inspect specific claims (tags, audience, + algorithm) end-to-end. + """ + payload_b64url = jwt.split(".")[1] + # base64url → base64 with padding + padding = "=" * (-len(payload_b64url) % 4) + standard_b64 = (payload_b64url + padding).replace("-", "+").replace("_", "/") + return json.loads(base64.b64decode(standard_b64).decode("utf-8")) + + +# ============================================================================= +# create_handler() — the cross-module entry-point contract +# ============================================================================= + + +def test_create_handler_mints_valid_jwt(): + handler = create_handler(_default_config(), None) + token, expiry, principal, extensions = handler("") + + # 3-segment base64url JWT. + assert _JWT_PATTERN.match(token), f"not a 3-segment JWT: {token[:60]}..." + + # Length within cross-language tolerance. + assert _JWT_LENGTH_MIN <= len(token) <= _JWT_LENGTH_MAX, ( + f"JWT length {len(token)} outside [{_JWT_LENGTH_MIN}, " + f"{_JWT_LENGTH_MAX}] (Go/.NET/JS/librdkafka observe 1256 on the shared " + f"test box). Set JWT_LENGTH_MIN/MAX env vars if running on a different " + f"role/audience." + ) + + # Expiry within the requested window. + now = time.time() + assert expiry > now, "expiry must be in the future" + assert expiry < now + 10 * 60, "expiry must be within 10 min for DurationSeconds <= 300" + + # Principal is a bare role ARN (AWS STS GetWebIdentityToken convention). + assert _ARN_PATTERN.match(principal), f"unexpected principal: {principal!r}" + + # Default config has no extensions. + assert extensions == {} + + # Diagnostic for the run log so EC2 evidence is captured in CI / PR. + print(f"\n[autowire-real] jwt_length={len(token)} " f"principal={principal} expires_in={int(expiry - now)}s") + + +def test_create_handler_jwt_length_matches_cross_language(): + """Cross-language byte-exact JWT length parity check. + + Different roles produce different JWT lengths because AWS STS bakes + role-attached ``principal_tags`` into the payload (e.g. provisioning- + system tags like ``divvy_owner`` add ~70 bytes). The 1256-byte length + seen on Go/.NET/JS/librdkafka in earlier validation was specific to + one tag-less role. + + This test only runs when ``EXPECTED_TOKEN_LEN`` is explicitly set — + its purpose is to confirm byte-for-byte parity against a baseline you've + measured by running any sibling client (Go / .NET / JS) against the + same role + audience. + """ + expected_str = os.environ.get("EXPECTED_TOKEN_LEN") + if not expected_str: + pytest.skip( + "EXPECTED_TOKEN_LEN not set — cross-language byte-exact parity " + "check only runs when you supply a known-good baseline. Run a " + "sibling client (Go/.NET/JS) against the same role+audience, " + "note the JWT length, and set EXPECTED_TOKEN_LEN to that value." + ) + + expected = int(expected_str) + handler = create_handler(_default_config(), None) + token, *_ = handler("") + assert len(token) == expected, f"expected {expected} bytes (cross-language match), got {len(token)}" + + +def test_create_handler_principal_matches_role_arn(): + """JWT 'sub' claim is the bare role ARN. AWS STS GetWebIdentityToken + convention (no session prefix on bare role tokens).""" + handler = create_handler(_default_config(), None) + _, _, principal, _ = handler("") + assert principal.startswith("arn:aws:iam::"), f"not an IAM role ARN: {principal!r}" + assert ":role/" in principal, f"not a role ARN: {principal!r}" + + +def test_create_handler_repeats_mint_distinct_tokens(): + """Provider does not cache; STS mints a fresh JWT each invocation.""" + handler = create_handler(_default_config(), None) + _, exp1, _, _ = handler("") + time.sleep(1.0) + _, exp2, _, _ = handler("") + assert exp2 > exp1, "second token's expiry must be later than the first" + + +def test_create_handler_honours_duration_seconds(): + """Custom duration_seconds=60 produces an expiry ~1 min out, not the 300s + default. + + The shared test role's IAM policy caps DurationSeconds at 300 (per .NET + PR validation §2g and the JS hybrid plan). So we go *below* the default + rather than above to prove the parameter is honored end-to-end. Same + value JS uses for the same reason. + """ + handler = create_handler(_default_config(duration_seconds="60"), None) + _, expiry, _, _ = handler("") + seconds_remaining = expiry - time.time() + assert 50 <= seconds_remaining <= 80, f"duration_seconds=60 → expected ~60s window, got {seconds_remaining:.0f}s" + + +def test_create_handler_honours_signing_algorithm_rs256(): + """signing_algorithm=RS256 produces a JWT with alg=RS256 in the header. + + Cross-language parity check: the .NET test asserts the same header.alg + round-trip. Catches any SDK-side algorithm-shaping surprises. + """ + handler = create_handler(_default_config(signing_algorithm="RS256"), None) + token, *_ = handler("") + # Decode the JWT header (first segment). + header_b64url = token.split(".")[0] + padding = "=" * (-len(header_b64url) % 4) + header_json = json.loads( + base64.b64decode((header_b64url + padding).replace("-", "+").replace("_", "/")).decode("utf-8") + ) + assert header_json.get("alg") == "RS256", f"header.alg expected RS256, got {header_json.get('alg')!r}" + + +def test_create_handler_honours_principal_name_override(): + """Setting principal_name=... in the wire grammar overrides the JWT 'sub' + extraction at the autowire layer.""" + handler = create_handler( + _default_config(principal_name="custom-principal"), + None, + ) + _, _, principal, _ = handler("") + assert principal == "custom-principal", f"principal_name override not honoured: got {principal!r}" + + +def test_create_handler_round_trips_sasl_extensions(): + """The typed sasl.oauthbearer.extensions property is parsed separately + (comma-separated) and surfaces in the returned extensions dict. + + Note: extensions are forwarded to the broker — not to STS — so they + don't appear in the JWT. This test asserts the dispatcher-to-tuple + round-trip, not the JWT payload.""" + handler = create_handler( + _default_config(), + "logicalCluster=lkc-abc,identityPoolId=pool-xyz", + ) + _, _, _, extensions = handler("") + assert extensions == { + "logicalCluster": "lkc-abc", + "identityPoolId": "pool-xyz", + } + + +def test_create_handler_tag_claims_flow_to_sts(): + """tag_= in the wire grammar flows into the STS Tags + parameter without breaking the request. + + The primary evidence is **token minting succeeding** — if our Tags + parameter caused STS to reject the request (e.g. malformed shape or + over the 50-tag cap), we'd see a ClientError exception. Token returned + cleanly means dispatcher → provider → boto3 → AWS STS handoff worked. + + Whether AWS STS surfaces our custom tags in the JWT payload's + ``principal_tags`` claim is a SEPARATE question depending on the role's + IAM trust policy — specifically whether ``sts:TagSession`` is allowed. + Without that permission, STS silently drops user tags from the JWT. + The diagnostic prints below show what STS actually embedded so users + can spot the case where tags are silently dropped. + + Set ``ASSERT_TAGS_IN_JWT=1`` to require that our tags appear in the + JWT — only meaningful on roles with ``sts:TagSession`` allowed. + """ + handler = create_handler( + _default_config(tag_team="platform", tag_environment="prod"), + None, + ) + # If our Tags param breaks the STS call, this raises. + token, *_ = handler("") + payload = _decode_jwt_payload(token) + + # AWS STS surfaces tags either under "principal_tags" at the JWT root + # OR nested under "https://sts.amazonaws.com/" → "principal_tags". + found_tags = {} + if "principal_tags" in payload: + found_tags.update(payload["principal_tags"]) + for nested_value in payload.values(): + if isinstance(nested_value, dict) and "principal_tags" in nested_value: + found_tags.update(nested_value["principal_tags"]) + + our_tags_present = "team" in found_tags and "environment" in found_tags + print(f"\n[autowire-real] JWT payload keys: {list(payload.keys())}") + print(f"[autowire-real] tags surfaced in JWT: {found_tags}") + print( + f"[autowire-real] our injected tags appear in JWT: {our_tags_present}" + + ("" if our_tags_present else " (role likely missing sts:TagSession permission)") + ) + + # Token minted successfully — Tags param didn't cause STS rejection. + assert len(token) > 100, "Token should be a valid JWT" + + # Strict assertion only when the user explicitly opts in via env var. + if os.environ.get("ASSERT_TAGS_IN_JWT") == "1": + assert our_tags_present, ( + f"ASSERT_TAGS_IN_JWT=1 set but our tags aren't in the JWT. " + f"Role likely missing sts:TagSession permission. " + f"Found tags: {found_tags}" + ) + + +def test_create_handler_jwt_audience_matches_request(): + """The JWT payload's 'aud' claim matches the audience we asked AWS for. + + Defensive — catches any audience-shaping surprises (AWS should not munge + the audience string). Mirrors the .NET integration test's assertion. + """ + handler = create_handler(_default_config(), None) + token, *_ = handler("") + payload = _decode_jwt_payload(token) + assert payload.get("aud") == _AUDIENCE, f"JWT aud {payload.get('aud')!r} doesn't match requested {_AUDIENCE!r}" + + +# ============================================================================= +# Full dispatcher flow — Producer with marker +# ============================================================================= + + +def test_producer_with_marker_succeeds_against_real_sts(): + """End-to-end: construct a Producer with the marker set; the C dispatcher + in common_conf_setup() invokes create_handler() which builds a real + provider; librdkafka's background thread invokes the autowired callback; + STS mints a real token; rd_kafka_oauthbearer_set_token marks the token + as set; wait_for_oauth_token_set returns successfully and the Producer + constructor returns. + + The bootstrap.servers points at a non-resolvable host so we don't waste + a real broker connection — the OAUTHBEARER refresh path doesn't need + one. It fires immediately on background-thread startup. + """ + from confluent_kafka import Producer + + t0 = time.time() + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": _default_config(), + } + ) + elapsed = time.time() - t0 + + # Producer construction returns once the autowire callback has set a + # real token (typically <2s on EC2). wait_for_oauth_token_set timeout + # is 10s — anything over that means the autowire chain didn't fire. + assert elapsed < 5.0, ( + f"Producer construction took {elapsed:.1f}s — autowire path slow " + f"or broken (10s timeout = OAuth refresh never set the token)" + ) + print(f"\n[autowire-real] Producer constructed via dispatcher in {elapsed:.2f}s") + + +def test_producer_with_marker_and_extensions_succeeds(): + """Same end-to-end path with sasl.oauthbearer.extensions set on the + typed property. Proves the C dispatcher reads the extensions string + and passes it as the second arg to create_handler.""" + from confluent_kafka import Producer + + t0 = time.time() + Producer( + { + "bootstrap.servers": "broker.invalid:9092", + "security.protocol": "SASL_SSL", + "sasl.mechanisms": "OAUTHBEARER", + "sasl.oauthbearer.method": "oidc", + "sasl.oauthbearer.metadata.authentication.type": "aws_iam", + "sasl.oauthbearer.config": _default_config(), + "sasl.oauthbearer.extensions": "logicalCluster=lkc-abc", + } + ) + elapsed = time.time() - t0 + assert elapsed < 5.0, f"Producer with extensions took {elapsed:.1f}s — autowire path " f"slow or broken" + print(f"\n[autowire-real] Producer (with extensions) constructed via " f"dispatcher in {elapsed:.2f}s") diff --git a/tests/soak/setup_all_versions.py b/tests/soak/setup_all_versions.py index 61c4dcc11..b9b322102 100755 --- a/tests/soak/setup_all_versions.py +++ b/tests/soak/setup_all_versions.py @@ -5,7 +5,7 @@ PYTHON_SOAK_TEST_BRANCH = 'master' LIBRDKAFKA_VERSIONS = [ - '2.14.2', + '2.14.2-aws-iam.2-dev', '2.13.2', '2.13.0', '2.12.1', @@ -22,7 +22,7 @@ ] PYTHON_VERSIONS = [ - '2.14.2', + '2.14.2.dev3', '2.13.2', '2.13.0', '2.12.1', diff --git a/tools/source-package-verification.sh b/tools/source-package-verification.sh index 919d3d26a..e18c918d7 100755 --- a/tools/source-package-verification.sh +++ b/tools/source-package-verification.sh @@ -71,7 +71,7 @@ if [[ $OS_NAME == linux && $ARCH == x64 ]]; then python3 -c " import importlib.metadata, sys extras = set(importlib.metadata.metadata('confluent-kafka').get_all('Provides-Extra') or []) -required = {'schema-registry', 'schemaregistry', 'avro', 'json', 'protobuf', 'rules'} +required = {'schema-registry', 'schemaregistry', 'avro', 'json', 'protobuf', 'rules', 'oauthbearer-aws'} missing = required - extras if missing: print(f'Failing: package does not provide extras: {missing}', file=sys.stderr)