Skip to content

Commit

Permalink
Merge pull request #6647 from abrooks98/gpu_rma_rdma
Browse files Browse the repository at this point in the history
ch4/gpu: Add GPU RDMA support for RMA

Approved-by: Hui Zhou
Approved-by: Ken Raffenetti
  • Loading branch information
hzhou committed Oct 3, 2023
2 parents 7053743 + 51b28cc commit 69522c2
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 50 deletions.
10 changes: 10 additions & 0 deletions src/include/mpir_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ MPL_STATIC_INLINE_PREFIX bool MPIR_GPU_query_pointer_is_dev(const void *ptr)
return false;
}

MPL_STATIC_INLINE_PREFIX bool MPIR_GPU_query_pointer_is_strict_dev(const void *ptr,
MPL_pointer_attr_t * attr)
{
if (ENABLE_GPU && ptr != NULL) {
return MPL_gpu_query_pointer_is_strict_dev(ptr, attr);
}

return false;
}

/* gpu registration or pinning has huge latency (~500us), thus the following
* functions should be avoided at all critical paths. Use unregistered buffer
* (MPL_malloc) instead. */
Expand Down
64 changes: 54 additions & 10 deletions src/mpid/ch4/netmod/ofi/ofi_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_get_buffered(int vci, struct fi_cq_tagged
}

MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_register_memory(char *send_buf, size_t data_sz,
MPL_pointer_attr_t attr, int ctx_idx,
struct fid_mr **mr)
MPL_pointer_attr_t * attr, int ctx_idx,
uint64_t rkey, struct fid_mr **mr)
{
struct fi_mr_attr mr_attr;
struct fi_mr_attr mr_attr = { 0 };
struct iovec iov;
int mpi_errno = MPI_SUCCESS;

Expand All @@ -696,20 +696,41 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_register_memory(char *send_buf, size_t da
mr_attr.mr_iov = &iov;
mr_attr.iov_count = 1;
mr_attr.access = FI_REMOTE_READ | FI_REMOTE_WRITE;
mr_attr.requested_key = 1;
mr_attr.requested_key = rkey;
mr_attr.offset = 0;
mr_attr.context = NULL;
#ifdef MPL_HAVE_CUDA
mr_attr.iface = (attr.type != MPL_GPU_POINTER_DEV) ? FI_HMEM_SYSTEM : FI_HMEM_CUDA;
mr_attr.iface = (attr->type != MPL_GPU_POINTER_DEV) ? FI_HMEM_SYSTEM : FI_HMEM_CUDA;
mr_attr.device.cuda =
(attr.type != MPL_GPU_POINTER_DEV) ? 0 : MPL_gpu_get_dev_id_from_attr(&attr);
(attr->type != MPL_GPU_POINTER_DEV) ? 0 : MPL_gpu_get_dev_id_from_attr(attr);
#elif defined MPL_HAVE_ZE
mr_attr.iface = (attr.type != MPL_GPU_POINTER_DEV) ? FI_HMEM_SYSTEM : FI_HMEM_ZE;
/* OFI does not support tiles yet, need to pass the root device. */
mr_attr.iface = (attr->type != MPL_GPU_POINTER_DEV) ? FI_HMEM_SYSTEM : FI_HMEM_ZE;
mr_attr.device.ze =
(attr.type !=
MPL_GPU_POINTER_DEV) ? 0 : MPL_gpu_get_root_device(MPL_gpu_get_dev_id_from_attr(&attr));
(attr->type !=
MPL_GPU_POINTER_DEV) ? 0 : MPL_gpu_get_root_device(MPL_gpu_get_dev_id_from_attr(attr));
#endif
MPIDI_OFI_CALL(fi_mr_regattr
(MPIDI_OFI_global.ctx[ctx_idx].domain, &mr_attr, 0, &(*mr)), mr_regattr);
(MPIDI_OFI_global.ctx[ctx_idx].domain, &mr_attr, 0, mr), mr_regattr);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;

fn_fail:
goto fn_exit;
}

MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_register_memory_and_bind(char *send_buf, size_t data_sz,
MPL_pointer_attr_t * attr,
int ctx_idx, struct fid_mr **mr)
{
int mpi_errno = MPI_SUCCESS;

MPIR_FUNC_ENTER;

mpi_errno = MPIDI_OFI_register_memory(send_buf, data_sz, attr, ctx_idx, 0, mr);
MPIR_ERR_CHECK(mpi_errno);

if (*mr != NULL) {
mpi_errno = MPIDI_OFI_mr_bind(MPIDI_OFI_global.prov_use[0], *mr,
Expand Down Expand Up @@ -753,6 +774,29 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_OFI_unregister_am_bufs(void)
}
}

MPL_STATIC_INLINE_PREFIX void MPIDI_OFI_gpu_rma_register(const void *buffer, size_t size,
MPL_pointer_attr_t * attr, MPIR_Win * win,
int nic, void **desc)
{
struct fid_mr *mr = NULL;
MPL_pointer_attr_t attr_tmp;

*desc = NULL;

int ctx_idx = MPIDI_OFI_get_ctx_index(MPIDI_WIN(win, am_vci), nic);
if (!attr) {
MPIR_GPU_query_pointer_attr(buffer, &attr_tmp);
attr = &attr_tmp;
}
if (MPIDI_OFI_ENABLE_HMEM && MPIDI_OFI_ENABLE_MR_HMEM &&
MPIR_GPU_query_pointer_is_strict_dev(buffer, attr)) {
MPIDI_OFI_register_memory_and_bind((char *) buffer, size, attr, ctx_idx, &mr);
if (mr != NULL) {
*desc = fi_mr_desc(mr);
}
}
}

#undef CQ_S_LIST
#undef CQ_S_HEAD
#undef CQ_S_TAIL
Expand Down
4 changes: 2 additions & 2 deletions src/mpid/ch4/netmod/ofi/ofi_recv.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf,
MPIR_Datatype_add_ref_if_not_builtin(datatype);

recv_buf = MPIR_get_contig_ptr(buf, dt_true_lb);
MPL_pointer_attr_t attr;
MPL_pointer_attr_t attr = {.type = MPL_GPU_POINTER_UNREGISTERED_HOST };
MPIR_GPU_query_pointer_attr(recv_buf, &attr);

if (MPIDI_OFI_ENABLE_HMEM && data_sz >= MPIR_CVAR_CH4_OFI_GPU_RDMA_THRESHOLD &&
Expand All @@ -193,7 +193,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf,
}

if (register_mem) {
MPIDI_OFI_register_memory(recv_buf, data_sz, attr, ctx_idx, &mr);
MPIDI_OFI_register_memory_and_bind(recv_buf, data_sz, &attr, ctx_idx, &mr);
if (mr != NULL) {
desc = fi_mr_desc(mr);
}
Expand Down
13 changes: 11 additions & 2 deletions src/mpid/ch4/netmod/ofi/ofi_rma.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ int MPIDI_OFI_nopack_putget(const void *origin_addr, MPI_Aint origin_count,
struct fi_msg_rma msg;
struct fi_rma_iov riov;
struct iovec iov;
size_t origin_bytes;

/* used for GPU buffer registration */
MPIR_Datatype_get_size_macro(origin_datatype, origin_bytes);
origin_bytes *= origin_count;

/* allocate target iovecs */
struct iovec *target_iov;
Expand All @@ -88,6 +93,11 @@ int MPIDI_OFI_nopack_putget(const void *origin_addr, MPI_Aint origin_count,
flags = FI_DELIVERY_COMPLETE;
}

void *desc = NULL;
int nic = MPIDI_OFI_get_pref_nic(win->comm_ptr, target_rank);;

MPIDI_OFI_gpu_rma_register(origin_addr, origin_bytes, NULL, win, nic, &desc);

int i = 0, j = 0;
size_t msg_len;
while (i < total_origin_iov_len && j < total_target_iov_len) {
Expand All @@ -103,8 +113,7 @@ int MPIDI_OFI_nopack_putget(const void *origin_addr, MPI_Aint origin_count,
msg_len = MPL_MIN(origin_iov[origin_cur].iov_len, target_iov[target_cur].iov_len);

int vci = MPIDI_WIN(win, am_vci);
int nic = MPIDI_OFI_get_pref_nic(win->comm_ptr, target_rank);;
msg.desc = NULL;
msg.desc = desc;
msg.addr = MPIDI_OFI_av_to_phys(addr, nic, vci);
msg.context = NULL;
msg.data = 0;
Expand Down

0 comments on commit 69522c2

Please sign in to comment.