Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mca/coll: Add any radix k for alltoall bruck algorithm #12453

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
163 changes: 124 additions & 39 deletions ompi/mca/coll/base/coll_base_alltoall.c
Expand Up @@ -235,15 +235,90 @@ int ompi_coll_base_alltoall_intra_pairwise(const void *sbuf, size_t scount,
return err;
}


int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, size_t scount,
struct ompi_datatype_t *sdtype,
void* rbuf, size_t rcount,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
/*
*
* Function: ompi_coll_base_alltoall_intra_k_bruck using O(logk(N)) steps
* Accepts: Same arguments as MPI_Alltoall
* Returns: MPI_SUCCESS or error code
*
* Description: This method extends ompi_coll_base_alltoall_intra_bruck to handle any
* radix k(k >= 2).
*
* Example on 6 ranks with k = 4
* # 0 1 2 3 4 5
* [00] [10] [20] [30] [40] [50]
* [01] [11] [21] [31] [41] [51]
* [02] [12] [22] [32] [42] [52]
* [03] [13] [23] [33] [43] [53]
* [04] [14] [24] [34] [44] [54]
* [05] [15] [25] [35] [45] [55]
* After local rotation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you really need the local rotation ? I understand it makes the code easier to maintain but in has a significant and finally unnecessary cost, because at the end you are building the datatype by hand without taking advantage of it's continuity in memory.

Second remark is related to the cost of creating and committing the datatype. I'm almost certain that this cost is expensive, especially for the middle range messages where the k-ary bruck is supposed to behave best. The result is that you pay a high cost to prepare the datatype, resulting in a non contiguous datatype while leads to lower performance communications (because non-contiguous data usually lead to copy in/out protocols). If instead of building the datatype you copy the data into a contiguous buffer, you avoid the cost of the datatype construction and communicate from contiguous to contiguous memory, with better outcome. The only potential drawback is the extra local copy before the send (similar to a pack).

* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [01] [12] [23] [34] [45] [50]
* [02] [13] [24] [35] [40] [51]
* [03] [14] [25] [30] [41] [52]
* [04] [15] [20] [31] [42] [53]
* [05] [10] [21] [32] [43] [54]
* Phase 0: send message to (rank + k^0 * i), receive message from (rank - k^0 * i)
* send the data block whose least significant bit is i in base k representation
* for i between [1, k-1]
* i = 1: send the data block at offset i, i + k to (rank + 1)
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [50] [01] [12] [23] [34] [45]
* [02] [13] [24] [35] [40] [51]
* [03] [14] [25] [30] [41] [52]
* [04] [15] [20] [31] [42] [53]
* [54] [05] [10] [21] [32] [43]
* i = 2: send the data block at offset i to (rank + 2)
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [50] [01] [12] [23] [34] [45]
* [40] [51] [02] [13] [24] [35]
* [03] [14] [25] [30] [41] [52]
* [04] [15] [20] [31] [42] [53]
* [54] [05] [10] [21] [32] [43]
* i = 3: send the data block at offset i to (rank + 3)
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [50] [01] [12] [23] [34] [45]
* [40] [51] [02] [13] [24] [35]
* [30] [41] [52] [03] [14] [25]
* [04] [15] [20] [31] [42] [53]
* [54] [05] [10] [21] [32] [43]
* Phase 1: send message to (rank + k^1 * i), receive message from (rank - k^1 * i)
* send the data block whose second bit is i in base k representation
* for i between [1, k-1]
* i = 1: send the data block at offset k with size of min(k, size-i*k)=2 to (rank + 4)
* # 0 1 2 3 4 5
* [00] [11] [22] [33] [44] [55]
* [50] [01] [12] [23] [34] [45]
* [40] [51] [02] [13] [24] [35]
* [30] [41] [52] [03] [14] [25]
* [20] [31] [42] [53] [04] [15]
* [10] [21] [32] [43] [54] [05]
* i = 2: nothing is to be sent
* i = 3: nothing is to be sent
* After local inverse rotation
* # 0 1 2 3 4 5
* [00] [01] [02] [03] [04] [05]
* [10] [11] [12] [13] [14] [15]
* [20] [21] [22] [23] [24] [25]
* [30] [31] [32] [33] [34] [35]
* [40] [41] [42] [43] [44] [45]
* [50] [51] [52] [53] [54] [55]
*
*/
int ompi_coll_base_alltoall_intra_k_bruck(const void *sbuf, size_t scount,
struct ompi_datatype_t *sdtype,
void* rbuf, size_t rcount,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module,
int radix)
{
int i, line = -1, rank, size, err = 0;
int i, j, line = -1, rank, size, err = 0;
int sendto, recvfrom, distance, *displs = NULL;
char *tmpbuf = NULL, *tmpbuf_free = NULL;
ptrdiff_t sext, rext, span, gap = 0;
Expand All @@ -257,8 +332,12 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, size_t scount,
size = ompi_comm_size(comm);
rank = ompi_comm_rank(comm);

if (radix < 2) {
line = __LINE__; err = -1; goto err_hndl;
}

OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
"coll:base:alltoall_intra_bruck rank %d", rank));
"coll:base:alltoall_intra_k_bruck radix %d rank %d", radix, rank));

