Skip to content

Commit

Permalink
mca/coll: Add any radix k for alltoall bruck algorithm
Browse files Browse the repository at this point in the history
This method extends ompi_coll_base_alltoall_intra_bruck to handle
any radix k.

Signed-off-by: Jessie Yang <jiaxiyan@amazon.com>
  • Loading branch information
jiaxiyan committed May 7, 2024
1 parent e44cd58 commit 87e6eb8
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 42 deletions.
187 changes: 148 additions & 39 deletions ompi/mca/coll/base/coll_base_alltoall.c
Expand Up @@ -235,19 +235,96 @@ int ompi_coll_base_alltoall_intra_pairwise(const void *sbuf, int scount,
return err;
}


int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
struct ompi_datatype_t *sdtype,
void* rbuf, int rcount,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
/*
*
* Function: ompi_coll_base_alltoall_intra_k_bruck using O(logk(N)) steps
* Accepts: Same arguments as MPI_Alltoall
* Returns: MPI_SUCCESS or error code
*
* Description: This method extends ompi_coll_base_alltoall_intra_bruck to handle any
* radix k(k >= 2).
*
* Example on 6 ranks with k = 4
* # 0 1 2 3 4 5
* [00] [10] [20] [30] [40] [50]
* [01] [11] [21] [31] [41] [51]
* [02] [12] [22] [32] [42] [52]
* [03] [13] [23] [33] [43] [53]
* [04] [14] [24] [34] [44] [54]
* [05] [15] [25] [35] [45] [55]
* After local rotation
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [01] [12] [23] [34] [45] [50]
* [02] [13] [24] [35] [40] [51]
* [03] [14] [25] [30] [41] [52]
* [04] [15] [20] [31] [42] [53]
* [05] [10] [21] [32] [43] [54]
* Phase 0: send message to (rank + k^0 * i), receive message from (rank - k^0 * i)
* send the data block whose least significant bit is i in base k representation
* for i between [1, k-1]
* i = 1: send the data block at offset i, i + k to (rank + 1)
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [50] [01] [12] [23] [34] [45]
* [02] [13] [24] [35] [40] [51]
* [03] [14] [25] [30] [41] [52]
* [04] [15] [20] [31] [42] [53]
* [54] [05] [10] [21] [32] [43]
* i = 2: send the data block at offset i to (rank + 2)
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [50] [01] [12] [23] [34] [45]
* [40] [51] [02] [13] [24] [35]
* [03] [14] [25] [30] [41] [52]
* [04] [15] [20] [31] [42] [53]
* [54] [05] [10] [21] [32] [43]
* i = 3: send the data block at offset i to (rank + 3)
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [50] [01] [12] [23] [34] [45]
* [40] [51] [02] [13] [24] [35]
* [30] [41] [52] [03] [14] [25]
* [04] [15] [20] [31] [42] [53]
* [54] [05] [10] [21] [32] [43]
* Phase 1: send message to (rank + k^1 * i), receive message from (rank - k^1 * i)
* send the data block whose second bit is i in base k representation
* for i between [1, k-1]
* i = 1: send the data block at offset k with size of min(k, size-i*k)=2 to (rank + 4)
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [50] [01] [12] [23] [34] [45]
* [40] [51] [02] [13] [24] [35]
* [30] [41] [52] [03] [14] [25]
* [20] [31] [42] [53] [04] [15]
* [10] [21] [32] [43] [54] [05]
* i = 2: nothing is to be sent
* i = 3: nothing is to be sent
* After local inverse rotation
* # 0 1 2 3 4 5
* [00] [01] [02] [03] [04] [05]
* [10] [11] [12] [13] [14] [15]
* [20] [21] [22] [23] [24] [25]
* [30] [31] [32] [33] [34] [35]
* [40] [41] [42] [43] [44] [45]
* [50] [51] [52] [53] [54] [55]
*
*/
int ompi_coll_base_alltoall_intra_k_bruck(const void *sbuf, int scount,
struct ompi_datatype_t *sdtype,
void* rbuf, int rcount,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module,
int radix)
{
int i, line = -1, rank, size, err = 0;
int i, j, line = -1, rank, size, err = 0;
int sendto, recvfrom, distance, *displs = NULL;
char *tmpbuf = NULL, *tmpbuf_free = NULL;
ptrdiff_t sext, rext, span, gap = 0;
struct ompi_datatype_t *new_ddt;
ompi_request_t **reqs;
int num_reqs, max_reqs = 0;

if (MPI_IN_PLACE == sbuf) {
return mca_coll_base_alltoall_intra_basic_inplace (rbuf, rcount, rdtype,
Expand All @@ -257,8 +334,12 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
size = ompi_comm_size(comm);
rank = ompi_comm_rank(comm);

if (radix < 2) {
line = __LINE__; err = -1; goto err_hndl;
}

OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
"coll:base:alltoall_intra_bruck rank %d", rank));
"coll:base:alltoall_intra_k_bruck radix %d rank %d", radix, rank));

err = ompi_datatype_type_extent (sdtype, &sext);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
Expand Down Expand Up @@ -297,42 +378,57 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
}

