Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions src/brpc/rdma/rdma_endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,16 @@ static const size_t RESERVED_WR_NUM = 3;
// block size (4B)
// sq size (2B)
// rq size (2B)
// lid size (2B)
// GID (16B)
// QP number (4B)
// mtu type (2B)
static const char* MAGIC_STR = "RDMA";
static const size_t MAGIC_STR_LEN = 4;
static const size_t HELLO_MSG_LEN_MIN = 40;
static const size_t HELLO_MSG_LEN_MIN = 42;
// static const size_t HELLO_MSG_LEN_MAX = 4096;
static const size_t ACK_MSG_LEN = 4;
static uint16_t g_rdma_hello_msg_len = 40; // In Byte
static uint16_t g_rdma_hello_msg_len = 42; // In Byte
static uint16_t g_rdma_hello_version = 2;
Comment on lines +94 to 95
static uint16_t g_rdma_impl_version = 1;
static uint32_t g_rdma_recv_block_size = 0;
Expand Down Expand Up @@ -118,6 +120,7 @@ struct HelloMessage {
uint16_t lid;
ibv_gid gid;
uint32_t qp_num;
uint16_t mtu_type;
};

void HelloMessage::Serialize(void* data) const {
Expand All @@ -132,8 +135,11 @@ void HelloMessage::Serialize(void* data) const {
*(current_pos++) = butil::HostToNet16(rq_size);
*(current_pos++) = butil::HostToNet16(lid);
memcpy(current_pos, gid.raw, 16);
uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16);
current_pos += 8;
uint32_t* qp_num_pos = (uint32_t*)(current_pos);
*qp_num_pos = butil::HostToNet32(qp_num);
current_pos += 2;
*(current_pos) = butil::HostToNet16(mtu_type);
}

void HelloMessage::Deserialize(void* data) {
Expand All @@ -147,7 +153,10 @@ void HelloMessage::Deserialize(void* data) {
rq_size = butil::NetToHost16(*current_pos++);
lid = butil::NetToHost16(*current_pos++);
memcpy(gid.raw, current_pos, 16);
qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16));
current_pos += 8;
qp_num = butil::NetToHost32(*(uint32_t*)(current_pos));
current_pos += 2;
mtu_type = butil::NetToHost16(*current_pos);
}

RdmaResource::~RdmaResource() {
Expand Down Expand Up @@ -435,6 +444,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
<< "Start handshake on " << s->_local_side;

uint8_t data[g_rdma_hello_msg_len];
uint16_t local_mtu_type = GetLocalMtuType();

// First initialize CQ and QP resources
ep->_state = C_ALLOC_QPCQ;
Expand Down Expand Up @@ -463,6 +473,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
// Only happens in UT
local_msg.qp_num = 0;
}
local_msg.mtu_type = local_mtu_type;
memcpy(data, MAGIC_STR, 4);
local_msg.Serialize((char*)data + 4);
if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) {
Expand Down Expand Up @@ -534,7 +545,9 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
ep->_local_window_capacity, butil::memory_order_relaxed);

ep->_state = C_BRINGUP_QP;
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) {
// use the minimum of local mtu type and remote mtu type
uint16_t min_mtu_type = std::min(local_mtu_type, remote_msg.mtu_type);
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num, min_mtu_type) < 0) {
LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description();
Comment on lines +599 to 602
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
Expand Down Expand Up @@ -582,6 +595,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
<< "Start handshake on " << s->description();

uint8_t data[g_rdma_hello_msg_len];
uint16_t local_mtu_type = GetLocalMtuType();

ep->_state = S_HELLO_WAIT;
if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) {
Expand Down Expand Up @@ -652,7 +666,9 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
ep->_state = S_BRINGUP_QP;
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) {
// use the minimum of local mtu type and remote mtu type
uint16_t min_mtu_type = std::min(local_mtu_type, remote_msg.mtu_type);
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num, min_mtu_type) < 0) {
LOG(WARNING) << "Fail to bringup QP, fallback to tcp:"
<< s->description();
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
Expand Down Expand Up @@ -681,6 +697,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
// Only happens in UT
local_msg.qp_num = 0;
}
local_msg.mtu_type = local_mtu_type;
}
memcpy(data, MAGIC_STR, 4);
local_msg.Serialize((char*)data + 4);
Expand Down Expand Up @@ -1232,12 +1249,27 @@ int RdmaEndpoint::AllocateResources() {
return 0;
}

