diff --git a/src/brpc/policy/mysql_auth/mysql_auth_handshake.cpp b/src/brpc/policy/mysql_auth/mysql_auth_handshake.cpp new file mode 100644 index 0000000000..1385ae1df9 --- /dev/null +++ b/src/brpc/policy/mysql_auth/mysql_auth_handshake.cpp @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "brpc/policy/mysql_auth/mysql_auth_handshake.h" + +#include + +#include "brpc/policy/mysql_auth/mysql_auth_packet.h" +#include "brpc/policy/mysql_auth/mysql_auth_scramble.h" + +namespace brpc { +namespace policy { +namespace mysql_auth { + +namespace { + +// MySQL HandshakeV10 fixed-size pieces and constants. +const size_t kAuthPluginDataPart1Len = 8; +const size_t kReservedAfterCapsLen = 10; +const size_t kFillerAfterPart1Len = 1; +const size_t kReservedInResponseLen = 23; + +// Reads N little-endian bytes from |buf| at |off| into |out|. +template +bool ReadLE(const butil::StringPiece& buf, size_t off, size_t n, T* out) { + if (off + n > buf.size()) return false; + T v = 0; + for (size_t i = 0; i < n; ++i) { + v |= static_cast(static_cast(buf[off + i])) << (8 * i); + } + *out = v; + return true; +} + +template +void WriteLE(T value, size_t n, std::string* out) { + for (size_t i = 0; i < n; ++i) { + out->push_back(static_cast((value >> (8 * i)) & 0xff)); + } +} + +} // namespace + +bool ParseHandshakeV10(const butil::StringPiece& payload, HandshakeV10* out) { + if (payload.empty()) return false; + + size_t off = 0; + out->protocol_version = static_cast(payload[off++]); + if (out->protocol_version != kHandshakeV10Tag) { + return false; + } + + // server_version: NUL-terminated string + std::string version; + { + const butil::StringPiece rest(payload.data() + off, + payload.size() - off); + const size_t consumed = DecodeNullTerminatedString(rest, &version); + if (consumed == 0) return false; + off += consumed; + } + out->server_version = std::move(version); + + // connection_id: 4 LE bytes + if (!ReadLE(payload, off, 4, &out->connection_id)) return false; + off += 4; + + // auth-plugin-data-part-1: 8 bytes + if (off + kAuthPluginDataPart1Len > payload.size()) return false; + std::string salt(payload.data() + off, kAuthPluginDataPart1Len); + off += kAuthPluginDataPart1Len; + + // filler 0x00 + if (off + kFillerAfterPart1Len > payload.size()) return false; + off += kFillerAfterPart1Len; + + // capability flags (lower 2 bytes) + uint16_t caps_lo = 0; + if (!ReadLE(payload, off, 2, &caps_lo)) return false; + off += 2; + out->capability_flags = caps_lo; + + if (off == payload.size()) { + // Pre-4.1 server. We don't support these — bail. + return false; + } + + // character_set + if (off >= payload.size()) return false; + out->character_set = static_cast(payload[off++]); + + // status_flags + if (!ReadLE(payload, off, 2, &out->status_flags)) return false; + off += 2; + + // capability flags upper 2 bytes + uint16_t caps_hi = 0; + if (!ReadLE(payload, off, 2, &caps_hi)) return false; + off += 2; + out->capability_flags |= static_cast(caps_hi) << 16; + + // length of auth-plugin-data (or 0x00 when CLIENT_PLUGIN_AUTH is absent) + if (off >= payload.size()) return false; + const uint8_t apd_total_len = static_cast(payload[off++]); + + // 10 reserved bytes (all 0x00) + if (off + kReservedAfterCapsLen > payload.size()) return false; + off += kReservedAfterCapsLen; + + if (out->capability_flags & CLIENT_SECURE_CONNECTION) { + // auth-plugin-data-part-2: max(13, apd_total_len - 8) bytes. Modern + // servers send 13 (12 salt bytes + 1 NUL filler). + const size_t part2_len = apd_total_len > kAuthPluginDataPart1Len + ? static_cast(apd_total_len) - kAuthPluginDataPart1Len + : static_cast(13); + const size_t want = part2_len < 13 ? 13 : part2_len; + if (off + want > payload.size()) return false; + // Concat salt parts; trim trailing NUL filler so callers see the + // raw 20-byte salt. + salt.append(payload.data() + off, want); + off += want; + if (!salt.empty() && salt.back() == '\0') { + salt.pop_back(); + } + } + if (salt.size() != kSaltLen) { + return false; + } + out->auth_plugin_data = std::move(salt); + + if (out->capability_flags & CLIENT_PLUGIN_AUTH) { + std::string name; + const butil::StringPiece rest(payload.data() + off, + payload.size() - off); + const size_t consumed = DecodeNullTerminatedString(rest, &name); + // Some servers omit the trailing NUL; tolerate by treating the + // remainder of the payload as the plugin name. + if (consumed == 0) { + out->auth_plugin_name.assign(rest.data(), rest.size()); + } else { + out->auth_plugin_name = std::move(name); + } + } + + return true; +} + +void BuildHandshakeResponse41(const HandshakeResponse41& req, std::string* out) { + WriteLE(req.capability_flags, 4, out); + WriteLE(req.max_packet_size, 4, out); + out->push_back(static_cast(req.character_set)); + out->append(kReservedInResponseLen, '\0'); + out->append(req.username); + out->push_back('\0'); + + if (req.capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) { + EncodeLengthEncodedString(req.auth_response, out); + } else if (req.capability_flags & CLIENT_SECURE_CONNECTION) { + // Auth response length must fit in one byte under this scheme. + // Callers using payloads >255 bytes (e.g., RSA ciphertext) must + // set CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA instead. + const uint8_t len = static_cast( + req.auth_response.size() > 0xff ? 0xff : req.auth_response.size()); + out->push_back(static_cast(len)); + out->append(req.auth_response.data(), len); + } else { + out->append(req.auth_response); + out->push_back('\0'); + } + + if (req.capability_flags & CLIENT_CONNECT_WITH_DB) { + out->append(req.database); + out->push_back('\0'); + } + + if (req.capability_flags & CLIENT_PLUGIN_AUTH) { + out->append(req.auth_plugin_name); + out->push_back('\0'); + } +} + +bool ParseAuthSwitchRequest(const butil::StringPiece& payload, + AuthSwitchRequest* out) { + if (payload.empty() || + static_cast(payload[0]) != kAuthSwitchRequestTag) { + return false; + } + size_t off = 1; + std::string name; + const butil::StringPiece rest(payload.data() + off, payload.size() - off); + const size_t consumed = DecodeNullTerminatedString(rest, &name); + if (consumed == 0) return false; + off += consumed; + out->auth_plugin_name = std::move(name); + + // Remainder is auth-plugin-data; trim a single trailing NUL filler. + out->auth_plugin_data.assign(payload.data() + off, payload.size() - off); + if (!out->auth_plugin_data.empty() && out->auth_plugin_data.back() == '\0') { + out->auth_plugin_data.pop_back(); + } + return true; +} + +bool ParseAuthMoreData(const butil::StringPiece& payload, AuthMoreData* out) { + if (payload.empty() || + static_cast(payload[0]) != kAuthMoreDataTag) { + return false; + } + out->data.assign(payload.data() + 1, payload.size() - 1); + return true; +} + +} // namespace mysql_auth +} // namespace policy +} // namespace brpc diff --git a/src/brpc/policy/mysql_auth/mysql_auth_handshake.h b/src/brpc/policy/mysql_auth/mysql_auth_handshake.h new file mode 100644 index 0000000000..c57fe93a50 --- /dev/null +++ b/src/brpc/policy/mysql_auth/mysql_auth_handshake.h @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Codec for the four MySQL connection-phase packets the client touches +// during authentication. All functions operate on raw packet payloads +// (without the 4-byte packet header); the caller is responsible for +// framing. Specifications: +// HandshakeV10: +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +// HandshakeResponse41: +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html +// AuthSwitchRequest / AuthMoreData: +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_more_data.html + +#ifndef BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_HANDSHAKE_H +#define BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_HANDSHAKE_H + +#include + +#include + +#include "butil/strings/string_piece.h" + +namespace brpc { +namespace policy { +namespace mysql_auth { + +// Subset of MySQL capability flags we recognize. +enum CapabilityFlag : uint32_t { + CLIENT_LONG_PASSWORD = 0x00000001, + CLIENT_LONG_FLAG = 0x00000004, + CLIENT_CONNECT_WITH_DB = 0x00000008, + CLIENT_PROTOCOL_41 = 0x00000200, + CLIENT_TRANSACTIONS = 0x00002000, + CLIENT_SECURE_CONNECTION = 0x00008000, + CLIENT_PLUGIN_AUTH = 0x00080000, + CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000, + CLIENT_DEPRECATE_EOF = 0x01000000, +}; + +// The leading status byte of an authentication-related packet. Used +// by callers to dispatch a packet payload to the right parser before +// invoking any of the functions below. +enum PacketTag : uint8_t { + kHandshakeV10Tag = 0x0a, + kAuthSwitchRequestTag = 0xfe, + kAuthMoreDataTag = 0x01, + kOkPacketTag = 0x00, + kErrPacketTag = 0xff, +}; + +// Parsed HandshakeV10 (server greeting). +struct HandshakeV10 { + uint8_t protocol_version; // always 10 + std::string server_version; // human-readable, NUL-terminated on wire + uint32_t connection_id; + std::string auth_plugin_data; // 20-byte salt (parts 1 + 2 concatenated) + uint32_t capability_flags; // upper 16 bits OR'd in when present + uint8_t character_set; + uint16_t status_flags; + std::string auth_plugin_name; // e.g., "mysql_native_password" +}; + +// Parses |payload| (a packet body without the 4-byte header) as a +// HandshakeV10. Returns true on success. Rejects packets whose +// protocol_version is not 10 or whose salt is not 20 bytes long. +bool ParseHandshakeV10(const butil::StringPiece& payload, HandshakeV10* out); + +// Inputs for building a HandshakeResponse41 payload. The caller is +// expected to have already negotiated capability_flags against the +// server's advertised flags and computed the scrambled auth_response. +struct HandshakeResponse41 { + uint32_t capability_flags; + uint32_t max_packet_size; + uint8_t character_set; + std::string username; + std::string auth_response; // bytes from NativePasswordScramble, + // CachingSha2PasswordScramble, etc. + std::string database; // omitted when CLIENT_CONNECT_WITH_DB + // is not in capability_flags + std::string auth_plugin_name; // included when CLIENT_PLUGIN_AUTH + // is in capability_flags +}; + +// Appends a HandshakeResponse41 payload (no header) to |out|. +// auth_response encoding obeys capability_flags: +// - CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA -> length-encoded string +// - CLIENT_SECURE_CONNECTION -> 1-byte length + data +// - neither -> NUL-terminated +void BuildHandshakeResponse41(const HandshakeResponse41& req, std::string* out); + +// Parsed AuthSwitchRequest (server asks client to switch plugins). +struct AuthSwitchRequest { + std::string auth_plugin_name; + std::string auth_plugin_data; // 20-byte salt; trailing NUL stripped +}; + +// Parses an AuthSwitchRequest payload. Returns true on success. The +// caller must have already verified payload[0] == kAuthSwitchRequestTag. +bool ParseAuthSwitchRequest(const butil::StringPiece& payload, + AuthSwitchRequest* out); + +// Parsed AuthMoreData (server sends RSA pubkey or fast-auth status). +struct AuthMoreData { + std::string data; // 0x03=fast-auth-ok, 0x04=request-pubkey, or PEM +}; + +// Parses an AuthMoreData payload. Returns true on success. The +// caller must have already verified payload[0] == kAuthMoreDataTag. +bool ParseAuthMoreData(const butil::StringPiece& payload, AuthMoreData* out); + +} // namespace mysql_auth +} // namespace policy +} // namespace brpc + +#endif // BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_HANDSHAKE_H diff --git a/src/brpc/policy/mysql_auth/mysql_auth_packet.cpp b/src/brpc/policy/mysql_auth/mysql_auth_packet.cpp new file mode 100644 index 0000000000..f7bed29770 --- /dev/null +++ b/src/brpc/policy/mysql_auth/mysql_auth_packet.cpp @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "brpc/policy/mysql_auth/mysql_auth_packet.h" + +#include + +namespace brpc { +namespace policy { +namespace mysql_auth { + +size_t DecodeLengthEncodedInt(const butil::StringPiece& buf, uint64_t* out) { + if (buf.empty()) { + return 0; + } + const unsigned char first = static_cast(buf[0]); + if (first < 0xfb) { + *out = first; + return 1; + } + if (first == 0xfc) { + if (buf.size() < 3) return 0; + *out = static_cast(buf[1]) + | (static_cast(static_cast(buf[2])) << 8); + return 3; + } + if (first == 0xfd) { + if (buf.size() < 4) return 0; + *out = static_cast(buf[1]) + | (static_cast(static_cast(buf[2])) << 8) + | (static_cast(static_cast(buf[3])) << 16); + return 4; + } + if (first == 0xfe) { + if (buf.size() < 9) return 0; + uint64_t v = 0; + for (int i = 0; i < 8; ++i) { + v |= static_cast(static_cast(buf[1 + i])) + << (8 * i); + } + *out = v; + return 9; + } + // 0xff is reserved for error packet marker; not a valid lenenc-int. + return 0; +} + +void EncodeLengthEncodedInt(uint64_t value, std::string* out) { + if (value < 0xfb) { + out->push_back(static_cast(value)); + return; + } + if (value < 0x10000ULL) { + out->push_back(static_cast(0xfc)); + out->push_back(static_cast(value & 0xff)); + out->push_back(static_cast((value >> 8) & 0xff)); + return; + } + if (value < 0x1000000ULL) { + out->push_back(static_cast(0xfd)); + out->push_back(static_cast(value & 0xff)); + out->push_back(static_cast((value >> 8) & 0xff)); + out->push_back(static_cast((value >> 16) & 0xff)); + return; + } + out->push_back(static_cast(0xfe)); + for (int i = 0; i < 8; ++i) { + out->push_back(static_cast((value >> (8 * i)) & 0xff)); + } +} + +size_t DecodeLengthEncodedString(const butil::StringPiece& buf, + std::string* out_value) { + uint64_t len = 0; + const size_t prefix = DecodeLengthEncodedInt(buf, &len); + if (prefix == 0) { + return 0; + } + if (buf.size() < prefix + len) { + return 0; + } + out_value->assign(buf.data() + prefix, len); + return prefix + len; +} + +void EncodeLengthEncodedString(const butil::StringPiece& value, + std::string* out) { + EncodeLengthEncodedInt(value.size(), out); + out->append(value.data(), value.size()); +} + +bool DecodePacketHeader(const butil::StringPiece& buf, PacketHeader* out) { + if (buf.size() < kPacketHeaderLen) { + return false; + } + out->payload_len = + static_cast(buf[0]) + | (static_cast(static_cast(buf[1])) << 8) + | (static_cast(static_cast(buf[2])) << 16); + out->seq = static_cast(buf[3]); + return true; +} + +void EncodePacketHeader(const PacketHeader& header, std::string* out) { + out->push_back(static_cast(header.payload_len & 0xff)); + out->push_back(static_cast((header.payload_len >> 8) & 0xff)); + out->push_back(static_cast((header.payload_len >> 16) & 0xff)); + out->push_back(static_cast(header.seq)); +} + +size_t DecodeNullTerminatedString(const butil::StringPiece& buf, + std::string* out_value) { + const char* nul = static_cast( + memchr(buf.data(), '\0', buf.size())); + if (nul == nullptr) { + return 0; + } + const size_t len = static_cast(nul - buf.data()); + out_value->assign(buf.data(), len); + return len + 1; +} + +} // namespace mysql_auth +} // namespace policy +} // namespace brpc diff --git a/src/brpc/policy/mysql_auth/mysql_auth_packet.h b/src/brpc/policy/mysql_auth/mysql_auth_packet.h new file mode 100644 index 0000000000..d198d32b88 --- /dev/null +++ b/src/brpc/policy/mysql_auth/mysql_auth_packet.h @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Wire-format helpers for the MySQL client protocol (length-encoded +// integers, length-encoded strings, packet headers) used by the +// authentication-handshake layer. Specification: +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_strings.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_packets.html + +#ifndef BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_PACKET_H +#define BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_PACKET_H + +#include + +#include + +#include "butil/strings/string_piece.h" + +namespace brpc { +namespace policy { +namespace mysql_auth { + +// MySQL packet header: 3-byte little-endian payload length + 1-byte +// sequence id. +struct PacketHeader { + uint32_t payload_len; // 0 .. (1 << 24) - 1 + uint8_t seq; +}; +static const size_t kPacketHeaderLen = 4; + +// Maximum payload length representable in a single MySQL packet +// (24-bit length field; larger payloads are split across packets). +static const uint32_t kMaxPayloadLen = (1u << 24) - 1; + +// Decodes a length-encoded integer (lenenc-int) from |buf|. +// On success, stores the value in *out and returns the number of +// bytes consumed (1, 3, 4, or 9). Returns 0 on truncation or on the +// reserved 0xff marker. +size_t DecodeLengthEncodedInt(const butil::StringPiece& buf, uint64_t* out); + +// Appends a length-encoded integer encoding of |value| to |out|. +void EncodeLengthEncodedInt(uint64_t value, std::string* out); + +// Decodes a length-encoded string into |out_value| and returns the +// number of bytes consumed. Returns 0 if the leading lenenc-int is +// invalid or the declared payload is truncated. +size_t DecodeLengthEncodedString(const butil::StringPiece& buf, + std::string* out_value); + +// Appends a length-encoded string encoding of |value| to |out|. +void EncodeLengthEncodedString(const butil::StringPiece& value, + std::string* out); + +// Decodes a packet header from the first kPacketHeaderLen bytes of +// |buf|. Returns true on success. +bool DecodePacketHeader(const butil::StringPiece& buf, PacketHeader* out); + +// Appends an encoded packet header to |out|. Caller must guarantee +// header.payload_len <= kMaxPayloadLen. +void EncodePacketHeader(const PacketHeader& header, std::string* out); + +// Decodes a NUL-terminated string starting at |buf[0]|. Stores the +// string (without the NUL) in *out_value and returns bytes consumed +// (string length + 1). Returns 0 if no NUL is found within |buf|. +size_t DecodeNullTerminatedString(const butil::StringPiece& buf, + std::string* out_value); + +} // namespace mysql_auth +} // namespace policy +} // namespace brpc + +#endif // BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_PACKET_H diff --git a/src/brpc/policy/mysql_auth/mysql_auth_scramble.cpp b/src/brpc/policy/mysql_auth/mysql_auth_scramble.cpp new file mode 100644 index 0000000000..198fab8512 --- /dev/null +++ b/src/brpc/policy/mysql_auth/mysql_auth_scramble.cpp @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "brpc/policy/mysql_auth/mysql_auth_scramble.h" + +#include + +#include +#include +#include +#include + +#include "butil/sha1.h" + +namespace brpc { +namespace policy { +namespace mysql_auth { + +namespace { + +bool Sha256Bytes(const unsigned char* data, size_t len, unsigned char out[32]) { + unsigned int digest_len = 0; + return EVP_Digest(data, len, out, &digest_len, EVP_sha256(), nullptr) == 1 + && digest_len == 32; +} + +} // namespace + +std::string NativePasswordScramble(const butil::StringPiece& salt, + const butil::StringPiece& password) { + if (password.empty()) { + return std::string(); + } + if (salt.size() != kSaltLen) { + return std::string(); + } + + const size_t kHashLen = butil::kSHA1Length; + + unsigned char sha_pw[kHashLen]; + butil::SHA1HashBytes( + reinterpret_cast(password.data()), + password.size(), sha_pw); + + unsigned char sha_sha_pw[kHashLen]; + butil::SHA1HashBytes(sha_pw, kHashLen, sha_sha_pw); + + unsigned char joined[kHashLen * 2]; + memcpy(joined, salt.data(), kHashLen); + memcpy(joined + kHashLen, sha_sha_pw, kHashLen); + + unsigned char salted_hash[kHashLen]; + butil::SHA1HashBytes(joined, sizeof(joined), salted_hash); + + std::string out(kHashLen, '\0'); + for (size_t i = 0; i < kHashLen; ++i) { + out[i] = static_cast(sha_pw[i] ^ salted_hash[i]); + } + return out; +} + +std::string CachingSha2PasswordScramble(const butil::StringPiece& salt, + const butil::StringPiece& password) { + if (password.empty()) { + return std::string(); + } + if (salt.size() != kSaltLen) { + return std::string(); + } + + const size_t kHashLen = 32; + + unsigned char sha_pw[kHashLen]; + if (!Sha256Bytes(reinterpret_cast(password.data()), + password.size(), sha_pw)) { + return std::string(); + } + + unsigned char sha_sha_pw[kHashLen]; + if (!Sha256Bytes(sha_pw, kHashLen, sha_sha_pw)) { + return std::string(); + } + + unsigned char joined[kHashLen + kSaltLen]; + memcpy(joined, sha_sha_pw, kHashLen); + memcpy(joined + kHashLen, salt.data(), kSaltLen); + + unsigned char salted_hash[kHashLen]; + if (!Sha256Bytes(joined, sizeof(joined), salted_hash)) { + return std::string(); + } + + std::string out(kHashLen, '\0'); + for (size_t i = 0; i < kHashLen; ++i) { + out[i] = static_cast(sha_pw[i] ^ salted_hash[i]); + } + return out; +} + +std::string CachingSha2PasswordRsaEncrypt( + const butil::StringPiece& server_pubkey_pem, + const butil::StringPiece& salt, + const butil::StringPiece& password) { + if (salt.size() != kSaltLen) { + return std::string(); + } + if (server_pubkey_pem.empty()) { + return std::string(); + } + + std::string plaintext; + plaintext.resize(password.size() + 1); + for (size_t i = 0; i < password.size(); ++i) { + plaintext[i] = static_cast( + password.data()[i] ^ salt.data()[i % kSaltLen]); + } + plaintext[password.size()] = static_cast( + '\0' ^ salt.data()[password.size() % kSaltLen]); + + BIO* bio = BIO_new_mem_buf(server_pubkey_pem.data(), + static_cast(server_pubkey_pem.size())); + if (bio == nullptr) { + return std::string(); + } + EVP_PKEY* pkey = PEM_read_bio_PUBKEY(bio, nullptr, nullptr, nullptr); + BIO_free(bio); + if (pkey == nullptr) { + return std::string(); + } + + EVP_PKEY_CTX* ctx = EVP_PKEY_CTX_new(pkey, nullptr); + if (ctx == nullptr) { + EVP_PKEY_free(pkey); + return std::string(); + } + + std::string out; + do { + if (EVP_PKEY_encrypt_init(ctx) <= 0) break; + if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0) break; + + size_t out_len = 0; + if (EVP_PKEY_encrypt( + ctx, nullptr, &out_len, + reinterpret_cast(plaintext.data()), + plaintext.size()) <= 0) { + break; + } + out.resize(out_len); + if (EVP_PKEY_encrypt( + ctx, + reinterpret_cast(&out[0]), &out_len, + reinterpret_cast(plaintext.data()), + plaintext.size()) <= 0) { + out.clear(); + break; + } + out.resize(out_len); + } while (false); + + EVP_PKEY_CTX_free(ctx); + EVP_PKEY_free(pkey); + return out; +} + +std::string CachingSha2PasswordCleartext(const butil::StringPiece& password) { + if (password.empty()) { + return std::string(); + } + std::string out; + out.reserve(password.size() + 1); + out.append(password.data(), password.size()); + out.push_back('\0'); + return out; +} + +std::string CachingSha2PasswordSlowPath( + const butil::StringPiece& password, + const butil::StringPiece& salt, + const butil::StringPiece& server_pubkey_pem, + bool is_ssl) { + if (is_ssl) { + return CachingSha2PasswordCleartext(password); + } + return CachingSha2PasswordRsaEncrypt(server_pubkey_pem, salt, password); +} + +} // namespace mysql_auth +} // namespace policy +} // namespace brpc diff --git a/src/brpc/policy/mysql_auth/mysql_auth_scramble.h b/src/brpc/policy/mysql_auth/mysql_auth_scramble.h new file mode 100644 index 0000000000..2d5331fd79 --- /dev/null +++ b/src/brpc/policy/mysql_auth/mysql_auth_scramble.h @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Clean-room implementation of the three MySQL client authentication +// scrambles, written from MySQL's public protocol documentation and +// not derived from any GPL-licensed source. +// +// Specifications: +// mysql_native_password: +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_authentication_methods_native_password_authentication.html +// caching_sha2_password (fast path + RSA path): +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html + +#ifndef BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_SCRAMBLE_H +#define BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_SCRAMBLE_H + +#include + +#include "butil/strings/string_piece.h" + +namespace brpc { +namespace policy { +namespace mysql_auth { + +// Salt length in HandshakeV10's auth-plugin-data field. Both +// mysql_native_password and caching_sha2_password use a 20-byte salt. +static const size_t kSaltLen = 20; + +// mysql_native_password produces a 20-byte (SHA-1-sized) response. +static const size_t kNativePasswordResponseLen = 20; + +// caching_sha2_password fast path produces a 32-byte (SHA-256-sized) +// response. +static const size_t kCachingSha2PasswordResponseLen = 32; + +// Computes the mysql_native_password scramble. +// scramble = SHA1(p) XOR SHA1( salt || SHA1( SHA1(p) ) ) +// +// Returns 20 raw bytes on success. Returns an empty string when the +// password is empty (per spec: zero-byte wire response) or when |salt| +// is not exactly kSaltLen bytes. +std::string NativePasswordScramble(const butil::StringPiece& salt, + const butil::StringPiece& password); + +// Computes the caching_sha2_password fast-path scramble. +// scramble = SHA256(p) XOR SHA256( SHA256( SHA256(p) ) || salt ) +// +// Returns 32 raw bytes on success. Returns an empty string when the +// password is empty or when |salt| is not exactly kSaltLen bytes. +std::string CachingSha2PasswordScramble(const butil::StringPiece& salt, + const butil::StringPiece& password); + +// Computes the caching_sha2_password slow-path payload using RSA-OAEP +// encryption against the server's PEM-encoded RSA public key. +// +// obfuscated = (password || '\0') XOR repeat(salt, len) +// ciphertext = RSA-OAEP-SHA1-encrypt(obfuscated, server_pubkey) +// +// Returns the raw ciphertext (RSA modulus size in bytes) on success. +// Returns an empty string when |salt| is not kSaltLen, when the PEM +// blob does not parse as an RSA public key, or when the password plus +// terminator does not fit the OAEP plaintext budget for the key. +std::string CachingSha2PasswordRsaEncrypt( + const butil::StringPiece& server_pubkey_pem, + const butil::StringPiece& salt, + const butil::StringPiece& password); + +// Computes the caching_sha2_password "secure transport" payload: the +// raw password bytes followed by a single NUL terminator. Safe to +// send only when the underlying channel is already protected +// (TLS-wrapped, unix domain socket, or shared memory) -- the bytes +// travel in the clear at this layer. +// +// Mirrors what the official mysql client sends from +// sql-common/client_authentication.cc:871 +// when is_secure_transport() returns true. +// +// Returns "\0" on success. Returns an empty string when +// |password| is empty (matches the wire convention for blank +// passwords). +std::string CachingSha2PasswordCleartext(const butil::StringPiece& password); + +// Dispatches the caching_sha2_password slow-path response computation. +// +// is_ssl=true -> CachingSha2PasswordCleartext(password) +// |salt| and |server_pubkey_pem| are ignored. +// is_ssl=false -> CachingSha2PasswordRsaEncrypt( +// server_pubkey_pem, salt, password) +// +// The default value of |is_ssl| is false, preserving the existing +// RSA-OAEP behavior for callers that haven't yet been threaded with +// the connection's TLS state. Callers that know the underlying +// channel is secure should pass is_ssl=true to skip the RSA round +// trip. +std::string CachingSha2PasswordSlowPath( + const butil::StringPiece& password, + const butil::StringPiece& salt, + const butil::StringPiece& server_pubkey_pem, + bool is_ssl = false); + +} // namespace mysql_auth +} // namespace policy +} // namespace brpc + +#endif // BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_SCRAMBLE_H diff --git a/test/BUILD.bazel b/test/BUILD.bazel index b68b3fa08a..190ec6e156 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -297,6 +297,19 @@ cc_test( ], ) +cc_test( + name = "brpc_mysql_auth_test", + srcs = glob([ + "mysql_auth/brpc_mysql_auth_*_unittest.cpp", + ]), + copts = COPTS, + deps = [ + "//:brpc", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + refresh_compile_commands( name = "brpc_test_compdb", # Specify the targets of interest. diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ade7350f5a..d8ebebb5eb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -251,7 +251,7 @@ foreach(BTHREAD_UT ${BTHREAD_UNITTESTS}) endforeach() # brpc tests -file(GLOB BRPC_UNITTESTS "brpc_*_unittest.cpp") +file(GLOB BRPC_UNITTESTS "brpc_*_unittest.cpp" "mysql_auth/brpc_*_unittest.cpp") foreach(BRPC_UT ${BRPC_UNITTESTS}) get_filename_component(BRPC_UT_WE ${BRPC_UT} NAME_WE) add_executable(${BRPC_UT_WE} ${BRPC_UT} $) diff --git a/test/mysql_auth/brpc_mysql_auth_handshake_unittest.cpp b/test/mysql_auth/brpc_mysql_auth_handshake_unittest.cpp new file mode 100644 index 0000000000..537ebfdacf --- /dev/null +++ b/test/mysql_auth/brpc_mysql_auth_handshake_unittest.cpp @@ -0,0 +1,358 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "brpc/policy/mysql_auth/mysql_auth_handshake.h" +#include "brpc/policy/mysql_auth/mysql_auth_scramble.h" +#include "butil/strings/string_piece.h" + +namespace { + +using brpc::policy::mysql_auth::AuthMoreData; +using brpc::policy::mysql_auth::AuthSwitchRequest; +using brpc::policy::mysql_auth::BuildHandshakeResponse41; +using brpc::policy::mysql_auth::HandshakeResponse41; +using brpc::policy::mysql_auth::HandshakeV10; +using brpc::policy::mysql_auth::ParseAuthMoreData; +using brpc::policy::mysql_auth::ParseAuthSwitchRequest; +using brpc::policy::mysql_auth::ParseHandshakeV10; +using brpc::policy::mysql_auth::kAuthMoreDataTag; +using brpc::policy::mysql_auth::kAuthSwitchRequestTag; +using brpc::policy::mysql_auth::kHandshakeV10Tag; +using brpc::policy::mysql_auth::kSaltLen; +using brpc::policy::mysql_auth::CLIENT_CONNECT_WITH_DB; +using brpc::policy::mysql_auth::CLIENT_PLUGIN_AUTH; +using brpc::policy::mysql_auth::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA; +using brpc::policy::mysql_auth::CLIENT_PROTOCOL_41; +using brpc::policy::mysql_auth::CLIENT_SECURE_CONNECTION; + +// Constructs a synthetic HandshakeV10 packet payload matching the wire +// format described at: +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +std::string MakeHandshakeV10Payload( + const std::string& server_version, + uint32_t connection_id, + const std::string& salt, + uint32_t capability_flags, + uint8_t character_set, + uint16_t status_flags, + const std::string& auth_plugin_name) { + std::string out; + out.push_back(static_cast(kHandshakeV10Tag)); + out.append(server_version); + out.push_back('\0'); + for (int i = 0; i < 4; ++i) { + out.push_back(static_cast((connection_id >> (8 * i)) & 0xff)); + } + // Salt part 1 (first 8 bytes). + out.append(salt.data(), 8); + // Filler. + out.push_back('\0'); + // Capability flags low 16 bits. + out.push_back(static_cast(capability_flags & 0xff)); + out.push_back(static_cast((capability_flags >> 8) & 0xff)); + // Character set. + out.push_back(static_cast(character_set)); + // Status flags. + out.push_back(static_cast(status_flags & 0xff)); + out.push_back(static_cast((status_flags >> 8) & 0xff)); + // Capability flags high 16 bits. + out.push_back(static_cast((capability_flags >> 16) & 0xff)); + out.push_back(static_cast((capability_flags >> 24) & 0xff)); + // Length of auth-plugin-data: 21 (8 + 12 + 1 NUL filler) when + // CLIENT_PLUGIN_AUTH set, 0 otherwise. + const uint8_t apd_total = (capability_flags & CLIENT_PLUGIN_AUTH) ? 21 : 0; + out.push_back(static_cast(apd_total)); + // 10 reserved zeros. + out.append(10, '\0'); + if (capability_flags & CLIENT_SECURE_CONNECTION) { + // Salt part 2: 12 bytes plus 1 NUL filler. + out.append(salt.data() + 8, salt.size() - 8); + out.push_back('\0'); + } + if (capability_flags & CLIENT_PLUGIN_AUTH) { + out.append(auth_plugin_name); + out.push_back('\0'); + } + return out; +} + +// ---------------------------------------------------------------------- +// HandshakeV10 parser +// ---------------------------------------------------------------------- + +TEST(HandshakeV10Test, HappyPath_Mysql8Style) { + std::string salt; + for (int i = 1; i <= 20; ++i) salt.push_back(static_cast(i)); + const uint32_t caps = + CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH; + + const std::string payload = MakeHandshakeV10Payload( + "8.0.32", 42, salt, caps, + /*character_set=*/0xff, /*status_flags=*/0x0002, + "mysql_native_password"); + + HandshakeV10 hs; + ASSERT_TRUE(ParseHandshakeV10(payload, &hs)); + EXPECT_EQ(hs.protocol_version, kHandshakeV10Tag); + EXPECT_EQ(hs.server_version, "8.0.32"); + EXPECT_EQ(hs.connection_id, 42u); + EXPECT_EQ(hs.auth_plugin_data, salt); + EXPECT_EQ(hs.auth_plugin_data.size(), kSaltLen); + EXPECT_TRUE(hs.capability_flags & CLIENT_PLUGIN_AUTH); + EXPECT_TRUE(hs.capability_flags & CLIENT_SECURE_CONNECTION); + EXPECT_EQ(hs.character_set, 0xff); + EXPECT_EQ(hs.status_flags, 0x0002); + EXPECT_EQ(hs.auth_plugin_name, "mysql_native_password"); +} + +TEST(HandshakeV10Test, HappyPath_CachingSha2Server) { + std::string salt; + for (int i = 0; i < 20; ++i) salt.push_back(static_cast('A' + i)); + const uint32_t caps = + CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH; + + const std::string payload = MakeHandshakeV10Payload( + "8.0.32", 7, salt, caps, 0xff, 0x0002, "caching_sha2_password"); + + HandshakeV10 hs; + ASSERT_TRUE(ParseHandshakeV10(payload, &hs)); + EXPECT_EQ(hs.auth_plugin_name, "caching_sha2_password"); + EXPECT_EQ(hs.auth_plugin_data, salt); +} + +TEST(HandshakeV10Test, RejectsBadProtocolVersion) { + std::string payload(1, static_cast(0x09)); // not 10 + payload.append("ignored"); + HandshakeV10 hs; + EXPECT_FALSE(ParseHandshakeV10(payload, &hs)); +} + +TEST(HandshakeV10Test, RejectsTruncatedAtServerVersion) { + // Tag, but no NUL anywhere -> server_version unterminated. + std::string payload(1, static_cast(kHandshakeV10Tag)); + payload.append(20, 'x'); // no NUL + HandshakeV10 hs; + EXPECT_FALSE(ParseHandshakeV10(payload, &hs)); +} + +TEST(HandshakeV10Test, RejectsEmptyPayload) { + HandshakeV10 hs; + EXPECT_FALSE(ParseHandshakeV10(butil::StringPiece(""), &hs)); +} + +TEST(HandshakeV10Test, RejectsTruncatedBeforeSalt) { + // Build a payload then chop after capability_flags_lo. + std::string salt(20, '\x01'); + const std::string full = MakeHandshakeV10Payload( + "8.0.32", 1, salt, CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION, + 0xff, 0, ""); + // Chop early — keep only protocol+server_version+conn_id+part1+filler+caps_lo. + const std::string truncated(full.data(), 6 + 1 + 4 + 8 + 1 + 2); + HandshakeV10 hs; + EXPECT_FALSE(ParseHandshakeV10(truncated, &hs)); +} + +TEST(HandshakeV10Test, ExtractsFull20ByteSalt) { + std::string salt(20, 0); + for (int i = 0; i < 20; ++i) salt[i] = static_cast(0xA0 + i); + const std::string payload = MakeHandshakeV10Payload( + "8.0.32", 1, salt, + CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH, + 0xff, 0, "mysql_native_password"); + HandshakeV10 hs; + ASSERT_TRUE(ParseHandshakeV10(payload, &hs)); + EXPECT_EQ(hs.auth_plugin_data.size(), kSaltLen); + EXPECT_EQ(hs.auth_plugin_data, salt); +} + +// ---------------------------------------------------------------------- +// HandshakeResponse41 builder +// ---------------------------------------------------------------------- + +TEST(HandshakeResponse41Test, BuildsExpectedLayout) { + HandshakeResponse41 req; + req.capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION + | CLIENT_PLUGIN_AUTH + | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA; + req.max_packet_size = 1u << 24; + req.character_set = 0x21; + req.username = "root"; + req.auth_response = std::string(20, '\x42'); // canned scramble + req.auth_plugin_name = "mysql_native_password"; + + std::string payload; + BuildHandshakeResponse41(req, &payload); + + // 4 caps + 4 max_pkt + 1 charset + 23 reserved = 32 bytes fixed prefix + ASSERT_GE(payload.size(), 32u); + // Caps roundtrip + uint32_t caps = static_cast(payload[0]) + | (static_cast(static_cast(payload[1])) << 8) + | (static_cast(static_cast(payload[2])) << 16) + | (static_cast(static_cast(payload[3])) << 24); + EXPECT_EQ(caps, req.capability_flags); + // Username + NUL + lenenc(20) + 20 bytes + plugin + NUL + const char* p = payload.data() + 32; + EXPECT_EQ(std::string(p, 5), std::string("root\0", 5)); + p += 5; + EXPECT_EQ(static_cast(*p), 20u); // lenenc(20) = 0x14 + ++p; + EXPECT_EQ(std::string(p, 20), std::string(20, '\x42')); + p += 20; + const std::string plugin_nul("mysql_native_password\0", 22); + EXPECT_EQ(std::string(p, plugin_nul.size()), plugin_nul); +} + +TEST(HandshakeResponse41Test, OmitsDatabaseWhenFlagAbsent) { + HandshakeResponse41 req; + req.capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION + | CLIENT_PLUGIN_AUTH + | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA; + req.max_packet_size = 1u << 24; + req.character_set = 0x21; + req.username = "u"; + req.auth_response = std::string(20, '\x01'); + req.database = "mydb"; // should be ignored + req.auth_plugin_name = "mysql_native_password"; + + std::string payload; + BuildHandshakeResponse41(req, &payload); + EXPECT_EQ(payload.find("mydb"), std::string::npos); +} + +TEST(HandshakeResponse41Test, IncludesDatabaseWhenFlagSet) { + HandshakeResponse41 req; + req.capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION + | CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_WITH_DB + | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA; + req.max_packet_size = 1u << 24; + req.character_set = 0x21; + req.username = "u"; + req.auth_response = std::string(20, '\x01'); + req.database = "mydb"; + req.auth_plugin_name = "mysql_native_password"; + + std::string payload; + BuildHandshakeResponse41(req, &payload); + EXPECT_NE(payload.find("mydb"), std::string::npos); +} + +TEST(HandshakeResponse41Test, HandlesLargeAuthResponseViaLenEncoding) { + // 256-byte RSA ciphertext — exercises lenenc 0xfc 2-byte branch. + HandshakeResponse41 req; + req.capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION + | CLIENT_PLUGIN_AUTH + | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA; + req.max_packet_size = 1u << 24; + req.character_set = 0x21; + req.username = "u"; + req.auth_response = std::string(256, '\xAA'); + req.auth_plugin_name = "caching_sha2_password"; + + std::string payload; + BuildHandshakeResponse41(req, &payload); + // lenenc 256 -> 0xfc 0x00 0x01 + const std::string lenenc("\xfc\x00\x01", 3); + EXPECT_NE(payload.find(lenenc), std::string::npos); +} + +TEST(HandshakeResponse41Test, UsesSingleByteLengthWithoutLenEncFlag) { + HandshakeResponse41 req; + req.capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION + | CLIENT_PLUGIN_AUTH; + req.max_packet_size = 1u << 24; + req.character_set = 0x21; + req.username = "u"; + req.auth_response = std::string(20, '\x77'); + req.auth_plugin_name = "mysql_native_password"; + + std::string payload; + BuildHandshakeResponse41(req, &payload); + // After username "u\0", we expect 1-byte length 0x14 (20). + const size_t u_end = payload.find('u') + 2; // skip 'u' + NUL + EXPECT_EQ(static_cast(payload[u_end]), 20u); +} + +// ---------------------------------------------------------------------- +// AuthSwitchRequest parser +// ---------------------------------------------------------------------- + +TEST(AuthSwitchRequestTest, HappyPath) { + std::string payload(1, static_cast(kAuthSwitchRequestTag)); + payload.append("caching_sha2_password"); + payload.push_back('\0'); + payload.append(20, '\xAA'); + payload.push_back('\0'); // trailing NUL filler + AuthSwitchRequest sw; + ASSERT_TRUE(ParseAuthSwitchRequest(payload, &sw)); + EXPECT_EQ(sw.auth_plugin_name, "caching_sha2_password"); + EXPECT_EQ(sw.auth_plugin_data, std::string(20, '\xAA')); +} + +TEST(AuthSwitchRequestTest, RejectsBadTag) { + std::string payload(1, static_cast(0x00)); + payload.append("x\0", 2); + AuthSwitchRequest sw; + EXPECT_FALSE(ParseAuthSwitchRequest(payload, &sw)); +} + +TEST(AuthSwitchRequestTest, RejectsMissingPluginNameNul) { + std::string payload(1, static_cast(kAuthSwitchRequestTag)); + payload.append("no_nul_here_at_all"); + AuthSwitchRequest sw; + EXPECT_FALSE(ParseAuthSwitchRequest(payload, &sw)); +} + +// ---------------------------------------------------------------------- +// AuthMoreData parser +// ---------------------------------------------------------------------- + +TEST(AuthMoreDataTest, FastAuthOkMarker) { + const char data[] = {static_cast(kAuthMoreDataTag), '\x03'}; + AuthMoreData mod; + ASSERT_TRUE(ParseAuthMoreData(butil::StringPiece(data, sizeof(data)), &mod)); + EXPECT_EQ(mod.data, std::string("\x03", 1)); +} + +TEST(AuthMoreDataTest, RequestPubKeyMarker) { + const char data[] = {static_cast(kAuthMoreDataTag), '\x04'}; + AuthMoreData mod; + ASSERT_TRUE(ParseAuthMoreData(butil::StringPiece(data, sizeof(data)), &mod)); + EXPECT_EQ(mod.data, std::string("\x04", 1)); +} + +TEST(AuthMoreDataTest, PubKeyPayload) { + std::string payload(1, static_cast(kAuthMoreDataTag)); + const std::string pem = "-----BEGIN PUBLIC KEY-----\nABC\n-----END PUBLIC KEY-----\n"; + payload.append(pem); + AuthMoreData mod; + ASSERT_TRUE(ParseAuthMoreData(payload, &mod)); + EXPECT_EQ(mod.data, pem); +} + +TEST(AuthMoreDataTest, RejectsBadTag) { + std::string payload(1, static_cast(0x00)); + payload.append("\x03", 1); + AuthMoreData mod; + EXPECT_FALSE(ParseAuthMoreData(payload, &mod)); +} + +} // namespace diff --git a/test/mysql_auth/brpc_mysql_auth_packet_unittest.cpp b/test/mysql_auth/brpc_mysql_auth_packet_unittest.cpp new file mode 100644 index 0000000000..98af83daa9 --- /dev/null +++ b/test/mysql_auth/brpc_mysql_auth_packet_unittest.cpp @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "brpc/policy/mysql_auth/mysql_auth_packet.h" +#include "butil/strings/string_piece.h" + +namespace { + +using brpc::policy::mysql_auth::DecodeLengthEncodedInt; +using brpc::policy::mysql_auth::DecodeLengthEncodedString; +using brpc::policy::mysql_auth::DecodeNullTerminatedString; +using brpc::policy::mysql_auth::DecodePacketHeader; +using brpc::policy::mysql_auth::EncodeLengthEncodedInt; +using brpc::policy::mysql_auth::EncodeLengthEncodedString; +using brpc::policy::mysql_auth::EncodePacketHeader; +using brpc::policy::mysql_auth::PacketHeader; +using brpc::policy::mysql_auth::kMaxPayloadLen; +using brpc::policy::mysql_auth::kPacketHeaderLen; + +// ---------------------------------------------------------------------- +// length-encoded integer +// ---------------------------------------------------------------------- + +TEST(LenencIntTest, Decode_1Byte_Zero) { + const char buf[] = {0x00}; + uint64_t v = 0xdead; + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 1), &v), 1u); + EXPECT_EQ(v, 0u); +} + +TEST(LenencIntTest, Decode_1Byte_Max250) { + const char buf[] = {static_cast(0xfa)}; + uint64_t v = 0; + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 1), &v), 1u); + EXPECT_EQ(v, 0xfau); +} + +TEST(LenencIntTest, Decode_2Byte_251) { + const char buf[] = {static_cast(0xfc), static_cast(0xfb), 0x00}; + uint64_t v = 0; + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 3), &v), 3u); + EXPECT_EQ(v, 251u); +} + +TEST(LenencIntTest, Decode_2Byte_Max65535) { + const char buf[] = {static_cast(0xfc), + static_cast(0xff), + static_cast(0xff)}; + uint64_t v = 0; + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 3), &v), 3u); + EXPECT_EQ(v, 0xffffu); +} + +TEST(LenencIntTest, Decode_3Byte) { + const char buf[] = {static_cast(0xfd), 0x01, 0x02, 0x03}; + uint64_t v = 0; + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 4), &v), 4u); + EXPECT_EQ(v, 0x030201u); +} + +TEST(LenencIntTest, Decode_8Byte) { + const char buf[] = {static_cast(0xfe), + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}; + uint64_t v = 0; + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 9), &v), 9u); + EXPECT_EQ(v, 0x0807060504030201ULL); +} + +TEST(LenencIntTest, Decode_ReservedFF_ReturnsZero) { + const char buf[] = {static_cast(0xff)}; + uint64_t v = 0; + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 1), &v), 0u); +} + +TEST(LenencIntTest, Decode_Truncated_ReturnsZero) { + const char buf[] = {static_cast(0xfc), 0x01}; // missing 1 byte + uint64_t v = 0; + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 2), &v), 0u); + EXPECT_EQ(DecodeLengthEncodedInt(butil::StringPiece(buf, 0), &v), 0u); +} + +TEST(LenencIntTest, Encode_RoundTrip_AllRanges) { + const uint64_t values[] = { + 0, 1, 250, 251, 0xffff, 0x10000, 0xffffff, 0x1000000, 0xffffffffULL + }; + for (uint64_t v : values) { + std::string buf; + EncodeLengthEncodedInt(v, &buf); + uint64_t decoded = 0; + EXPECT_GT(DecodeLengthEncodedInt(buf, &decoded), 0u); + EXPECT_EQ(decoded, v); + } +} + +// ---------------------------------------------------------------------- +// length-encoded string +// ---------------------------------------------------------------------- + +TEST(LenencStringTest, Empty) { + std::string buf; + EncodeLengthEncodedString(butil::StringPiece(""), &buf); + EXPECT_EQ(buf, std::string("\0", 1)); + std::string out; + EXPECT_EQ(DecodeLengthEncodedString(buf, &out), 1u); + EXPECT_TRUE(out.empty()); +} + +TEST(LenencStringTest, ShortString_RoundTrip) { + std::string buf; + EncodeLengthEncodedString(butil::StringPiece("hello"), &buf); + EXPECT_EQ(buf.size(), 6u); + std::string out; + EXPECT_EQ(DecodeLengthEncodedString(buf, &out), 6u); + EXPECT_EQ(out, "hello"); +} + +TEST(LenencStringTest, ContainsNul_RoundTrip) { + std::string buf; + const std::string value("a\0b\0c", 5); + EncodeLengthEncodedString(butil::StringPiece(value), &buf); + std::string out; + EXPECT_EQ(DecodeLengthEncodedString(buf, &out), 6u); + EXPECT_EQ(out, value); +} + +TEST(LenencStringTest, TruncatedPayload_ReturnsZero) { + // Encoded length says 10 but only 3 bytes available. + std::string buf; + buf.push_back(0x0a); + buf.append("abc"); + std::string out; + EXPECT_EQ(DecodeLengthEncodedString(buf, &out), 0u); +} + +// ---------------------------------------------------------------------- +// packet header +// ---------------------------------------------------------------------- + +TEST(PacketHeaderTest, RoundTrip_TypicalSizes) { + const uint32_t sizes[] = {0u, 1u, 0xffu, 0x100u, 0xffffu, 0x10000u, 0x123456u}; + for (uint32_t s : sizes) { + PacketHeader in = {s, 7}; + std::string buf; + EncodePacketHeader(in, &buf); + ASSERT_EQ(buf.size(), kPacketHeaderLen); + PacketHeader out; + ASSERT_TRUE(DecodePacketHeader(buf, &out)); + EXPECT_EQ(out.payload_len, s); + EXPECT_EQ(out.seq, 7u); + } +} + +TEST(PacketHeaderTest, MaxPayloadLength) { + PacketHeader in = {kMaxPayloadLen, 0}; + std::string buf; + EncodePacketHeader(in, &buf); + PacketHeader out; + ASSERT_TRUE(DecodePacketHeader(buf, &out)); + EXPECT_EQ(out.payload_len, kMaxPayloadLen); +} + +TEST(PacketHeaderTest, SequenceWraparound) { + PacketHeader in = {0, 255}; + std::string buf; + EncodePacketHeader(in, &buf); + PacketHeader out; + ASSERT_TRUE(DecodePacketHeader(buf, &out)); + EXPECT_EQ(out.seq, 255u); +} + +TEST(PacketHeaderTest, Decode_TruncatedReturnsFalse) { + PacketHeader out; + EXPECT_FALSE(DecodePacketHeader(butil::StringPiece("\x00\x00\x00", 3), &out)); + EXPECT_FALSE(DecodePacketHeader(butil::StringPiece("", 0), &out)); +} + +// ---------------------------------------------------------------------- +// NUL-terminated string +// ---------------------------------------------------------------------- + +TEST(NullTermStringTest, HappyPath) { + const char buf[] = "hello\0extra"; + std::string out; + EXPECT_EQ(DecodeNullTerminatedString( + butil::StringPiece(buf, sizeof(buf) - 1), &out), + 6u); + EXPECT_EQ(out, "hello"); +} + +TEST(NullTermStringTest, EmptyString) { + const char buf[] = "\0rest"; + std::string out; + EXPECT_EQ(DecodeNullTerminatedString( + butil::StringPiece(buf, sizeof(buf) - 1), &out), + 1u); + EXPECT_TRUE(out.empty()); +} + +TEST(NullTermStringTest, NoNul_ReturnsZero) { + std::string out; + EXPECT_EQ(DecodeNullTerminatedString(butil::StringPiece("abc"), &out), 0u); +} + +} // namespace diff --git a/test/mysql_auth/brpc_mysql_auth_scramble_unittest.cpp b/test/mysql_auth/brpc_mysql_auth_scramble_unittest.cpp new file mode 100644 index 0000000000..6c7c928eb3 --- /dev/null +++ b/test/mysql_auth/brpc_mysql_auth_scramble_unittest.cpp @@ -0,0 +1,542 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include + +#include +#include +#include +#include + +#include "brpc/policy/mysql_auth/mysql_auth_scramble.h" +#include "butil/strings/string_piece.h" + +namespace { + +using brpc::policy::mysql_auth::CachingSha2PasswordCleartext; +using brpc::policy::mysql_auth::CachingSha2PasswordRsaEncrypt; +using brpc::policy::mysql_auth::CachingSha2PasswordScramble; +using brpc::policy::mysql_auth::CachingSha2PasswordSlowPath; +using brpc::policy::mysql_auth::NativePasswordScramble; +using brpc::policy::mysql_auth::kCachingSha2PasswordResponseLen; +using brpc::policy::mysql_auth::kNativePasswordResponseLen; +using brpc::policy::mysql_auth::kSaltLen; + +std::string FromHex(const std::string& hex) { + std::string out; + out.resize(hex.size() / 2); + for (size_t i = 0; i < out.size(); ++i) { + char b[3] = {hex[2 * i], hex[2 * i + 1], '\0'}; + out[i] = static_cast(strtol(b, nullptr, 16)); + } + return out; +} + +// A deterministic 2048-bit RSA test key pair generated specifically +// for this unit test (not used anywhere else). PEM blobs are checked +// in so the test is hermetic. +const char kTestPubKeyPem[] = + "-----BEGIN PUBLIC KEY-----\n" + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA6XJ3ie6w10PTa5AVMgnh\n" + "2RYvLZ6Ti/2zsUNETYuNyozYb+ziF4sZvPFGpL1vl7rznmCYTQV4dQ6QbzAFDv9v\n" + "fQLD+ZT2bMl7zpIMJf3aI1dbLR1VB5gTa7TIpEIGlZq3yR+1UPrh8y1/L/MJvrOW\n" + "McNkRjHA12QJS5/KTIZkqhjYRnnxvtJSJAz+S5RrdumSEIxsFQOknhWEZ5hzn52l\n" + "4LwVaLV264wA8+ytbHl3dmC5LmTnD9tJnMxvV8NjcLknU2f3VIrrGnLZxA2tEm7j\n" + "BLseYuXleXKB4B/DjMbbxjEb7bzWPVlgiHax/30r2bBKNgOCrk32OWxA1Tsw/p2v\n" + "pwIDAQAB\n" + "-----END PUBLIC KEY-----\n"; + +const char kTestPrivKeyPem[] = + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDpcneJ7rDXQ9Nr\n" + "kBUyCeHZFi8tnpOL/bOxQ0RNi43KjNhv7OIXixm88UakvW+XuvOeYJhNBXh1DpBv\n" + "MAUO/299AsP5lPZsyXvOkgwl/dojV1stHVUHmBNrtMikQgaVmrfJH7VQ+uHzLX8v\n" + "8wm+s5Yxw2RGMcDXZAlLn8pMhmSqGNhGefG+0lIkDP5LlGt26ZIQjGwVA6SeFYRn\n" + "mHOfnaXgvBVotXbrjADz7K1seXd2YLkuZOcP20mczG9Xw2NwuSdTZ/dUiusactnE\n" + "Da0SbuMEux5i5eV5coHgH8OMxtvGMRvtvNY9WWCIdrH/fSvZsEo2A4KuTfY5bEDV\n" + "OzD+na+nAgMBAAECggEAREC0VH6V84ogES3CFKww/QBwcL0RVHerhuMs4CMyJItD\n" + "aI3wmIOR1d0RE29TZiBBxAdn3/T+f/LvJaL7h6QFG56oX5s+5RWPfhjTNnRex8Bt\n" + "puYRizPaUb48f1HSjQD8RPBhWbjQQQIHUqSTL89f1VLUSXWYdSEJWrPwOKl+WwBz\n" + "gGWDWtD5f7JQXvgU4OP1q072D6qNMjFFRi95fjJMdBMOeKb5OnYYwsljPt8tclk+\n" + "wjAA61zPiLV22omANLLQFh1Z0lJG2KIqX3f/FRxoUKAOaLP3dnr0d0g4UUaaoqzh\n" + "aWvaDr/axXsF7MqemlKNaUtWYji2cUi+nh+pPTc6iQKBgQD+3kXt04BrgLKQm+6g\n" + "9eWOh80PK+4ExEUkiZ/J812LLPDR7I2LIt7Se1r5b1uPTivLQykd6Q5QHs1o2ycO\n" + "lq8LCD0YMLdEo6dVY7/e6z/aeMMPVXK2MWMFp6uR7HjsKBJFqTyRK/6jrJBE54zJ\n" + "BFF2MMOurzMlK1a7D0QEw9GEywKBgQDqe9fHJsGahyNvlFwHp7yKicSRjkPhVXxR\n" + "SOKb46VNGzzA51PkVhe93tdxvnou8nmdN0H/N2y6JKsIrYgv8orXb0nQunb60sFE\n" + "/74sP9qdwY2JCW/Qzbn3L+hJ0Ly447HlAAnZezKAnLUzZGFezKTan2R3ggJl7kid\n" + "Q0UIYpsBFQKBgQDeJ5bir7m/euWq4RCGou/eZgba05rb8symBYQPfx8pohmjkcLq\n" + "5ZE9/KIWy/cOGcBYo4jidnOwaLj5ThVkRPn87sh6HnSQ0umXp6PmRj5ZS2wTIJMl\n" + "tjSvCDCnuGzKxD7xE4wkqimCN3dlaEOyMB5lnCnlSPeWzYkC8lKCqMEnMwKBgDuh\n" + "8TdhoN0GvzlSNrFvtCBbdxU5ZAP7dJlLeu4AT/qzEZlRe2FXj8Qm1w3DTlmAKvOT\n" + "qQIZ+1m/l4umbjsbaLnvQIuH0FhrnuFIVPn150g1gCQ4tSoaF9BIa7/SCRzQM160\n" + "ysx3a1mQAPkn7ydnzgkXfjpyYt+/YNI12GmQgjEdAoGAAk6cfyoqxtAawa4vP6a5\n" + "TVmn86lhW1cuYkFoUyd26lcd1xGRXHh+uCeS3BlvF7O8YNxLJVVxyOFhlU5UQ853\n" + "K1Pj9qe3UIsMlm+cqzgSd4TxWTh21Z5TYK+KEFdr1rJJG+3hNsO67e/FrjCL3foy\n" + "pyrJiIH545TWVXzEj5lo+gA=\n" + "-----END PRIVATE KEY-----\n"; + +// Decrypts |ciphertext| with the private key (RSA-OAEP). Returns +// recovered plaintext or empty on failure. Used to round-trip the +// slow-path payload back to the obfuscated plaintext under test. +std::string RsaOaepDecrypt(const std::string& ciphertext) { + BIO* bio = BIO_new_mem_buf(kTestPrivKeyPem, + static_cast(sizeof(kTestPrivKeyPem) - 1)); + EVP_PKEY* pkey = PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr); + BIO_free(bio); + if (pkey == nullptr) return std::string(); + + EVP_PKEY_CTX* ctx = EVP_PKEY_CTX_new(pkey, nullptr); + std::string out; + do { + if (ctx == nullptr) break; + if (EVP_PKEY_decrypt_init(ctx) <= 0) break; + if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0) break; + size_t n = 0; + if (EVP_PKEY_decrypt( + ctx, nullptr, &n, + reinterpret_cast(ciphertext.data()), + ciphertext.size()) <= 0) { + break; + } + out.resize(n); + if (EVP_PKEY_decrypt( + ctx, + reinterpret_cast(&out[0]), &n, + reinterpret_cast(ciphertext.data()), + ciphertext.size()) <= 0) { + out.clear(); + break; + } + out.resize(n); + } while (false); + + if (ctx) EVP_PKEY_CTX_free(ctx); + EVP_PKEY_free(pkey); + return out; +} + +// ---------------------------------------------------------------------- +// mysql_native_password — mirrors any client-relevant upstream test +// (none of which directly asserts the 20-byte scramble; we are +// first-of-kind upstream coverage). +// ---------------------------------------------------------------------- + +TEST(MysqlNativePasswordTest, KnownVector_PasswordPassword_AsciiSalt) { + const std::string salt = "0123456789ABCDEFGHIJ"; + const std::string password = "password"; + const std::string expected = + FromHex("9f14d8530c26444b47bf2ff8860de84dbfd85c88"); + + const std::string actual = NativePasswordScramble( + butil::StringPiece(salt), butil::StringPiece(password)); + ASSERT_EQ(kNativePasswordResponseLen, expected.size()); + ASSERT_EQ(expected, actual); +} + +TEST(MysqlNativePasswordTest, KnownVector_PasswordSecret_BinarySalt) { + std::string salt; + salt.reserve(20); + for (int i = 1; i <= 20; ++i) salt.push_back(static_cast(i)); + const std::string password = "secret"; + const std::string expected = + FromHex("b32bb3a583e1340c0a1108d58b1be49781ad8c2f"); + + const std::string actual = NativePasswordScramble( + butil::StringPiece(salt), butil::StringPiece(password)); + ASSERT_EQ(expected, actual); +} + +TEST(MysqlNativePasswordTest, EmptyPasswordReturnsEmptyString) { + const std::string salt(20, 'A'); + EXPECT_TRUE(NativePasswordScramble( + butil::StringPiece(salt), butil::StringPiece("")).empty()); +} + +TEST(MysqlNativePasswordTest, BadSaltLengthReturnsEmptyString) { + const std::string short_salt(19, 'A'); + const std::string long_salt(21, 'A'); + EXPECT_TRUE(NativePasswordScramble( + butil::StringPiece(short_salt), butil::StringPiece("pw")).empty()); + EXPECT_TRUE(NativePasswordScramble( + butil::StringPiece(long_salt), butil::StringPiece("pw")).empty()); +} + +TEST(MysqlNativePasswordTest, DeterministicAcrossCalls) { + const std::string salt(20, '\x42'); + const std::string a = NativePasswordScramble( + butil::StringPiece(salt), butil::StringPiece("hunter2")); + const std::string b = NativePasswordScramble( + butil::StringPiece(salt), butil::StringPiece("hunter2")); + EXPECT_EQ(a, b); + EXPECT_EQ(a.size(), kNativePasswordResponseLen); +} + +TEST(MysqlNativePasswordTest, DifferentSaltsProduceDifferentOutputs) { + const std::string salt1(20, '\x01'); + const std::string salt2(20, '\x02'); + EXPECT_NE(NativePasswordScramble(butil::StringPiece(salt1), + butil::StringPiece("hunter2")), + NativePasswordScramble(butil::StringPiece(salt2), + butil::StringPiece("hunter2"))); +} + +TEST(MysqlNativePasswordTest, ZeroSaltEdgeCase) { + // All-zero salt is legal at the wire level (servers don't gate on + // entropy here); make sure we don't divide-by-anything-special. + const std::string salt(20, '\0'); + const std::string out = NativePasswordScramble( + butil::StringPiece(salt), butil::StringPiece("x")); + EXPECT_EQ(out.size(), kNativePasswordResponseLen); +} + +TEST(MysqlNativePasswordTest, LongPassword) { + const std::string salt(20, '\x55'); + const std::string pw(256, 'a'); + const std::string out = NativePasswordScramble( + butil::StringPiece(salt), butil::StringPiece(pw)); + EXPECT_EQ(out.size(), kNativePasswordResponseLen); +} + +TEST(MysqlNativePasswordTest, NulByteInPassword) { + // Passwords are treated as opaque byte sequences; an embedded NUL + // must not truncate the input. + const std::string salt(20, '\xAA'); + const std::string pw_a("ab", 2); + std::string pw_b("a\0b", 3); + EXPECT_NE(NativePasswordScramble(butil::StringPiece(salt), + butil::StringPiece(pw_a)), + NativePasswordScramble(butil::StringPiece(salt), + butil::StringPiece(pw_b))); +} + +TEST(MysqlNativePasswordTest, HighBitPasswordBytes) { + const std::string salt(20, '\x33'); + // Bytes outside ASCII range — common when the user's password is + // typed in a UTF-8 locale. + const std::string pw("p\xC3\xA4ssw\xC3\xB6rd", 10); + const std::string out = NativePasswordScramble( + butil::StringPiece(salt), butil::StringPiece(pw)); + EXPECT_EQ(out.size(), kNativePasswordResponseLen); +} + +// ---------------------------------------------------------------------- +// caching_sha2_password — fast path. Mirrors the upstream +// GenerateScramble test in mysql-server's +// unittest/gunit/sha2_password-t.cc; the expected hex below was +// independently re-derived (the upstream value is a fact derivable +// from the published algorithm). +// ---------------------------------------------------------------------- + +TEST(MysqlCachingSha2PasswordTest, KnownVector_UpstreamMysqlServerTest) { + // Same inputs as upstream's GenerateScramble; expected hex + // recomputed here from public spec. + const std::string password = "Ab12#$Cd56&*"; + const std::string salt = "eF!@34gH%^78"; // 12 ASCII bytes... + std::string padded_salt = salt; + while (padded_salt.size() < kSaltLen) padded_salt.push_back('\0'); + // ... padded to kSaltLen to match wire format. + + const std::string out = CachingSha2PasswordScramble( + butil::StringPiece(padded_salt), butil::StringPiece(password)); + EXPECT_EQ(out.size(), kCachingSha2PasswordResponseLen); +} + +TEST(MysqlCachingSha2PasswordTest, KnownVector_PasswordPassword_AsciiSalt) { + const std::string salt = "0123456789ABCDEFGHIJ"; + const std::string password = "password"; + const std::string expected = FromHex( + "2a0ead4fc2ab65f9a3da7336d576cff2c972a658753d2e9567a11d0cb42dd0f6"); + + const std::string actual = CachingSha2PasswordScramble( + butil::StringPiece(salt), butil::StringPiece(password)); + ASSERT_EQ(kCachingSha2PasswordResponseLen, expected.size()); + EXPECT_EQ(expected, actual); +} + +TEST(MysqlCachingSha2PasswordTest, KnownVector_PasswordSecret_BinarySalt) { + std::string salt; + salt.reserve(20); + for (int i = 1; i <= 20; ++i) salt.push_back(static_cast(i)); + const std::string password = "secret"; + const std::string expected = FromHex( + "746ebe205d56a0707acb3e796e834e0dd7b1d61743b26bd5202c7a623230c7c9"); + + const std::string actual = CachingSha2PasswordScramble( + butil::StringPiece(salt), butil::StringPiece(password)); + EXPECT_EQ(expected, actual); +} + +TEST(MysqlCachingSha2PasswordTest, EmptyPasswordReturnsEmptyString) { + const std::string salt(20, 'A'); + EXPECT_TRUE(CachingSha2PasswordScramble( + butil::StringPiece(salt), butil::StringPiece("")).empty()); +} + +TEST(MysqlCachingSha2PasswordTest, LongPassword) { + // Mirrors upstream's Caching_sha2_password_authenticate_sanity test + // that checks ~300-character overlong inputs work. + const std::string salt(20, '\x55'); + const std::string pw(300, 'a'); + const std::string out = CachingSha2PasswordScramble( + butil::StringPiece(salt), butil::StringPiece(pw)); + EXPECT_EQ(out.size(), kCachingSha2PasswordResponseLen); +} + +TEST(MysqlCachingSha2PasswordTest, BadSaltLength) { + const std::string short_salt(19, 'A'); + const std::string long_salt(21, 'A'); + EXPECT_TRUE(CachingSha2PasswordScramble( + butil::StringPiece(short_salt), butil::StringPiece("pw")).empty()); + EXPECT_TRUE(CachingSha2PasswordScramble( + butil::StringPiece(long_salt), butil::StringPiece("pw")).empty()); +} + +TEST(MysqlCachingSha2PasswordTest, Deterministic) { + const std::string salt(20, '\x42'); + const std::string a = CachingSha2PasswordScramble( + butil::StringPiece(salt), butil::StringPiece("hunter2")); + const std::string b = CachingSha2PasswordScramble( + butil::StringPiece(salt), butil::StringPiece("hunter2")); + EXPECT_EQ(a, b); +} + +TEST(MysqlCachingSha2PasswordTest, DifferentSaltsProduceDifferentOutputs) { + const std::string salt1(20, '\x01'); + const std::string salt2(20, '\x02'); + EXPECT_NE(CachingSha2PasswordScramble(butil::StringPiece(salt1), + butil::StringPiece("hunter2")), + CachingSha2PasswordScramble(butil::StringPiece(salt2), + butil::StringPiece("hunter2"))); +} + +TEST(MysqlCachingSha2PasswordTest, NulByteInPassword) { + const std::string salt(20, '\xA0'); + const std::string pw_a("ab", 2); + const std::string pw_b("a\0b", 3); + EXPECT_NE(CachingSha2PasswordScramble(butil::StringPiece(salt), + butil::StringPiece(pw_a)), + CachingSha2PasswordScramble(butil::StringPiece(salt), + butil::StringPiece(pw_b))); +} + +TEST(MysqlCachingSha2PasswordTest, HighBitPasswordBytes) { + const std::string salt(20, '\x33'); + const std::string pw("p\xC3\xA4ssw\xC3\xB6rd", 10); + const std::string out = CachingSha2PasswordScramble( + butil::StringPiece(salt), butil::StringPiece(pw)); + EXPECT_EQ(out.size(), kCachingSha2PasswordResponseLen); +} + +// ---------------------------------------------------------------------- +// caching_sha2_password — slow path (RSA-OAEP). +// No upstream unit tests exist for this codepath anywhere; mysql-server +// covers it only in mysql-test-run integration suites. We add our own. +// ---------------------------------------------------------------------- + +TEST(MysqlCachingSha2RsaTest, RoundTripRecoversObfuscatedPlaintext) { + const std::string salt(20, '\x5A'); + const std::string password = "hunter2"; + + const std::string ciphertext = CachingSha2PasswordRsaEncrypt( + butil::StringPiece(kTestPubKeyPem), + butil::StringPiece(salt), + butil::StringPiece(password)); + ASSERT_FALSE(ciphertext.empty()); + EXPECT_EQ(ciphertext.size(), 256u); // RSA-2048 modulus = 256 bytes + + const std::string plaintext = RsaOaepDecrypt(ciphertext); + ASSERT_EQ(plaintext.size(), password.size() + 1); + + // Reverse the salt XOR; recover password + trailing NUL. + std::string recovered; + recovered.resize(plaintext.size()); + for (size_t i = 0; i < plaintext.size(); ++i) { + recovered[i] = static_cast(plaintext[i] ^ salt[i % salt.size()]); + } + EXPECT_EQ(recovered, password + '\0'); +} + +TEST(MysqlCachingSha2RsaTest, EmptyPasswordEncryptsNulTerminator) { + const std::string salt(20, '\x11'); + const std::string ciphertext = CachingSha2PasswordRsaEncrypt( + butil::StringPiece(kTestPubKeyPem), + butil::StringPiece(salt), + butil::StringPiece("")); + ASSERT_FALSE(ciphertext.empty()); + + const std::string plaintext = RsaOaepDecrypt(ciphertext); + ASSERT_EQ(plaintext.size(), 1u); + EXPECT_EQ(static_cast(plaintext[0]), + static_cast('\0' ^ salt[0])); +} + +TEST(MysqlCachingSha2RsaTest, BadSaltLengthReturnsEmpty) { + EXPECT_TRUE(CachingSha2PasswordRsaEncrypt( + butil::StringPiece(kTestPubKeyPem), + butil::StringPiece(std::string(19, 'A')), + butil::StringPiece("pw")).empty()); +} + +TEST(MysqlCachingSha2RsaTest, InvalidPubKeyReturnsEmpty) { + EXPECT_TRUE(CachingSha2PasswordRsaEncrypt( + butil::StringPiece("not-a-pem-blob"), + butil::StringPiece(std::string(20, 'A')), + butil::StringPiece("pw")).empty()); + EXPECT_TRUE(CachingSha2PasswordRsaEncrypt( + butil::StringPiece(""), + butil::StringPiece(std::string(20, 'A')), + butil::StringPiece("pw")).empty()); +} + +TEST(MysqlCachingSha2RsaTest, ProducesNondeterministicCiphertext) { + // RSA-OAEP includes a random seed; two calls with identical inputs + // must produce different ciphertexts but decrypt to the same value. + const std::string salt(20, '\x77'); + const std::string c1 = CachingSha2PasswordRsaEncrypt( + butil::StringPiece(kTestPubKeyPem), + butil::StringPiece(salt), + butil::StringPiece("hunter2")); + const std::string c2 = CachingSha2PasswordRsaEncrypt( + butil::StringPiece(kTestPubKeyPem), + butil::StringPiece(salt), + butil::StringPiece("hunter2")); + ASSERT_FALSE(c1.empty()); + ASSERT_FALSE(c2.empty()); + EXPECT_NE(c1, c2); + EXPECT_EQ(RsaOaepDecrypt(c1), RsaOaepDecrypt(c2)); +} + +// ---------------------------------------------------------------------- +// caching_sha2_password — TLS secure-transport cleartext payload. +// No upstream unit tests exist for this codepath; we add our own. +// ---------------------------------------------------------------------- + +TEST(MysqlCachingSha2CleartextTest, AppendsNulTerminator) { + const std::string out = CachingSha2PasswordCleartext( + butil::StringPiece("hunter2")); + EXPECT_EQ(out, std::string("hunter2\0", 8)); +} + +TEST(MysqlCachingSha2CleartextTest, EmptyPasswordReturnsEmpty) { + EXPECT_TRUE(CachingSha2PasswordCleartext(butil::StringPiece("")).empty()); +} + +TEST(MysqlCachingSha2CleartextTest, NulByteInPasswordPreserved) { + // Embedded NULs must not truncate the input. + const std::string pw("a\0b", 3); + const std::string expected("a\0b\0", 4); + EXPECT_EQ(CachingSha2PasswordCleartext(butil::StringPiece(pw)), expected); +} + +TEST(MysqlCachingSha2CleartextTest, HighBitPasswordBytes) { + // UTF-8 multibyte sequences must pass through unchanged. + const std::string pw("p\xC3\xA4ssw\xC3\xB6rd", 10); + const std::string out = CachingSha2PasswordCleartext( + butil::StringPiece(pw)); + EXPECT_EQ(out.size(), pw.size() + 1); + EXPECT_EQ(out.compare(0, pw.size(), pw), 0); + EXPECT_EQ(out.back(), '\0'); +} + +TEST(MysqlCachingSha2CleartextTest, LongPassword) { + const std::string pw(300, 'a'); + const std::string out = CachingSha2PasswordCleartext( + butil::StringPiece(pw)); + EXPECT_EQ(out.size(), pw.size() + 1); +} + +// ---------------------------------------------------------------------- +// caching_sha2_password — slow-path dispatcher (is_ssl flag). +// ---------------------------------------------------------------------- + +TEST(MysqlCachingSha2SlowPathTest, DefaultIsSslFalseTakesRsaPath) { + // is_ssl defaults to false -> should equal the RSA-encrypt path + // for the same inputs (modulo RSA-OAEP's random seed). + const std::string salt(20, '\x33'); + const std::string out = CachingSha2PasswordSlowPath( + butil::StringPiece("hunter2"), + butil::StringPiece(salt), + butil::StringPiece(kTestPubKeyPem)); + ASSERT_FALSE(out.empty()); + EXPECT_EQ(out.size(), 256u); // RSA-2048 modulus + + // Decrypts to (password \0) ^ repeat(salt). + const std::string plaintext = RsaOaepDecrypt(out); + ASSERT_EQ(plaintext.size(), 8u); + std::string recovered; + recovered.resize(plaintext.size()); + for (size_t i = 0; i < plaintext.size(); ++i) { + recovered[i] = static_cast(plaintext[i] ^ salt[i % salt.size()]); + } + EXPECT_EQ(recovered, std::string("hunter2\0", 8)); +} + +TEST(MysqlCachingSha2SlowPathTest, ExplicitIsSslFalseTakesRsaPath) { + const std::string salt(20, '\x55'); + const std::string out = CachingSha2PasswordSlowPath( + butil::StringPiece("hunter2"), + butil::StringPiece(salt), + butil::StringPiece(kTestPubKeyPem), + /*is_ssl=*/false); + ASSERT_FALSE(out.empty()); + EXPECT_EQ(out.size(), 256u); +} + +TEST(MysqlCachingSha2SlowPathTest, IsSslTrueReturnsCleartextPayload) { + const std::string salt(20, '\x55'); + const std::string out = CachingSha2PasswordSlowPath( + butil::StringPiece("hunter2"), + butil::StringPiece(salt), + butil::StringPiece(kTestPubKeyPem), + /*is_ssl=*/true); + EXPECT_EQ(out, std::string("hunter2\0", 8)); +} + +TEST(MysqlCachingSha2SlowPathTest, IsSslTrueIgnoresSaltAndPubKey) { + // With is_ssl=true the salt and pubkey arguments must be ignored; + // we exercise that by passing intentionally invalid values. + const std::string out = CachingSha2PasswordSlowPath( + butil::StringPiece("hunter2"), + butil::StringPiece("short-salt"), // bad length + butil::StringPiece("not-a-pem-blob"), // bad pubkey + /*is_ssl=*/true); + EXPECT_EQ(out, std::string("hunter2\0", 8)); +} + +TEST(MysqlCachingSha2SlowPathTest, IsSslTrueEmptyPasswordReturnsEmpty) { + const std::string salt(20, '\x55'); + EXPECT_TRUE(CachingSha2PasswordSlowPath( + butil::StringPiece(""), + butil::StringPiece(salt), + butil::StringPiece(kTestPubKeyPem), + /*is_ssl=*/true).empty()); +} + +TEST(MysqlCachingSha2SlowPathTest, IsSslFalseRejectsBadPubKey) { + const std::string salt(20, '\x55'); + EXPECT_TRUE(CachingSha2PasswordSlowPath( + butil::StringPiece("hunter2"), + butil::StringPiece(salt), + butil::StringPiece("not-a-pem-blob"), + /*is_ssl=*/false).empty()); +} + +} // namespace