/* perform communication step */
for (distance = 1; distance < size; distance<<=1) {

sendto = (rank + distance) % size;
recvfrom = (rank - distance + size) % size;

new_ddt = ompi_datatype_create((1 + size/distance) * (2 + rdtype->super.desc.used));
max_reqs = 2 * (radix - 1);
reqs = ompi_coll_base_comm_get_reqs(module->base_data, max_reqs);
for (distance = 1; distance < size; distance *= radix) {
num_reqs = 0;
for (i = 1; i < radix; i++) {

/* Create datatype describing data sent/received */
for (i = distance; i < size; i += 2*distance) {
int nblocks = distance;
if (i + distance >= size) {
nblocks = size - i;
if (distance * i >= size) {
break;
}
ompi_datatype_add(new_ddt, rdtype, rcount * nblocks,
i * rcount * rext, rext);
}

/* Commit the new datatype */
err = ompi_datatype_commit(&new_ddt);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
sendto = (rank + distance * i) % size;
recvfrom = (rank - distance * i + size) % size;

/* Sendreceive */
err = ompi_coll_base_sendrecv ( tmpbuf, 1, new_ddt, sendto,
MCA_COLL_BASE_TAG_ALLTOALL,
rbuf, 1, new_ddt, recvfrom,
MCA_COLL_BASE_TAG_ALLTOALL,
comm, MPI_STATUS_IGNORE, rank );
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
new_ddt = ompi_datatype_create((1 + size/distance) * (2 + rdtype->super.desc.used));

/* Copy back new data from recvbuf to tmpbuf */
err = ompi_datatype_copy_content_same_ddt(new_ddt, 1,tmpbuf, (char *) rbuf);
if (err < 0) { line = __LINE__; err = -1; goto err_hndl; }
/* Create datatype describing data sent/received */
for (j = i * distance; j < size; j += radix * distance) {
int nblocks = distance;
if (j + distance >= size) {
nblocks = size - j;
}
ompi_datatype_add(new_ddt, rdtype, rcount * nblocks,
j * rcount * rext, rext);
}

/* free ddt */
err = ompi_datatype_destroy(&new_ddt);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
/* Commit the new datatype */
err = ompi_datatype_commit(&new_ddt);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }

err = MCA_PML_CALL(irecv(rbuf, 1, new_ddt, recvfrom,
MCA_COLL_BASE_TAG_ALLTOALL,
comm,
&reqs[num_reqs++]));
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
err = MCA_PML_CALL(isend(tmpbuf, 1, new_ddt, sendto,
MCA_COLL_BASE_TAG_ALLTOALL,
MCA_PML_BASE_SEND_STANDARD,
comm,
&reqs[num_reqs++]));
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }

/* Copy back new data from recvbuf to tmpbuf */
err = ompi_datatype_copy_content_same_ddt(new_ddt, 1, tmpbuf, (char *) rbuf);
if (err < 0) { line = __LINE__; err = -1; goto err_hndl; }

/* free ddt */
err = ompi_datatype_destroy(&new_ddt);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
}
err = ompi_request_wait_all(num_reqs, reqs, MPI_STATUSES_IGNORE);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
} /* end of for (distance = 1... */

/* Step 3 - local rotation - */
Expand All @@ -349,6 +445,19 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
return OMPI_SUCCESS;

err_hndl:
if( NULL != reqs ) {
if (MPI_ERR_IN_STATUS == err) {
for( num_reqs = 0; num_reqs < max_reqs; num_reqs++ ) {
if (MPI_REQUEST_NULL == reqs[num_reqs]) continue;
if (MPI_ERR_PENDING == reqs[num_reqs]->req_status.MPI_ERROR) continue;
if (reqs[num_reqs]->req_status.MPI_ERROR != MPI_SUCCESS) {
err = reqs[num_reqs]->req_status.MPI_ERROR;
break;
}
}
}
ompi_coll_base_free_reqs(reqs, max_reqs);
}
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
"%s:%4d\tError occurred %d, rank %2d", __FILE__, line, err,
rank));
Expand Down
2 changes: 1 addition & 1 deletion ompi/mca/coll/base/coll_base_functions.h
Expand Up @@ -215,7 +215,7 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS);

/* AlltoAll */
int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS);
int ompi_coll_base_alltoall_intra_bruck(ALLTOALL_ARGS);
int ompi_coll_base_alltoall_intra_k_bruck(ALLTOALL_ARGS, int radix);
int ompi_coll_base_alltoall_intra_basic_linear(ALLTOALL_ARGS);
int ompi_coll_base_alltoall_intra_linear_sync(ALLTOALL_ARGS, int max_requests);
int ompi_coll_base_alltoall_intra_two_procs(ALLTOALL_ARGS);
Expand Down
2 changes: 1 addition & 1 deletion ompi/mca/coll/tuned/coll_tuned_alltoall_decision.c
Expand Up @@ -173,7 +173,7 @@ int ompi_coll_tuned_alltoall_intra_do_this(const void *sbuf, int scount,
case (2):
return ompi_coll_base_alltoall_intra_pairwise(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module);
case (3):
return ompi_coll_base_alltoall_intra_bruck(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module);
return ompi_coll_base_alltoall_intra_k_bruck(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module, faninout);
case (4):
return ompi_coll_base_alltoall_intra_linear_sync(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module, max_requests);
case (5):
Expand Down
3 changes: 2 additions & 1 deletion ompi/mca/coll/tuned/coll_tuned_decision_fixed.c
Expand Up @@ -404,10 +404,11 @@ int ompi_coll_tuned_alltoall_intra_dec_fixed(const void *sbuf, int scount,
}
}

int faninout = 2;
return ompi_coll_tuned_alltoall_intra_do_this (sbuf, scount, sdtype,
rbuf, rcount, rdtype,
comm, module,
alg, 0, 0, ompi_coll_tuned_alltoall_max_requests);
alg, faninout, 0, ompi_coll_tuned_alltoall_max_requests);
}

/*
Expand Down

0 comments on commit 87e6eb8

Please sign in to comment.