Skip to content

Commit

Permalink
Merge pull request #12429 from bosilca/topic/fix_collective_init_fini
Browse files Browse the repository at this point in the history
Fix collective modules initialization and finalization
  • Loading branch information
bosilca committed May 1, 2024
2 parents 49382c3 + f2dfbba commit bf2068a
Show file tree
Hide file tree
Showing 63 changed files with 1,734 additions and 1,421 deletions.
3 changes: 1 addition & 2 deletions ompi/mca/coll/accelerator/Makefile.am
Expand Up @@ -2,15 +2,14 @@
# Copyright (c) 2014 The University of Tennessee and The University
# of Tennessee Research Foundation. All rights
# reserved.
# Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
# Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
# Copyright (c) 2017 IBM Corporation. All rights reserved.
# $COPYRIGHT$
#
# Additional copyrights may follow
#
# $HEADER$
#
dist_ompidata_DATA = help-mpi-coll-accelerator.txt

sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \
Expand Down
5 changes: 1 addition & 4 deletions ompi/mca/coll/accelerator/coll_accelerator.h
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2014 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
* $COPYRIGHT$
*
Expand Down Expand Up @@ -38,9 +38,6 @@ mca_coll_base_module_t
*mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
int *priority);

int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);

int
mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
Expand Down
127 changes: 67 additions & 60 deletions ompi/mca/coll/accelerator/coll_accelerator_module.c
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2014-2017 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
Expand Down Expand Up @@ -32,30 +32,21 @@
#include "ompi/mca/coll/base/base.h"
#include "coll_accelerator.h"

static int
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);
static int
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);

static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module)
{
memset(&(module->c_coll), 0, sizeof(module->c_coll));
}

static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module)
{
OBJ_RELEASE(module->c_coll.coll_allreduce_module);
OBJ_RELEASE(module->c_coll.coll_reduce_module);
OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
OBJ_RELEASE(module->c_coll.coll_scatter_module);
/* If the exscan module is not NULL, then this was an
intracommunicator, and therefore scan will have a module as
well. */
if (NULL != module->c_coll.coll_exscan_module) {
OBJ_RELEASE(module->c_coll.coll_exscan_module);
OBJ_RELEASE(module->c_coll.coll_scan_module);
}
}

OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t,
mca_coll_accelerator_module_construct,
mca_coll_accelerator_module_destruct);
NULL);


/*
Expand Down Expand Up @@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,

/* Choose whether to use [intra|inter] */
accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable;
accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable;

accelerator_module->super.coll_allgather = NULL;
accelerator_module->super.coll_allgatherv = NULL;
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
accelerator_module->super.coll_alltoall = NULL;
accelerator_module->super.coll_alltoallv = NULL;
accelerator_module->super.coll_alltoallw = NULL;
accelerator_module->super.coll_barrier = NULL;
accelerator_module->super.coll_bcast = NULL;
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
accelerator_module->super.coll_gather = NULL;
accelerator_module->super.coll_gatherv = NULL;
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
accelerator_module->super.coll_reduce_scatter = NULL;
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
accelerator_module->super.coll_scatter = NULL;
accelerator_module->super.coll_scatterv = NULL;
if (!OMPI_COMM_IS_INTER(comm)) {
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
}

return &(accelerator_module->super);
}


#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if ((__comm)->c_coll->coll_##__api) \
{ \
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
} \
else \
{ \
opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \
"cuda", #__api, ompi_process_info.nodename, \
mca_coll_accelerator_component.priority); \
} \
} while (0)

#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \
MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
(__module)->c_coll.coll_##__api##_module = NULL; \
(__module)->c_coll.coll_##__api = NULL; \
} \
} while (0)

/*
* Init module on the communicator
* Init/Fini module on the communicator
*/
int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
static int
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
bool good = true;
char *msg = NULL;
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

#define CHECK_AND_RETAIN(src, dst, name) \
if (NULL == (src)->c_coll->coll_ ## name ## _module) { \
good = false; \
msg = #name; \
} else if (good) { \
(dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \
(dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \
OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \
}

CHECK_AND_RETAIN(comm, s, allreduce);
CHECK_AND_RETAIN(comm, s, reduce);
CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
CHECK_AND_RETAIN(comm, s, scatter);
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm)) {
/* MPI does not define scan/exscan on intercommunicators */
CHECK_AND_RETAIN(comm, s, exscan);
CHECK_AND_RETAIN(comm, s, scan);
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_INSTALL_COLL_API(comm, s, scan);
}

