Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions src/brpc/policy/mysql_auth/mysql_auth_handshake.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "brpc/policy/mysql_auth/mysql_auth_handshake.h"

#include <cstring>

#include "brpc/policy/mysql_auth/mysql_auth_packet.h"
#include "brpc/policy/mysql_auth/mysql_auth_scramble.h"

namespace brpc {
namespace policy {
namespace mysql_auth {

namespace {

// MySQL HandshakeV10 fixed-size pieces and constants.
const size_t kAuthPluginDataPart1Len = 8;
const size_t kReservedAfterCapsLen = 10;
const size_t kFillerAfterPart1Len = 1;
const size_t kReservedInResponseLen = 23;

// Reads N little-endian bytes from |buf| at |off| into |out|.
template <typename T>
bool ReadLE(const butil::StringPiece& buf, size_t off, size_t n, T* out) {
if (off + n > buf.size()) return false;
T v = 0;
for (size_t i = 0; i < n; ++i) {
v |= static_cast<T>(static_cast<unsigned char>(buf[off + i])) << (8 * i);
}
*out = v;
return true;
}

template <typename T>
void WriteLE(T value, size_t n, std::string* out) {
for (size_t i = 0; i < n; ++i) {
out->push_back(static_cast<char>((value >> (8 * i)) & 0xff));
}
}

} // namespace

bool ParseHandshakeV10(const butil::StringPiece& payload, HandshakeV10* out) {
if (payload.empty()) return false;

size_t off = 0;
out->protocol_version = static_cast<uint8_t>(payload[off++]);
if (out->protocol_version != kHandshakeV10Tag) {
return false;
}

// server_version: NUL-terminated string
std::string version;
{
const butil::StringPiece rest(payload.data() + off,
payload.size() - off);
const size_t consumed = DecodeNullTerminatedString(rest, &version);
if (consumed == 0) return false;
off += consumed;
}
out->server_version = std::move(version);

// connection_id: 4 LE bytes
if (!ReadLE<uint32_t>(payload, off, 4, &out->connection_id)) return false;
off += 4;

// auth-plugin-data-part-1: 8 bytes
if (off + kAuthPluginDataPart1Len > payload.size()) return false;
std::string salt(payload.data() + off, kAuthPluginDataPart1Len);
off += kAuthPluginDataPart1Len;

// filler 0x00
if (off + kFillerAfterPart1Len > payload.size()) return false;
off += kFillerAfterPart1Len;

// capability flags (lower 2 bytes)
uint16_t caps_lo = 0;
if (!ReadLE<uint16_t>(payload, off, 2, &caps_lo)) return false;
off += 2;
out->capability_flags = caps_lo;

if (off == payload.size()) {
// Pre-4.1 server. We don't support these — bail.
return false;
}

// character_set
if (off >= payload.size()) return false;
out->character_set = static_cast<uint8_t>(payload[off++]);

// status_flags
if (!ReadLE<uint16_t>(payload, off, 2, &out->status_flags)) return false;
off += 2;

// capability flags upper 2 bytes
uint16_t caps_hi = 0;
if (!ReadLE<uint16_t>(payload, off, 2, &caps_hi)) return false;
off += 2;
out->capability_flags |= static_cast<uint32_t>(caps_hi) << 16;

// length of auth-plugin-data (or 0x00 when CLIENT_PLUGIN_AUTH is absent)
if (off >= payload.size()) return false;
const uint8_t apd_total_len = static_cast<uint8_t>(payload[off++]);

// 10 reserved bytes (all 0x00)
if (off + kReservedAfterCapsLen > payload.size()) return false;
off += kReservedAfterCapsLen;

if (out->capability_flags & CLIENT_SECURE_CONNECTION) {
// auth-plugin-data-part-2: max(13, apd_total_len - 8) bytes. Modern
// servers send 13 (12 salt bytes + 1 NUL filler).
const size_t part2_len = apd_total_len > kAuthPluginDataPart1Len
? static_cast<size_t>(apd_total_len) - kAuthPluginDataPart1Len
: static_cast<size_t>(13);
const size_t want = part2_len < 13 ? 13 : part2_len;
if (off + want > payload.size()) return false;
// Concat salt parts; trim trailing NUL filler so callers see the
// raw 20-byte salt.
salt.append(payload.data() + off, want);
off += want;
if (!salt.empty() && salt.back() == '\0') {
salt.pop_back();
}
}
if (salt.size() != kSaltLen) {
return false;
}
out->auth_plugin_data = std::move(salt);

if (out->capability_flags & CLIENT_PLUGIN_AUTH) {
std::string name;
const butil::StringPiece rest(payload.data() + off,
payload.size() - off);
const size_t consumed = DecodeNullTerminatedString(rest, &name);
// Some servers omit the trailing NUL; tolerate by treating the
// remainder of the payload as the plugin name.
if (consumed == 0) {
out->auth_plugin_name.assign(rest.data(), rest.size());
} else {
out->auth_plugin_name = std::move(name);
}
}

return true;
}

void BuildHandshakeResponse41(const HandshakeResponse41& req, std::string* out) {
WriteLE<uint32_t>(req.capability_flags, 4, out);
WriteLE<uint32_t>(req.max_packet_size, 4, out);
out->push_back(static_cast<char>(req.character_set));
out->append(kReservedInResponseLen, '\0');
out->append(req.username);
out->push_back('\0');

if (req.capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
EncodeLengthEncodedString(req.auth_response, out);
} else if (req.capability_flags & CLIENT_SECURE_CONNECTION) {
// Auth response length must fit in one byte under this scheme.
// Callers using payloads >255 bytes (e.g., RSA ciphertext) must
// set CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA instead.
const uint8_t len = static_cast<uint8_t>(
req.auth_response.size() > 0xff ? 0xff : req.auth_response.size());
out->push_back(static_cast<char>(len));
out->append(req.auth_response.data(), len);
} else {
out->append(req.auth_response);
out->push_back('\0');
}

if (req.capability_flags & CLIENT_CONNECT_WITH_DB) {
out->append(req.database);
out->push_back('\0');
}

if (req.capability_flags & CLIENT_PLUGIN_AUTH) {
out->append(req.auth_plugin_name);
out->push_back('\0');
}
}

bool ParseAuthSwitchRequest(const butil::StringPiece& payload,
AuthSwitchRequest* out) {
if (payload.empty() ||
static_cast<uint8_t>(payload[0]) != kAuthSwitchRequestTag) {
return false;
}
size_t off = 1;
std::string name;
const butil::StringPiece rest(payload.data() + off, payload.size() - off);
const size_t consumed = DecodeNullTerminatedString(rest, &name);
if (consumed == 0) return false;
off += consumed;
out->auth_plugin_name = std::move(name);

// Remainder is auth-plugin-data; trim a single trailing NUL filler.
out->auth_plugin_data.assign(payload.data() + off, payload.size() - off);
if (!out->auth_plugin_data.empty() && out->auth_plugin_data.back() == '\0') {
out->auth_plugin_data.pop_back();
}
return true;
}

bool ParseAuthMoreData(const butil::StringPiece& payload, AuthMoreData* out) {
if (payload.empty() ||
static_cast<uint8_t>(payload[0]) != kAuthMoreDataTag) {
return false;
}
out->data.assign(payload.data() + 1, payload.size() - 1);
return true;
}

} // namespace mysql_auth
} // namespace policy
} // namespace brpc
131 changes: 131 additions & 0 deletions src/brpc/policy/mysql_auth/mysql_auth_handshake.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Codec for the four MySQL connection-phase packets the client touches
// during authentication. All functions operate on raw packet payloads
// (without the 4-byte packet header); the caller is responsible for
// framing. Specifications:
// HandshakeV10:
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
// HandshakeResponse41:
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html
// AuthSwitchRequest / AuthMoreData:
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_more_data.html

