Skip to content

Commit

Permalink
Fix collective initialization and finalization
Browse files Browse the repository at this point in the history
Instead of allowing each collective module to present a list of
functions it provide, let them register the functions they provide and
save the context of the previous collective if they choose to.

There are two major benefits to this approach:
- tighter memory management in the collective module themselves. Each
  collective enable and disable is called exactly once per communicator,
  to chain or unchain themselves from the collective function pointers
  struct. The disable is called in the reverse order of the enable,
  allowing for proper chaining of collectives.
- they only install the functions they want. So instead of checking in
  the coll_select all the functions for all modules, each module can now
  selectively iterate over only the functions it provides.

What is still broken is the ability of a particular collective module to
unchain itself in the middle of the execution. Instead, a properly
implemented module will have an enable/disable flag, and it should act
as a passthrough if it chooses to desactivate.

Signed-off-by: George Bosilca <bosilca@icl.utk.edu>
  • Loading branch information
bosilca committed Apr 30, 2024
1 parent 49382c3 commit f2dfbba
Show file tree
Hide file tree
Showing 63 changed files with 1,734 additions and 1,421 deletions.
3 changes: 1 addition & 2 deletions ompi/mca/coll/accelerator/Makefile.am
Expand Up @@ -2,15 +2,14 @@
# Copyright (c) 2014 The University of Tennessee and The University
# of Tennessee Research Foundation. All rights
# reserved.
# Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
# Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
# Copyright (c) 2017 IBM Corporation. All rights reserved.
# $COPYRIGHT$
#
# Additional copyrights may follow
#
# $HEADER$
#
dist_ompidata_DATA = help-mpi-coll-accelerator.txt

sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \
Expand Down
5 changes: 1 addition & 4 deletions ompi/mca/coll/accelerator/coll_accelerator.h
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2014 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
* $COPYRIGHT$
*
Expand Down Expand Up @@ -38,9 +38,6 @@ mca_coll_base_module_t
*mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
int *priority);

int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);

int
mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
Expand Down
127 changes: 67 additions & 60 deletions ompi/mca/coll/accelerator/coll_accelerator_module.c
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2014-2017 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
Expand Down Expand Up @@ -32,30 +32,21 @@
#include "ompi/mca/coll/base/base.h"
#include "coll_accelerator.h"

static int
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);
static int
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);

static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module)
{
memset(&(module->c_coll), 0, sizeof(module->c_coll));
}

static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module)
{
OBJ_RELEASE(module->c_coll.coll_allreduce_module);
OBJ_RELEASE(module->c_coll.coll_reduce_module);
OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
OBJ_RELEASE(module->c_coll.coll_scatter_module);
/* If the exscan module is not NULL, then this was an
intracommunicator, and therefore scan will have a module as
well. */
if (NULL != module->c_coll.coll_exscan_module) {
OBJ_RELEASE(module->c_coll.coll_exscan_module);
OBJ_RELEASE(module->c_coll.coll_scan_module);
}
}

OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t,
mca_coll_accelerator_module_construct,
mca_coll_accelerator_module_destruct);
NULL);


/*
Expand Down Expand Up @@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,

/* Choose whether to use [intra|inter] */
accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable;
accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable;

accelerator_module->super.coll_allgather = NULL;
accelerator_module->super.coll_allgatherv = NULL;
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
accelerator_module->super.coll_alltoall = NULL;
accelerator_module->super.coll_alltoallv = NULL;
accelerator_module->super.coll_alltoallw = NULL;
accelerator_module->super.coll_barrier = NULL;
accelerator_module->super.coll_bcast = NULL;
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
accelerator_module->super.coll_gather = NULL;
accelerator_module->super.coll_gatherv = NULL;
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
accelerator_module->super.coll_reduce_scatter = NULL;
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
accelerator_module->super.coll_scatter = NULL;
accelerator_module->super.coll_scatterv = NULL;
if (!OMPI_COMM_IS_INTER(comm)) {
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
}

return &(accelerator_module->super);
}


#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if ((__comm)->c_coll->coll_##__api) \
{ \
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
} \
else \
{ \
opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \
"cuda", #__api, ompi_process_info.nodename, \
mca_coll_accelerator_component.priority); \
} \
} while (0)

#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \
MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
(__module)->c_coll.coll_##__api##_module = NULL; \
(__module)->c_coll.coll_##__api = NULL; \
} \
} while (0)

/*
* Init module on the communicator
* Init/Fini module on the communicator
*/
int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
static int
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
bool good = true;
char *msg = NULL;
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

#define CHECK_AND_RETAIN(src, dst, name) \
if (NULL == (src)->c_coll->coll_ ## name ## _module) { \
good = false; \
msg = #name; \
} else if (good) { \
(dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \
(dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \
OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \
}

CHECK_AND_RETAIN(comm, s, allreduce);
CHECK_AND_RETAIN(comm, s, reduce);
CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
CHECK_AND_RETAIN(comm, s, scatter);
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm)) {
/* MPI does not define scan/exscan on intercommunicators */
CHECK_AND_RETAIN(comm, s, exscan);
CHECK_AND_RETAIN(comm, s, scan);
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_INSTALL_COLL_API(comm, s, scan);
}