/* All done */
if (good) {
return OMPI_SUCCESS;
}
opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true,
ompi_process_info.nodename,
mca_coll_accelerator_component.priority, msg);
return OMPI_ERR_NOT_FOUND;
return OMPI_SUCCESS;
}

static int
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm))
{
/* MPI does not define scan/exscan on intercommunicators */
ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan);
}

return OMPI_SUCCESS;
}
29 changes: 0 additions & 29 deletions ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt

This file was deleted.

87 changes: 52 additions & 35 deletions ompi/mca/coll/adapt/coll_adapt_module.c
Expand Up @@ -5,6 +5,7 @@
* Copyright (c) 2021 Triad National Security, LLC. All rights
* reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
*
* $COPYRIGHT$
*
Expand Down Expand Up @@ -83,25 +84,41 @@ OBJ_CLASS_INSTANCE(mca_coll_adapt_module_t,
adapt_module_construct,
adapt_module_destruct);

/*
* In this macro, the following variables are supposed to have been declared
* in the caller:
* . ompi_communicator_t *comm
* . mca_coll_adapt_module_t *adapt_module
*/
#define ADAPT_SAVE_PREV_COLL_API(__api) \
do { \
adapt_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
adapt_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
opal_output_verbose(1, ompi_coll_base_framework.framework_output, \
"(%s/%s): no underlying " # __api"; disqualifying myself", \
ompi_comm_print_cid(comm), comm->c_name); \
return OMPI_ERROR; \
} \
OBJ_RETAIN(adapt_module->previous_ ## __api ## _module); \
} while(0)

#define ADAPT_INSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (__module->super.coll_##__api) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
} \
} while (0)
#define ADAPT_UNINSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "adapt"); \
} \
} while (0)
#define ADAPT_INSTALL_AND_SAVE_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api && __comm->c_coll->coll_##__api##_module) \
{ \
MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
} \
} while (0)
#define ADAPT_UNINSTALL_AND_RESTORE_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
__module->previous_##__api = NULL; \
__module->previous_##__api##_module = NULL; \
} \
} while (0)

/*
* Init module on the communicator
Expand All @@ -111,12 +128,25 @@ static int adapt_module_enable(mca_coll_base_module_t * module,
{
mca_coll_adapt_module_t * adapt_module = (mca_coll_adapt_module_t*) module;

ADAPT_SAVE_PREV_COLL_API(reduce);
ADAPT_SAVE_PREV_COLL_API(ireduce);
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, reduce);
ADAPT_INSTALL_COLL_API(comm, adapt_module, bcast);
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, ireduce);
ADAPT_INSTALL_COLL_API(comm, adapt_module, ibcast);

return OMPI_SUCCESS;
}
static int adapt_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_adapt_module_t *adapt_module = (mca_coll_adapt_module_t *)module;

ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, reduce);
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, bcast);
ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, ireduce);
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, ibcast);

return OMPI_SUCCESS;
}
/*
* Initial query function that is invoked during MPI_INIT, allowing
* this component to disqualify itself if it doesn't support the
Expand Down Expand Up @@ -165,24 +195,11 @@ mca_coll_base_module_t *ompi_coll_adapt_comm_query(struct ompi_communicator_t *

/* All is good -- return a module */
adapt_module->super.coll_module_enable = adapt_module_enable;
adapt_module->super.coll_allgather = NULL;
adapt_module->super.coll_allgatherv = NULL;
adapt_module->super.coll_allreduce = NULL;
adapt_module->super.coll_alltoall = NULL;
adapt_module->super.coll_alltoallw = NULL;
adapt_module->super.coll_barrier = NULL;
adapt_module->super.coll_module_disable = adapt_module_disable;
adapt_module->super.coll_bcast = ompi_coll_adapt_bcast;
adapt_module->super.coll_exscan = NULL;
adapt_module->super.coll_gather = NULL;
adapt_module->super.coll_gatherv = NULL;
adapt_module->super.coll_reduce = ompi_coll_adapt_reduce;
adapt_module->super.coll_reduce_scatter = NULL;
adapt_module->super.coll_scan = NULL;
adapt_module->super.coll_scatter = NULL;
adapt_module->super.coll_scatterv = NULL;
adapt_module->super.coll_ibcast = ompi_coll_adapt_ibcast;
adapt_module->super.coll_ireduce = ompi_coll_adapt_ireduce;
adapt_module->super.coll_iallreduce = NULL;

opal_output_verbose(10, ompi_coll_base_framework.framework_output,
"coll:adapt:comm_query (%s/%s): pick me! pick me!",
Expand Down

0 comments on commit bf2068a

Please sign in to comment.