Skip to content

Commit

Permalink
Fix mpi4py failures
Browse files Browse the repository at this point in the history
Corner cases are handled to fix mpi4py failures.

Signed-off-by: Nithya V S <Nithya.VS@amd.com>
  • Loading branch information
amd-nithyavs authored and wenduwan committed May 8, 2024
1 parent bccc940 commit 550ac58
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 29 deletions.
14 changes: 10 additions & 4 deletions ompi/mca/coll/acoll/coll_acoll_allgather.c
Expand Up @@ -268,6 +268,9 @@ static inline int mca_coll_acoll_allgather_intra(const void *sbuf, int scount,
data_blk_size[0] = bcount * (num_sgs - 2) + last_subgrp_rcnt;
blk_ofst[0] = bcount;
} else if (sg_id == num_sgs - 1) {
if (last_subgrp_size < 2) {
return err;
}
num_data_blks = 1;
data_blk_size[0] = bcount * (num_sgs - 1);
blk_ofst[0] = 0;
Expand Down Expand Up @@ -329,8 +332,7 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_
int i;
int err;
int size;
int rank, adj_rank;
int num_sgs;
int rank;
int sg_size, log2_sg_size;
int num_nodes, node_start, node_end, node_id;
int node_size, last_node_size;
Expand Down Expand Up @@ -388,7 +390,9 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_
if (size <= 2) {
intra_comm = comm;
} else {
assert(subc->local_r_comm != NULL);
if (num_nodes > 1) {
assert(subc->local_r_comm != NULL);
}
intra_comm = num_nodes == 1 ? comm : subc->local_r_comm;
}
err = mca_coll_acoll_allgather_intra(sbuf, scount, sdtype, local_rbuf, rcount, rdtype,
Expand Down Expand Up @@ -454,12 +458,14 @@ int mca_coll_acoll_allgather(const void *sbuf, int scount, struct ompi_datatype_
} /* End of if inter leader */

/* Do intra node broadcast */
num_sgs = (node_size + sg_size - 1) >> log2_sg_size;
if (node_id == 0) {
num_data_blks = 1;
data_blk_size[0] = bcount * (num_nodes - 2) + last_subgrp_rcnt;
blk_ofst[0] = bcount;
} else if (node_id == num_nodes - 1) {
if (last_node_size < 2) {
return err;
}
num_data_blks = 1;
data_blk_size[0] = bcount * (num_nodes - 1);
blk_ofst[0] = 0;
Expand Down
4 changes: 3 additions & 1 deletion ompi/mca/coll/acoll/coll_acoll_barrier.c
Expand Up @@ -125,7 +125,6 @@ static int mca_coll_acoll_barrier_send_subc(struct ompi_communicator_t *comm,
int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base_module_t *module)
{
int size, ssize, bsize;
int srank;
int err = MPI_SUCCESS;
int nreqs = 0;
ompi_request_t **reqs;
Expand All @@ -141,6 +140,9 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base

subc = &acoll_module->subc[cid];
size = ompi_comm_size(comm);
if (size == 1) {
return err;
}
if (!subc->initialized && size > 1) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0);
if (MPI_SUCCESS != err) {
Expand Down
4 changes: 2 additions & 2 deletions ompi/mca/coll/acoll/coll_acoll_bcast.c
Expand Up @@ -37,7 +37,7 @@ static int bcast_binomial(void *buff, int count, struct ompi_datatype_t *datatyp
struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs,
int world_rank)
{
int msb_pos, sub_rank, peer, err;
int msb_pos, sub_rank, peer, err = MPI_SUCCESS;
int size, rank, dim;
int i, mask;

Expand Down Expand Up @@ -83,7 +83,7 @@ static int bcast_flat_tree(void *buff, int count, struct ompi_datatype_t *dataty
int world_rank)
{
int peer;
int err;
int err = MPI_SUCCESS;
int rank = ompi_comm_rank(comm);
int size = ompi_comm_size(comm);

Expand Down
34 changes: 14 additions & 20 deletions ompi/mca/coll/acoll/coll_acoll_gather.c
Expand Up @@ -43,17 +43,16 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
int i, err, rank, size;
char *wkg = NULL, *workbuf = NULL;
MPI_Status status;
MPI_Aint incr, extent, lb;
MPI_Aint sextent, sgap = 0, ssize;
MPI_Aint rextent, rgap = 0, rsize;
MPI_Aint rextent;
int total_recv = 0;
int sg_cnt, node_cnt;
int cur_sg, root_sg;
int cur_node, root_node;
int is_base, is_local_root;
int startr, endr, inc;
int startn, endn, incn;
int num_nodes, node_id;
int startn, endn;
int num_nodes;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
coll_acoll_reserve_mem_t *reserve_mem_gather = &(acoll_module->reserve_mem_s);

Expand All @@ -70,17 +69,13 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
num_nodes = 1;
}

ompi_datatype_get_extent(rdtype, &lb, &extent);
incr = extent * (ptrdiff_t) rcount;

/* Setup root for reveive */
/* Setup root for receive */
if (rank == root) {
ompi_datatype_type_extent(rdtype, &rextent);
rsize = opal_datatype_span(&rdtype->super, (int64_t) rcount * size, &rgap);
/* Just use the recv buffer */
wkg = (char *) rbuf;
if (sbuf != MPI_IN_PLACE) {
MPI_Aint root_ofst = extent * (ptrdiff_t) (rcount * root);
MPI_Aint root_ofst = rextent * (ptrdiff_t) (rcount * root);
err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, wkg + (ptrdiff_t) root_ofst,
rcount, rdtype);
if (MPI_SUCCESS != err) {
Expand All @@ -100,7 +95,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
is_local_root = (rank % node_cnt == 0) && (cur_node != root_node);
startn = (rank / node_cnt) * node_cnt;

if (is_base || (rank == root)) {
if (is_base) {
int64_t buf_size = is_local_root ? (int64_t) scount * node_cnt : (int64_t) scount * sg_cnt;
ompi_datatype_type_extent(sdtype, &sextent);
ssize = opal_datatype_span(&sdtype->super, buf_size, &sgap);
Expand All @@ -111,7 +106,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
return OMPI_ERR_OUT_OF_RESOURCE;
}
wkg = workbuf - sgap;
tmprecv = wkg + extent * (ptrdiff_t) (rcount * (rank - startr));
tmprecv = wkg + sextent * (ptrdiff_t) (rcount * (rank - startr));
/* local copy to workbuf */
err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, tmprecv, scount, sdtype);
if (MPI_SUCCESS != err) {
Expand All @@ -123,7 +118,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
rcount = scount;
rextent = sextent;
total_recv = rcount;
} else {
} else if (rank != root) {
wkg = (char *) sbuf;
total_recv = scount;
}
Expand All @@ -141,9 +136,9 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
continue;
}
if (rank == root) {
tmprecv = wkg + extent * (ptrdiff_t) (rcount * i);
tmprecv = wkg + rextent * (ptrdiff_t) (rcount * i);
} else {
tmprecv = wkg + extent * (ptrdiff_t) (rcount * (i - startr));
tmprecv = wkg + rextent * (ptrdiff_t) (rcount * (i - startr));
}
err = MCA_PML_CALL(
recv(tmprecv, rcount, rdtype, i, MCA_COLL_BASE_TAG_GATHER, comm, &status));
Expand All @@ -161,10 +156,9 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
if (endn > size) {
endn = size;
}
incn = (rank == root) ? ((root != startn) ? 0 : sg_cnt) : sg_cnt;
if (sg_cnt < size) {
int local_root = (root_node == cur_node) ? root : startn;
for (i = startn + incn; i < endn; i += sg_cnt) {
for (i = startn; i < endn; i += sg_cnt) {
int i_sg = i / sg_cnt;
if ((rank != local_root) && (rank == i) && is_base) {
err = MCA_PML_CALL(send(workbuf - sgap, total_recv, sdtype, local_root,
Expand All @@ -173,7 +167,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
}
if ((rank == local_root) && (rank != i) && (i_sg != root_sg)) {
int recv_amt = (i + sg_cnt > size) ? rcount * (size - i) : rcount * sg_cnt;
MPI_Aint rcv_ofst = extent * (ptrdiff_t) (rcount * (i - startn));
MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * (i - startn));

err = MCA_PML_CALL(recv(wkg + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i,
MCA_COLL_BASE_TAG_GATHER, comm, &status));
Expand All @@ -189,7 +183,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
}

/* All local roots ranks send to root */
if (node_cnt < size) {
if (node_cnt < size && num_nodes > 1) {
for (i = 0; i < size; i += node_cnt) {
int i_node = i / node_cnt;
if ((rank != root) && (rank == i) && is_base) {
Expand All @@ -199,7 +193,7 @@ int mca_coll_acoll_gather_intra(const void *sbuf, int scount, struct ompi_dataty
}
if ((rank == root) && (rank != i) && (i_node != root_node)) {
int recv_amt = (i + node_cnt > size) ? rcount * (size - i) : rcount * node_cnt;
MPI_Aint rcv_ofst = extent * (ptrdiff_t) (rcount * i);
MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * i);

err = MCA_PML_CALL(recv((char *) rbuf + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i,
MCA_COLL_BASE_TAG_GATHER, comm, &status));
Expand Down
9 changes: 9 additions & 0 deletions ompi/mca/coll/acoll/coll_acoll_module.c
Expand Up @@ -41,6 +41,15 @@ mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *co
return NULL;
}

if (OMPI_COMM_IS_INTER(comm)) {
*priority = 0;
return NULL;
}
if (OMPI_COMM_IS_INTRA(comm) && ompi_comm_size(comm) < 2) {
*priority = 0;
return NULL;
}

*priority = mca_coll_acoll_priority;

/* Set topology params */
Expand Down
4 changes: 2 additions & 2 deletions ompi/mca/coll/acoll/coll_acoll_reduce.c
Expand Up @@ -382,11 +382,11 @@ int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, int count,
module);
} else {
return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op,
root, comm, module, 0, 0);
root, comm, module, 0, 0);
}
#else
return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root,
comm, module, 0, 0);
comm, module, 0, 0);
#endif
}
} else {
Expand Down
36 changes: 36 additions & 0 deletions ompi/mca/coll/acoll/coll_acoll_utils.h
Expand Up @@ -262,6 +262,9 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm,
mca_coll_base_module_allreduce_fn_t coll_allreduce_org = (comm)->c_coll->coll_allreduce;
mca_coll_base_module_allgather_fn_t coll_allgather_org = (comm)->c_coll->coll_allgather;
mca_coll_base_module_bcast_fn_t coll_bcast_org = (comm)->c_coll->coll_bcast;
mca_coll_base_module_allreduce_fn_t coll_allreduce_loc, coll_allreduce_soc;
mca_coll_base_module_allgather_fn_t coll_allgather_loc, coll_allgather_soc;
mca_coll_base_module_bcast_fn_t coll_bcast_loc, coll_bcast_soc;
coll_acoll_subcomms_t *subc;
int err;
int size = ompi_comm_size(comm);
Expand Down Expand Up @@ -362,6 +365,21 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm,
subc->base_root[MCA_COLL_ACOLL_L3CACHE][i] = -1;
subc->base_root[MCA_COLL_ACOLL_NUMA][i] = -1;
}
/* Store original collectives for local and socket comms */
coll_allreduce_loc = (subc->local_comm)->c_coll->coll_allreduce;
coll_allgather_loc = (subc->local_comm)->c_coll->coll_allgather;
coll_bcast_loc = (subc->local_comm)->c_coll->coll_bcast;
(subc->local_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring;
(subc->local_comm)->c_coll->coll_allreduce
= ompi_coll_base_allreduce_intra_recursivedoubling;
(subc->local_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear;
coll_allreduce_soc = (subc->socket_comm)->c_coll->coll_allreduce;
coll_allgather_soc = (subc->socket_comm)->c_coll->coll_allgather;
coll_bcast_soc = (subc->socket_comm)->c_coll->coll_bcast;
(subc->socket_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring;
(subc->socket_comm)->c_coll->coll_allreduce
= ompi_coll_base_allreduce_intra_recursivedoubling;
(subc->socket_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear;
}

/* Further subcommunicators based on root */
Expand Down Expand Up @@ -519,6 +537,14 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm,
}
}

/* Restore originals for local and socket comms */
(subc->local_comm)->c_coll->coll_allreduce = coll_allreduce_loc;
(subc->local_comm)->c_coll->coll_allgather = coll_allgather_loc;
(subc->local_comm)->c_coll->coll_bcast = coll_bcast_loc;
(subc->socket_comm)->c_coll->coll_allreduce = coll_allreduce_soc;
(subc->socket_comm)->c_coll->coll_allgather = coll_allgather_soc;
(subc->socket_comm)->c_coll->coll_bcast = coll_bcast_soc;

/* For collectives where order is important (like gather, allgather),
* split based on ranks. This is optimal for global communicators with
* equal split among nodes, but suboptimal for other cases.
Expand Down Expand Up @@ -590,6 +616,7 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica
data = (coll_acoll_data_t *) malloc(sizeof(coll_acoll_data_t));
if (NULL == data) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
size = ompi_comm_size(comm);
Expand All @@ -601,6 +628,7 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica
data->scratch = (char *) malloc(subc->xpmem_buf_size);
if (NULL == data->scratch) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
} else {
Expand All @@ -611,41 +639,49 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica
data->allseg_id = (xpmem_segid_t *) malloc(sizeof(xpmem_segid_t) * size);
if (NULL == data->allseg_id) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
data->all_apid = (xpmem_apid_t *) malloc(sizeof(xpmem_apid_t) * size);
if (NULL == data->all_apid) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
data->allshm_sbuf = (void **) malloc(sizeof(void *) * size);
if (NULL == data->allshm_sbuf) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
data->allshm_rbuf = (void **) malloc(sizeof(void *) * size);
if (NULL == data->allshm_rbuf) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
data->xpmem_saddr = (void **) malloc(sizeof(void *) * size);
if (NULL == data->xpmem_saddr) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
data->xpmem_raddr = (void **) malloc(sizeof(void *) * size);
if (NULL == data->xpmem_raddr) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
data->rcache = (mca_rcache_base_module_t **) malloc(sizeof(mca_rcache_base_module_t *) * size);
if (NULL == data->rcache) {
line = __LINE__;
ret = OMPI_ERR_OUT_OF_RESOURCE;
goto error_hndl;
}
seg_id = xpmem_make(0, XPMEM_MAXADDR_SIZE, XPMEM_PERMIT_MODE, (void *) 0666);
if (seg_id == -1) {
line = __LINE__;
ret = -1;
goto error_hndl;
}

Expand Down

0 comments on commit 550ac58

Please sign in to comment.