Skip to content

Commit

Permalink
Fix collective initialization and finalization
Browse files Browse the repository at this point in the history
Instead of allowing each collective module to present a list of
functions it provide, let them register the functions they provide and
save the context of the previous collective if they choose to.

There are two major benefits to this approach:
- tighter memory management in the collective module themselves. Each
  collective enable and disable is called exactly once per communicator,
  to chain or unchain themselves from the collective function pointers
  struct. The disable is called in the reverse order of the enable,
  allowing for proper chaining of collectives.
- they only install the functions they want. So instead of checking in
  the coll_select all the functions for all modules, each module can now
  selectively iterate over only the functions it provides.

What is still broken is the ability of a particular collective module to
unchain itself in the middle of the execution. Instead, a properly
implemented module will have an enable/disable flag, and it should act
as a passthrough if it chooses to desactivate.

Signed-off-by: George Bosilca <bosilca@icl.utk.edu>
  • Loading branch information
bosilca committed Mar 22, 2024
1 parent 97eac28 commit c03b06c
Show file tree
Hide file tree
Showing 62 changed files with 1,710 additions and 1,416 deletions.
1 change: 0 additions & 1 deletion ompi/mca/coll/accelerator/Makefile.am
Expand Up @@ -10,7 +10,6 @@
#
# $HEADER$
#
dist_ompidata_DATA = help-mpi-coll-accelerator.txt

sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \
Expand Down
3 changes: 0 additions & 3 deletions ompi/mca/coll/accelerator/coll_accelerator.h
Expand Up @@ -38,9 +38,6 @@ mca_coll_base_module_t
*mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
int *priority);

int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);

int
mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
Expand Down
125 changes: 66 additions & 59 deletions ompi/mca/coll/accelerator/coll_accelerator_module.c
Expand Up @@ -32,30 +32,21 @@
#include "ompi/mca/coll/base/base.h"
#include "coll_accelerator.h"

static int
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);
static int
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);

static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module)
{
memset(&(module->c_coll), 0, sizeof(module->c_coll));
}

static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module)
{
OBJ_RELEASE(module->c_coll.coll_allreduce_module);
OBJ_RELEASE(module->c_coll.coll_reduce_module);
OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
OBJ_RELEASE(module->c_coll.coll_scatter_module);
/* If the exscan module is not NULL, then this was an
intracommunicator, and therefore scan will have a module as
well. */
if (NULL != module->c_coll.coll_exscan_module) {
OBJ_RELEASE(module->c_coll.coll_exscan_module);
OBJ_RELEASE(module->c_coll.coll_scan_module);
}
}

OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t,
mca_coll_accelerator_module_construct,
mca_coll_accelerator_module_destruct);
NULL);


/*
Expand Down Expand Up @@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,

/* Choose whether to use [intra|inter] */
accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable;
accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable;

accelerator_module->super.coll_allgather = NULL;
accelerator_module->super.coll_allgatherv = NULL;
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
accelerator_module->super.coll_alltoall = NULL;
accelerator_module->super.coll_alltoallv = NULL;
accelerator_module->super.coll_alltoallw = NULL;
accelerator_module->super.coll_barrier = NULL;
accelerator_module->super.coll_bcast = NULL;
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
accelerator_module->super.coll_gather = NULL;
accelerator_module->super.coll_gatherv = NULL;
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
accelerator_module->super.coll_reduce_scatter = NULL;
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
accelerator_module->super.coll_scatter = NULL;
accelerator_module->super.coll_scatterv = NULL;
if (!OMPI_COMM_IS_INTER(comm)) {
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
}

return &(accelerator_module->super);
}


#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if ((__comm)->c_coll->coll_##__api) \
{ \
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
} \
else \
{ \
opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \
"cuda", #__api, ompi_process_info.nodename, \
mca_coll_accelerator_component.priority); \
} \
} while (0)

#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \
MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
(__module)->c_coll.coll_##__api##_module = NULL; \
(__module)->c_coll.coll_##__api = NULL; \
} \
} while (0)

/*
* Init module on the communicator
* Init/Fini module on the communicator
*/
int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
static int
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
bool good = true;
char *msg = NULL;
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

#define CHECK_AND_RETAIN(src, dst, name) \
if (NULL == (src)->c_coll->coll_ ## name ## _module) { \
good = false; \
msg = #name; \
} else if (good) { \
(dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \
(dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \
OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \
}

CHECK_AND_RETAIN(comm, s, allreduce);
CHECK_AND_RETAIN(comm, s, reduce);
CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
CHECK_AND_RETAIN(comm, s, scatter);
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm)) {
/* MPI does not define scan/exscan on intercommunicators */
CHECK_AND_RETAIN(comm, s, exscan);
CHECK_AND_RETAIN(comm, s, scan);
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_INSTALL_COLL_API(comm, s, scan);
}

