Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bcast Hierarchical Functions #6620

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ struct MPIR_Comm {
struct MPIR_Comm *node_comm; /* Comm of processes in this comm that are on
* the same node as this process. */
struct MPIR_Comm *node_roots_comm; /* Comm of root processes for other nodes. */

//int* node_weights_table; /* provides the weight of comm i in node_roots_comm */

int *intranode_table; /* intranode_table[i] gives the rank in
* node_comm of rank i in this comm or -1 if i
* is not in this process' node_comm.
Expand Down
9 changes: 9 additions & 0 deletions src/mpi/coll/bcast/bcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@

#include "mpiimpl.h"


int MPII_Scatter_for_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype,
int root, MPIR_Comm * comm_ptr, MPI_Aint nbytes, void *tmp_buf,
int is_contig, MPIR_Errflag_t errflag);

int MPII_Scatter_for_bcast_group(void *buffer, MPI_Aint count, MPI_Datatype datatype,
int root, MPIR_Comm * comm_ptr, int* group, int group_size, MPI_Aint nbytes, void *tmp_buf,
int is_contig, MPIR_Errflag_t errflag);



bool find_local_rank_linear(int* group, int group_size, int rank, int root, int* group_rank, int* group_root);

#endif /* BCAST_H_INCLUDED */
150 changes: 150 additions & 0 deletions src/mpi/coll/bcast/bcast_intra_adaptive.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright (C) by Argonne National Laboratory
* See COPYRIGHT in top-level directory
*/

#include "mpiimpl.h"
#include "bcast.h"

/*
* Adaptive Bcast creates a priority queue of ranks (0th index in the queue is most important rank to send to and (comm_size - 1)th
* index is the least important) and sends using the binomial tree algorithm with sends reversed.
* If optimized properly, this should perform the same if not better than the SMP algorithm. */

int MPIR_Bcast_intra_adaptive(void* buffer, MPI_Aint count, MPI_Datatype datatype, int root,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) {

int rank, comm_size, src, dst;
int relative_rank, mask;
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
MPI_Aint nbytes = 0;
MPI_Status *status_p;
#ifdef HAVE_ERROR_CHECKING
MPI_Status status;
status_p = &status;
MPI_Aint recvd_size;
#else
status_p = MPI_STATUS_IGNORE;
#endif
int is_contig;
MPI_Aint type_size;
void *tmp_buf = NULL;
MPIR_CHKLMEM_DECL(3);

comm_size = comm_ptr->local_size;
rank = comm_ptr->rank;


int break_point;
struct Rank_Info * ranks = NULL;
int* priority_queue = NULL;

MPIR_CHKLMEM_MALLOC(ranks, struct Rank_Info * , sizeof(struct Rank_Info) * comm_size, mpi_errno, "ranks", MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(priority_queue, int*, sizeof(int) * comm_size, mpi_errno, "priority_queue", MPL_MEM_BUFFER);
retrieve_weights(comm_ptr, ranks);
build_queue(comm_ptr, ranks, priority_queue);


if (!rank) {
for (int i = 0; i < comm_size; i++) {
printf("%d\n", priority_queue[i]);
}
}
goto fn_exit;

if (HANDLE_IS_BUILTIN(datatype))
is_contig = 1;
else {
MPIR_Datatype_is_contig(datatype, &is_contig);
}

MPIR_Datatype_get_size_macro(datatype, type_size);

nbytes = type_size * count;
if (nbytes == 0)
goto fn_exit; /* nothing to do */

if (!is_contig) {
MPIR_CHKLMEM_MALLOC(tmp_buf, void *, nbytes, mpi_errno, "tmp_buf", MPL_MEM_BUFFER);

/* TODO: Pipeline the packing and communication */
if (rank == root) {
mpi_errno = MPIR_Localcopy(buffer, count, datatype, tmp_buf, nbytes, MPI_BYTE);
MPIR_ERR_CHECK(mpi_errno);
}
}

relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;

mask = 0x1;
while (mask < comm_size) {
if (relative_rank & mask) {

src = relative_rank - mask;
if (src < 0)
src += comm_size;
src = priority_queue[src];

if (root == src) break;

if (!is_contig)
mpi_errno = MPIC_Recv(tmp_buf, nbytes, MPI_BYTE, src,
MPIR_BCAST_TAG, comm_ptr, status_p);
else
mpi_errno = MPIC_Recv(buffer, count, datatype, src,
MPIR_BCAST_TAG, comm_ptr, status_p);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
#ifdef HAVE_ERROR_CHECKING
/* check that we received as much as we expected */
MPIR_Get_count_impl(status_p, MPI_BYTE, &recvd_size);
MPIR_ERR_COLL_CHECK_SIZE(recvd_size, nbytes, errflag, mpi_errno_ret);
#endif
break;
}
mask <<= 1;
}

mask >>= 1;
break_point = mask;
mask = 0x1;

while (mask <= break_point) {
if (relative_rank + mask < comm_size) {

dst = relative_rank + mask;
if (dst >= comm_size)
dst -= comm_size;
dst = priority_queue[dst];

if (root == dst) goto end_if; /* No need to send to root */

if (!is_contig)
mpi_errno = MPIC_Send(tmp_buf, nbytes, MPI_BYTE, dst,
MPIR_BCAST_TAG, comm_ptr, errflag);
else
mpi_errno = MPIC_Send(buffer, count, datatype, dst,
MPIR_BCAST_TAG, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}
end_if:
mask <<= 1;
}

if (!is_contig) {
if (rank != root) {
mpi_errno = MPIR_Localcopy(tmp_buf, nbytes, MPI_BYTE, buffer, count, datatype);
MPIR_ERR_CHECK(mpi_errno);

}
}

fn_exit:
MPIR_CHKLMEM_FREEALL();
return mpi_errno_ret;
fn_fail:
mpi_errno_ret = mpi_errno;
goto fn_exit;

}

126 changes: 126 additions & 0 deletions src/mpi/coll/bcast/bcast_intra_binomial_group.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright (C) by Argonne National Laboratory
* See COPYRIGHT in top-level directory
*/

#include "mpiimpl.h"
#include "bcast.h"

int MPIR_Bcast_intra_binomial_group(void *buffer,
MPI_Aint count,
MPI_Datatype datatype,
int root, MPIR_Comm * comm_ptr, int* group, int group_size, MPIR_Errflag_t errflag)
{
int rank, src, dst;
int relative_rank, mask;
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
MPI_Aint nbytes = 0;
#ifdef HAVE_ERROR_CHECKING
MPI_Status *status_p;
MPI_Status status;
status_p = &status;
MPI_Aint recvd_size;
#else
status_p = MPI_STATUS_IGNORE;
#endif
int is_contig;
MPI_Aint type_size;
void *tmp_buf = NULL;
MPIR_CHKLMEM_DECL(1);

rank = comm_ptr->rank;
int group_rank; // local ranking of the process within the group
int group_root; // the root of the group

bool found_rank_in_group = find_local_rank_linear(group, group_size, rank, root, &group_rank, &group_root);
if (!found_rank_in_group) goto fn_exit;

/* Uncomment the below code snippet for isolated testing */

// if (!found_rank_in_group) {
// return mpi_errno_ret;
// }

MPIR_Assert(found_rank_in_group);

if (HANDLE_IS_BUILTIN(datatype))
is_contig = 1;
else {
MPIR_Datatype_is_contig(datatype, &is_contig);
}

MPIR_Datatype_get_size_macro(datatype, type_size);

nbytes = type_size * count;
if (nbytes == 0)
goto fn_exit;

if (!is_contig) {
MPIR_CHKLMEM_MALLOC(tmp_buf, void *, nbytes, mpi_errno, "tmp_buf", MPL_MEM_BUFFER);
if (rank == root) {
mpi_errno = MPIR_Localcopy(buffer, count, datatype, tmp_buf, nbytes, MPI_BYTE);
MPIR_ERR_CHECK(mpi_errno);
}
}

relative_rank = (group_rank >= group_root) ? group_rank - group_root : group_rank - group_root + group_size;

mask = 0x1;
while (mask < group_size) {
if (relative_rank & mask) {
src = group_rank - mask;

if (src < 0)
src += group_size;
if (!is_contig)
mpi_errno = MPIC_Recv(tmp_buf, nbytes, MPI_BYTE, group[src],
MPIR_BCAST_TAG, comm_ptr, status_p);
else
mpi_errno = MPIC_Recv(buffer, count, datatype, group[src],
MPIR_BCAST_TAG, comm_ptr, status_p);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
#ifdef HAVE_ERROR_CHECKING
MPIR_Get_count_impl(status_p, MPI_BYTE, &recvd_size);
MPIR_ERR_COLL_CHECK_SIZE(recvd_size, nbytes, errflag, mpi_errno_ret);
#endif

break;
}
mask <<= 1;
}

mask >>= 1;
while (mask > 0) {

if (relative_rank + mask < group_size) {
dst = group_rank + mask;
if (dst >= group_size)
dst -= group_size;
if (!is_contig)
mpi_errno = MPIC_Send(tmp_buf, nbytes, MPI_BYTE, group[dst],
MPIR_BCAST_TAG, comm_ptr, errflag);
else
mpi_errno = MPIC_Send(buffer, count, datatype, group[dst],
MPIR_BCAST_TAG, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);

}
mask >>= 1;
}

if (!is_contig) {
if (rank != root) {
mpi_errno = MPIR_Localcopy(tmp_buf, nbytes, MPI_BYTE, buffer, count, datatype);
MPIR_ERR_CHECK(mpi_errno);

}
}

fn_exit:
MPIR_CHKLMEM_FREEALL();
return mpi_errno_ret;
fn_fail:
mpi_errno_ret = mpi_errno;
goto fn_exit;
}