err = ompi_datatype_type_extent (sdtype, &sext);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
Expand Down Expand Up @@ -297,42 +376,48 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, size_t scount,
}

/* perform communication step */
for (distance = 1; distance < size; distance<<=1) {

sendto = (rank + distance) % size;
recvfrom = (rank - distance + size) % size;
for (distance = 1; distance < size; distance *= radix) {
for (i = 1; i < radix; i++) {
lrbison marked this conversation as resolved.
Show resolved Hide resolved

new_ddt = ompi_datatype_create((1 + size/distance) * (2 + rdtype->super.desc.used));

/* Create datatype describing data sent/received */
for (i = distance; i < size; i += 2*distance) {
int nblocks = distance;
if (i + distance >= size) {
nblocks = size - i;
if (distance * i >= size) {
break;
}
ompi_datatype_add(new_ddt, rdtype, rcount * nblocks,
i * rcount * rext, rext);
}

/* Commit the new datatype */
err = ompi_datatype_commit(&new_ddt);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
sendto = (rank + distance * i) % size;
recvfrom = (rank - distance * i + size) % size;

/* Sendreceive */
err = ompi_coll_base_sendrecv ( tmpbuf, 1, new_ddt, sendto,
MCA_COLL_BASE_TAG_ALLTOALL,
rbuf, 1, new_ddt, recvfrom,
MCA_COLL_BASE_TAG_ALLTOALL,
comm, MPI_STATUS_IGNORE, rank );
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
new_ddt = ompi_datatype_create((1 + size/distance) * (2 + rdtype->super.desc.used));

/* Copy back new data from recvbuf to tmpbuf */
err = ompi_datatype_copy_content_same_ddt(new_ddt, 1,tmpbuf, (char *) rbuf);
if (err < 0) { line = __LINE__; err = -1; goto err_hndl; }
/* Create datatype describing data sent/received */
for (j = i * distance; j < size; j += radix * distance) {
int nblocks = distance;
if (j + distance >= size) {
nblocks = size - j;
}
ompi_datatype_add(new_ddt, rdtype, rcount * nblocks,
j * rcount * rext, rext);
}

/* free ddt */
err = ompi_datatype_destroy(&new_ddt);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
/* Commit the new datatype */
err = ompi_datatype_commit(&new_ddt);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the real reason you are not seeing benefits by using nonblocking communications is due to the cost of the datatype creation. During this time, the posted communications will not be able to progress, which means the bandwidth is wasted until all datatype are build and you reach the waitall part.

if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }

/* Sendreceive */
err = ompi_coll_base_sendrecv ( tmpbuf, 1, new_ddt, sendto,
jiaxiyan marked this conversation as resolved.
Show resolved Hide resolved
MCA_COLL_BASE_TAG_ALLTOALL,
rbuf, 1, new_ddt, recvfrom,
MCA_COLL_BASE_TAG_ALLTOALL,
comm, MPI_STATUS_IGNORE, rank );
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }

/* Copy back new data from recvbuf to tmpbuf */
err = ompi_datatype_copy_content_same_ddt(new_ddt, 1, tmpbuf, (char *) rbuf);
if (err < 0) { line = __LINE__; err = -1; goto err_hndl; }

/* free ddt */
err = ompi_datatype_destroy(&new_ddt);
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
}
} /* end of for (distance = 1... */

/* Step 3 - local rotation - */
Expand Down
2 changes: 1 addition & 1 deletion ompi/mca/coll/base/coll_base_functions.h
Expand Up @@ -215,7 +215,7 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS);

/* AlltoAll */
int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS);
int ompi_coll_base_alltoall_intra_bruck(ALLTOALL_ARGS);
int ompi_coll_base_alltoall_intra_k_bruck(ALLTOALL_ARGS, int radix);
int ompi_coll_base_alltoall_intra_basic_linear(ALLTOALL_ARGS);
int ompi_coll_base_alltoall_intra_linear_sync(ALLTOALL_ARGS, int max_requests);
int ompi_coll_base_alltoall_intra_two_procs(ALLTOALL_ARGS);
Expand Down
2 changes: 1 addition & 1 deletion ompi/mca/coll/tuned/coll_tuned_alltoall_decision.c
Expand Up @@ -173,7 +173,7 @@ int ompi_coll_tuned_alltoall_intra_do_this(const void *sbuf, size_t scount,
case (2):
return ompi_coll_base_alltoall_intra_pairwise(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module);
case (3):
return ompi_coll_base_alltoall_intra_bruck(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module);
return ompi_coll_base_alltoall_intra_k_bruck(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module, faninout);
case (4):
return ompi_coll_base_alltoall_intra_linear_sync(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, module, max_requests);
case (5):
Expand Down
3 changes: 2 additions & 1 deletion ompi/mca/coll/tuned/coll_tuned_decision_fixed.c
Expand Up @@ -404,10 +404,11 @@ int ompi_coll_tuned_alltoall_intra_dec_fixed(const void *sbuf, size_t scount,
}
}

int faninout = 2;
return ompi_coll_tuned_alltoall_intra_do_this (sbuf, scount, sdtype,
rbuf, rcount, rdtype,
comm, module,
alg, 0, 0, ompi_coll_tuned_alltoall_max_requests);
alg, faninout, 0, ompi_coll_tuned_alltoall_max_requests);
}

/*
Expand Down