/* All done */
if (good) {
return OMPI_SUCCESS;
}
opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true,
ompi_process_info.nodename,
mca_coll_accelerator_component.priority, msg);
return OMPI_ERR_NOT_FOUND;
return OMPI_SUCCESS;
}

static int
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm))
{
/* MPI does not define scan/exscan on intercommunicators */
ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan);
}

return OMPI_SUCCESS;
}
29 changes: 0 additions & 29 deletions ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt

This file was deleted.

86 changes: 51 additions & 35 deletions ompi/mca/coll/adapt/coll_adapt_module.c
Expand Up @@ -83,25 +83,41 @@ OBJ_CLASS_INSTANCE(mca_coll_adapt_module_t,
adapt_module_construct,
adapt_module_destruct);

/*
* In this macro, the following variables are supposed to have been declared
* in the caller:
* . ompi_communicator_t *comm
* . mca_coll_adapt_module_t *adapt_module
*/
#define ADAPT_SAVE_PREV_COLL_API(__api) \
do { \
adapt_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
adapt_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
opal_output_verbose(1, ompi_coll_base_framework.framework_output, \
"(%s/%s): no underlying " # __api"; disqualifying myself", \
ompi_comm_print_cid(comm), comm->c_name); \
return OMPI_ERROR; \
} \
OBJ_RETAIN(adapt_module->previous_ ## __api ## _module); \
} while(0)

#define ADAPT_INSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (__module->super.coll_##__api) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
} \
} while (0)
#define ADAPT_UNINSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "adapt"); \
} \
} while (0)
#define ADAPT_INSTALL_AND_SAVE_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api && __comm->c_coll->coll_##__api##_module) \
{ \
MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
} \
} while (0)
#define ADAPT_UNINSTALL_AND_RESTORE_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
__module->previous_##__api = NULL; \
__module->previous_##__api##_module = NULL; \
} \
} while (0)

/*
* Init module on the communicator
Expand All @@ -111,12 +127,25 @@ static int adapt_module_enable(mca_coll_base_module_t * module,
{
mca_coll_adapt_module_t * adapt_module = (mca_coll_adapt_module_t*) module;

ADAPT_SAVE_PREV_COLL_API(reduce);
ADAPT_SAVE_PREV_COLL_API(ireduce);
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, reduce);
ADAPT_INSTALL_COLL_API(comm, adapt_module, bcast);
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, ireduce);
ADAPT_INSTALL_COLL_API(comm, adapt_module, ibcast);

return OMPI_SUCCESS;
}
static int adapt_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_adapt_module_t *adapt_module = (mca_coll_adapt_module_t *)module;

ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, reduce);
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, bcast);
ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, ireduce);
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, ibcast);

return OMPI_SUCCESS;
}
/*
* Initial query function that is invoked during MPI_INIT, allowing
* this component to disqualify itself if it doesn't support the
Expand Down Expand Up @@ -165,24 +194,11 @@ mca_coll_base_module_t *ompi_coll_adapt_comm_query(struct ompi_communicator_t *

/* All is good -- return a module */
adapt_module->super.coll_module_enable = adapt_module_enable;
adapt_module->super.coll_allgather = NULL;
adapt_module->super.coll_allgatherv = NULL;
adapt_module->super.coll_allreduce = NULL;
adapt_module->super.coll_alltoall = NULL;
adapt_module->super.coll_alltoallw = NULL;
adapt_module->super.coll_barrier = NULL;
adapt_module->super.coll_module_disable = adapt_module_disable;
adapt_module->super.coll_bcast = ompi_coll_adapt_bcast;
adapt_module->super.coll_exscan = NULL;
adapt_module->super.coll_gather = NULL;
adapt_module->super.coll_gatherv = NULL;
adapt_module->super.coll_reduce = ompi_coll_adapt_reduce;
adapt_module->super.coll_reduce_scatter = NULL;
adapt_module->super.coll_scan = NULL;
adapt_module->super.coll_scatter = NULL;
adapt_module->super.coll_scatterv = NULL;
adapt_module->super.coll_ibcast = ompi_coll_adapt_ibcast;
adapt_module->super.coll_ireduce = ompi_coll_adapt_ireduce;
adapt_module->super.coll_iallreduce = NULL;

opal_output_verbose(10, ompi_coll_base_framework.framework_output,
"coll:adapt:comm_query (%s/%s): pick me! pick me!",
Expand Down

0 comments on commit c03b06c

Please sign in to comment.