/* All done */
if (good) {
return OMPI_SUCCESS;
}
opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true,
ompi_process_info.nodename,
mca_coll_accelerator_component.priority, msg);
return OMPI_ERR_NOT_FOUND;
return OMPI_SUCCESS;
}

static int
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm))
{
/* MPI does not define scan/exscan on intercommunicators */
ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan);
}

return OMPI_SUCCESS;
}
29 changes: 0 additions & 29 deletions ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt

This file was deleted.

87 changes: 52 additions & 35 deletions ompi/mca/coll/adapt/coll_adapt_module.c
Expand Up @@ -5,6 +5,7 @@
* Copyright (c) 2021 Triad National Security, LLC. All rights
* reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
*
* $COPYRIGHT$
*
Expand Down Expand Up @@ -83,25 +84,41 @@ OBJ_CLASS_INSTANCE(mca_coll_adapt_module_t,
adapt_module_construct,
adapt_module_destruct);

/*
* In this macro, the following variables are supposed to have been declared
* in the caller:
* . ompi_communicator_t *comm
* . mca_coll_adapt_module_t *adapt_module
*/
#define ADAPT_SAVE_PREV_COLL_API(__api) \
do { \
adapt_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
adapt_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
opal_output_verbose(1, ompi_coll_base_framework.framework_output, \
"(%s/%s): no underlying " # __api"; disqualifying myself", \
ompi_comm_print_cid(comm), comm->c_name); \
return OMPI_ERROR; \
} \
OBJ_RETAIN(adapt_module->previous_ ## __api ## _module); \
} while(0)

#define ADAPT_INSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (__module->super.coll_##__api) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
} \
} while (0)
#define ADAPT_UNINSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "adapt"); \
} \
} while (0)
#define ADAPT_INSTALL_AND_SAVE_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api && __comm->c_coll->coll_##__api##_module) \
{ \
MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
} \
} while (0)
#define ADAPT_UNINSTALL_AND_RESTORE_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
__module->previous_##__api = NULL; \
__module->previous_##__api##_module = NULL; \
} \
} while (0)

/*
* Init module on the communicator
Expand All @@ -111,12 +128,25 @@ static int adapt_module_enable(mca_coll_base_module_t * module,
{
mca_coll_adapt_module_t * adapt_module = (mca_coll_adapt_module_t*) module;

ADAPT_SAVE_PREV_COLL_API(reduce);
ADAPT_SAVE_PREV_COLL_API(ireduce);
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, reduce);
ADAPT_INSTALL_COLL_API(comm, adapt_module, bcast);
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, ireduce);
ADAPT_INSTALL_COLL_API(comm, adapt_module, ibcast);

return OMPI_SUCCESS;
}
static int adapt_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_adapt_module_t *adapt_module = (mca_coll_adapt_module_t *)module;

ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, reduce);
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, bcast);
ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, ireduce);
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, ibcast);

return OMPI_SUCCESS;
}
/*
* Initial query function that is invoked during MPI_INIT, allowing
* this component to disqualify itself if it doesn't support the
Expand Down Expand Up @@ -165,24 +195,11 @@ mca_coll_base_module_t *ompi_coll_adapt_comm_query(struct ompi_communicator_t *

/* All is good -- return a module */
adapt_module->super.coll_module_enable = adapt_module_enable;
adapt_module->super.coll_allgather = NULL;
adapt_module->super.coll_allgatherv = NULL;
adapt_module->super.coll_allreduce = NULL;
adapt_module->super.coll_alltoall = NULL;
adapt_module->super.coll_alltoallw = NULL;
adapt_module->super.coll_barrier = NULL;
adapt_module->super.coll_module_disable = adapt_module_disable;
adapt_module->super.coll_bcast = ompi_coll_adapt_bcast;
adapt_module->super.coll_exscan = NULL;
adapt_module->super.coll_gather = NULL;
adapt_module->super.coll_gatherv = NULL;
adapt_module->super.coll_reduce = ompi_coll_adapt_reduce;
adapt_module->super.coll_reduce_scatter = NULL;
adapt_module->super.coll_scan = NULL;
adapt_module->super.coll_scatter = NULL;
adapt_module->super.coll_scatterv = NULL;
adapt_module->super.coll_ibcast = ompi_coll_adapt_ibcast;
adapt_module->super.coll_ireduce = ompi_coll_adapt_ireduce;
adapt_module->super.coll_iallreduce = NULL;

opal_output_verbose(10, ompi_coll_base_framework.framework_output,
"coll:adapt:comm_query (%s/%s): pick me! pick me!",
Expand Down

0 comments on commit f2dfbba

Please sign in to comment.