int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num) {
int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num, uint16_t mtu_type) {
if (BAIDU_UNLIKELY(g_skip_rdma_init)) {
// For UT
return 0;
}

if (mtu_type == IBV_MTU_256) {
LOG(INFO) << "negotiated mtu is 256";
} else if (mtu_type == IBV_MTU_512) {
LOG(INFO) << "negotiated mtu is 512";
} else if (mtu_type == IBV_MTU_1024) {
LOG(INFO) << "negotiated mtu is 1024";
} else if (mtu_type == IBV_MTU_2048) {
LOG(INFO) << "negotiated mtu is 2048";
} else if (mtu_type == IBV_MTU_4096) {
LOG(INFO) << "negotiated mtu is 4096";
Comment on lines +1334 to +1342
} else {
LOG(ERROR) << "unknown mtu " << mtu_type;
return -1;
}

ibv_qp_attr attr;

attr.qp_state = IBV_QPS_INIT;
Expand Down Expand Up @@ -1275,7 +1307,7 @@ int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num) {
}

attr.qp_state = IBV_QPS_RTR;
attr.path_mtu = IBV_MTU_1024; // TODO: support more mtu in future
attr.path_mtu = ibv_mtu(mtu_type);
attr.ah_attr.grh.dgid = gid;
attr.ah_attr.grh.flow_label = 0;
attr.ah_attr.grh.sgid_index = GetRdmaGidIndex();
Expand Down
3 changes: 2 additions & 1 deletion src/brpc/rdma/rdma_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,11 @@ friend class Socket;
// lid: remote LID
// gid: remote GID
// qp_num: remote QP number
// mtu_type: the minimum of local mtu_type and remote mtu_type
// Return:
// 0: success
// -1: failed, errno set
int BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num);
int BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num, uint16_t mtu_type);

// Get event from comp channel and ack the events
int GetAndAckEvents(SocketUniquePtr& s);
Expand Down
41 changes: 41 additions & 0 deletions src/brpc/rdma/rdma_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ static int g_comp_vector_index = 0;

butil::atomic<bool> g_rdma_available(false);

static uint16_t local_mtu_type = IBV_MTU_4096;

DEFINE_int32(rdma_max_sge, 0, "Max SGE num in a WR");
DEFINE_string(rdma_device, "", "The name of the HCA device used "
"(Empty means using the first active device)");
Expand Down Expand Up @@ -455,6 +457,36 @@ static ibv_context* OpenDevice(int num_total, int* num_available_devices) {
return ret_context;
}

static uint16_t detect_mtu(struct ibv_context* ctx, int port_num) {
struct ibv_port_attr port_attr;

if (ibv_query_port(ctx, port_num, &port_attr)) {
LOG(ERROR) << "ibv_query_port failed";
return 0;
}

LOG(INFO) << "local active mtu type:" << port_attr.active_mtu
<< ", max mtu type:" << port_attr.max_mtu;

uint16_t mtu_type = port_attr.active_mtu;
if (mtu_type == IBV_MTU_256) {
LOG(INFO) << "local mtu is 256";
} else if (mtu_type == IBV_MTU_512) {
LOG(INFO) << "local mtu is 512";
} else if (mtu_type == IBV_MTU_1024) {
LOG(INFO) << "local mtu is 1024";
} else if (mtu_type == IBV_MTU_2048) {
LOG(INFO) << "local mtu is 2048";
} else if (mtu_type == IBV_MTU_4096) {
LOG(INFO) << "local mtu is 4096";
} else {
LOG(ERROR) << "unknown mtu type " << mtu_type;
return 0;
}

return mtu_type;
}

static void GlobalRdmaInitializeOrDieImpl() {
if (BAIDU_UNLIKELY(g_skip_rdma_init)) {
// Just for UT
Expand Down Expand Up @@ -549,6 +581,11 @@ static void GlobalRdmaInitializeOrDieImpl() {
g_max_sge = attr.max_sge;
}

local_mtu_type = detect_mtu(g_context, g_port_num);
if (!local_mtu_type) {
PLOG(ERROR) << "Fail to get local mtu type";
ExitWithError();
}
// Initialize RDMA memory pool (block_pool)
if (!InitBlockPool(RdmaRegisterMemory)) {
PLOG(ERROR) << "Fail to initialize RDMA memory pool";
Expand Down Expand Up @@ -701,6 +738,10 @@ bool SupportedByRdma(std::string protocol) {
return false;
}

uint16_t GetLocalMtuType() {
return local_mtu_type;
}

bool InitPollingModeWithTag(bthread_tag_t tag,
std::function<void(void)> callback,
std::function<void(void)> init_fn,
Expand Down
1 change: 1 addition & 0 deletions src/brpc/rdma/rdma_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void GlobalDisableRdma();
// If the given protocol supported by RDMA
bool SupportedByRdma(std::string protocol);

uint16_t GetLocalMtuType();
} // namespace rdma
} // namespace brpc
#else
Expand Down
Loading