#ifndef BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_HANDSHAKE_H
#define BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_HANDSHAKE_H

#include <stdint.h>

#include <string>

#include "butil/strings/string_piece.h"

namespace brpc {
namespace policy {
namespace mysql_auth {

// Subset of MySQL capability flags we recognize.
enum CapabilityFlag : uint32_t {
CLIENT_LONG_PASSWORD = 0x00000001,
CLIENT_LONG_FLAG = 0x00000004,
CLIENT_CONNECT_WITH_DB = 0x00000008,
CLIENT_PROTOCOL_41 = 0x00000200,
CLIENT_TRANSACTIONS = 0x00002000,
CLIENT_SECURE_CONNECTION = 0x00008000,
CLIENT_PLUGIN_AUTH = 0x00080000,
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000,
CLIENT_DEPRECATE_EOF = 0x01000000,
};

// The leading status byte of an authentication-related packet. Used
// by callers to dispatch a packet payload to the right parser before
// invoking any of the functions below.
enum PacketTag : uint8_t {
kHandshakeV10Tag = 0x0a,
kAuthSwitchRequestTag = 0xfe,
kAuthMoreDataTag = 0x01,
kOkPacketTag = 0x00,
kErrPacketTag = 0xff,
};

// Parsed HandshakeV10 (server greeting).
struct HandshakeV10 {
uint8_t protocol_version; // always 10
std::string server_version; // human-readable, NUL-terminated on wire
uint32_t connection_id;
std::string auth_plugin_data; // 20-byte salt (parts 1 + 2 concatenated)
uint32_t capability_flags; // upper 16 bits OR'd in when present
uint8_t character_set;
uint16_t status_flags;
std::string auth_plugin_name; // e.g., "mysql_native_password"
};

// Parses |payload| (a packet body without the 4-byte header) as a
// HandshakeV10. Returns true on success. Rejects packets whose
// protocol_version is not 10 or whose salt is not 20 bytes long.
bool ParseHandshakeV10(const butil::StringPiece& payload, HandshakeV10* out);

// Inputs for building a HandshakeResponse41 payload. The caller is
// expected to have already negotiated capability_flags against the
// server's advertised flags and computed the scrambled auth_response.
struct HandshakeResponse41 {
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
std::string username;
std::string auth_response; // bytes from NativePasswordScramble,
// CachingSha2PasswordScramble, etc.
std::string database; // omitted when CLIENT_CONNECT_WITH_DB
// is not in capability_flags
std::string auth_plugin_name; // included when CLIENT_PLUGIN_AUTH
// is in capability_flags
};

// Appends a HandshakeResponse41 payload (no header) to |out|.
// auth_response encoding obeys capability_flags:
// - CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA -> length-encoded string
// - CLIENT_SECURE_CONNECTION -> 1-byte length + data
// - neither -> NUL-terminated
void BuildHandshakeResponse41(const HandshakeResponse41& req, std::string* out);

// Parsed AuthSwitchRequest (server asks client to switch plugins).
struct AuthSwitchRequest {
std::string auth_plugin_name;
std::string auth_plugin_data; // 20-byte salt; trailing NUL stripped
};

// Parses an AuthSwitchRequest payload. Returns true on success. The
// caller must have already verified payload[0] == kAuthSwitchRequestTag.
bool ParseAuthSwitchRequest(const butil::StringPiece& payload,
AuthSwitchRequest* out);

// Parsed AuthMoreData (server sends RSA pubkey or fast-auth status).
struct AuthMoreData {
std::string data; // 0x03=fast-auth-ok, 0x04=request-pubkey, or PEM
};

// Parses an AuthMoreData payload. Returns true on success. The
// caller must have already verified payload[0] == kAuthMoreDataTag.
bool ParseAuthMoreData(const butil::StringPiece& payload, AuthMoreData* out);

} // namespace mysql_auth
} // namespace policy
} // namespace brpc

#endif // BRPC_POLICY_MYSQL_AUTH_MYSQL_AUTH_HANDSHAKE_H
Loading