diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index e21dfd23c..014f208ad 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -788,13 +788,23 @@ def _get_message_desc_proto( ) -> Tuple[str, descriptor_pb2.DescriptorProto]: index = msg_index[0] if isinstance(desc, descriptor_pb2.FileDescriptorProto): - msg = desc.message_type[index] + messages = desc.message_type + if index < 0 or index >= len(messages): + raise SerializationError( + "message index {} out of range, schema has {} top-level message(s)".format(index, len(messages)) + ) + msg = messages[index] path = path + "." + msg.name if path else msg.name if len(msg_index) == 1: return path, msg return self._get_message_desc_proto(path, msg, msg_index[1:]) else: - msg = desc.nested_type[index] + messages = desc.nested_type + if index < 0 or index >= len(messages): + raise SerializationError( + "message index {} out of range, message has {} nested message(s)".format(index, len(messages)) + ) + msg = messages[index] path = path + "." + msg.name if path else msg.name if len(msg_index) == 1: return path, msg diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 0528dffeb..a663288a7 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -777,13 +777,23 @@ def _get_message_desc_proto( ) -> Tuple[str, descriptor_pb2.DescriptorProto]: index = msg_index[0] if isinstance(desc, descriptor_pb2.FileDescriptorProto): - msg = desc.message_type[index] + messages = desc.message_type + if index < 0 or index >= len(messages): + raise SerializationError( + "message index {} out of range, schema has {} top-level message(s)".format(index, len(messages)) + ) + msg = messages[index] path = path + "." + msg.name if path else msg.name if len(msg_index) == 1: return path, msg return self._get_message_desc_proto(path, msg, msg_index[1:]) else: - msg = desc.nested_type[index] + messages = desc.nested_type + if index < 0 or index >= len(messages): + raise SerializationError( + "message index {} out of range, message has {} nested message(s)".format(index, len(messages)) + ) + msg = messages[index] path = path + "." + msg.name if path else msg.name if len(msg_index) == 1: return path, msg diff --git a/tests/schema_registry/_async/test_proto.py b/tests/schema_registry/_async/test_proto.py index b9ea6a6cd..3a7fe175c 100644 --- a/tests/schema_registry/_async/test_proto.py +++ b/tests/schema_registry/_async/test_proto.py @@ -20,14 +20,17 @@ from io import BytesIO import pytest +from google.protobuf import descriptor_pb2 from confluent_kafka.schema_registry.protobuf import ( + AsyncProtobufDeserializer, AsyncProtobufSerializer, _create_index_array, decimal_to_protobuf, protobuf_to_decimal, ) from confluent_kafka.schema_registry.serde import SchemaId +from confluent_kafka.serialization import SerializationError from tests.integration.schema_registry.data.proto import DependencyTestProto_pb2, metadata_proto_pb2 @@ -48,6 +51,40 @@ def test_create_index(pb2, coordinates): assert msg_idx == coordinates +def _two_message_file_proto(): + fdp = descriptor_pb2.FileDescriptorProto() + fdp.name = "test.proto" + fdp.package = "pkg" + first = fdp.message_type.add() + first.name = "First" + nested = first.nested_type.add() + nested.name = "Inner" + second = fdp.message_type.add() + second.name = "Second" + return fdp + + +def test_message_index_in_range(): + deserializer = object.__new__(AsyncProtobufDeserializer) + fdp = _two_message_file_proto() + + assert deserializer._get_message_desc_proto("", fdp, [0])[0] == "First" + assert deserializer._get_message_desc_proto("", fdp, [1])[0] == "Second" + assert deserializer._get_message_desc_proto("", fdp, [0, 0])[0] == "First.Inner" + + +@pytest.mark.parametrize("msg_index", [[-1], [2], [0, -1], [0, 5]]) +def test_message_index_out_of_range(msg_index): + # The message index array is attacker-controlled wire framing; a zigzag + # varint can decode to a negative or out-of-range value. A negative index + # would otherwise wrap around and resolve to a different message type. + deserializer = object.__new__(AsyncProtobufDeserializer) + fdp = _two_message_file_proto() + + with pytest.raises(SerializationError, match="out of range"): + deserializer._get_message_desc_proto("", fdp, msg_index) + + @pytest.mark.parametrize( "pb2", [ diff --git a/tests/schema_registry/_sync/test_proto.py b/tests/schema_registry/_sync/test_proto.py index d8817756f..c9b932987 100644 --- a/tests/schema_registry/_sync/test_proto.py +++ b/tests/schema_registry/_sync/test_proto.py @@ -20,14 +20,17 @@ from io import BytesIO import pytest +from google.protobuf import descriptor_pb2 from confluent_kafka.schema_registry.protobuf import ( + ProtobufDeserializer, ProtobufSerializer, _create_index_array, decimal_to_protobuf, protobuf_to_decimal, ) from confluent_kafka.schema_registry.serde import SchemaId +from confluent_kafka.serialization import SerializationError from tests.integration.schema_registry.data.proto import DependencyTestProto_pb2, metadata_proto_pb2 @@ -48,6 +51,40 @@ def test_create_index(pb2, coordinates): assert msg_idx == coordinates +def _two_message_file_proto(): + fdp = descriptor_pb2.FileDescriptorProto() + fdp.name = "test.proto" + fdp.package = "pkg" + first = fdp.message_type.add() + first.name = "First" + nested = first.nested_type.add() + nested.name = "Inner" + second = fdp.message_type.add() + second.name = "Second" + return fdp + + +def test_message_index_in_range(): + deserializer = object.__new__(ProtobufDeserializer) + fdp = _two_message_file_proto() + + assert deserializer._get_message_desc_proto("", fdp, [0])[0] == "First" + assert deserializer._get_message_desc_proto("", fdp, [1])[0] == "Second" + assert deserializer._get_message_desc_proto("", fdp, [0, 0])[0] == "First.Inner" + + +@pytest.mark.parametrize("msg_index", [[-1], [2], [0, -1], [0, 5]]) +def test_message_index_out_of_range(msg_index): + # The message index array is attacker-controlled wire framing; a zigzag + # varint can decode to a negative or out-of-range value. A negative index + # would otherwise wrap around and resolve to a different message type. + deserializer = object.__new__(ProtobufDeserializer) + fdp = _two_message_file_proto() + + with pytest.raises(SerializationError, match="out of range"): + deserializer._get_message_desc_proto("", fdp, msg_index) + + @pytest.mark.parametrize( "pb2", [