Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix collective modules initialization and finalization #12429

Merged
merged 1 commit into from May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions ompi/mca/coll/accelerator/Makefile.am
Expand Up @@ -2,15 +2,14 @@
# Copyright (c) 2014 The University of Tennessee and The University
# of Tennessee Research Foundation. All rights
# reserved.
# Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
# Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
# Copyright (c) 2017 IBM Corporation. All rights reserved.
# $COPYRIGHT$
#
# Additional copyrights may follow
#
# $HEADER$
#
dist_ompidata_DATA = help-mpi-coll-accelerator.txt

sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \
Expand Down
5 changes: 1 addition & 4 deletions ompi/mca/coll/accelerator/coll_accelerator.h
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2014 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
* $COPYRIGHT$
*
Expand Down Expand Up @@ -38,9 +38,6 @@ mca_coll_base_module_t
*mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
int *priority);

int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);

int
mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
Expand Down
127 changes: 67 additions & 60 deletions ompi/mca/coll/accelerator/coll_accelerator_module.c
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2014-2017 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
Expand Down Expand Up @@ -32,30 +32,21 @@
#include "ompi/mca/coll/base/base.h"
#include "coll_accelerator.h"

static int
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);
static int
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm);

static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module)
{
memset(&(module->c_coll), 0, sizeof(module->c_coll));
}

static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module)
{
OBJ_RELEASE(module->c_coll.coll_allreduce_module);
OBJ_RELEASE(module->c_coll.coll_reduce_module);
OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
OBJ_RELEASE(module->c_coll.coll_scatter_module);
/* If the exscan module is not NULL, then this was an
intracommunicator, and therefore scan will have a module as
well. */
if (NULL != module->c_coll.coll_exscan_module) {
OBJ_RELEASE(module->c_coll.coll_exscan_module);
OBJ_RELEASE(module->c_coll.coll_scan_module);
}
}

OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t,
mca_coll_accelerator_module_construct,
mca_coll_accelerator_module_destruct);
NULL);


/*
Expand Down Expand Up @@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,

/* Choose whether to use [intra|inter] */
accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable;
accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable;

accelerator_module->super.coll_allgather = NULL;
accelerator_module->super.coll_allgatherv = NULL;
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
accelerator_module->super.coll_alltoall = NULL;
accelerator_module->super.coll_alltoallv = NULL;
accelerator_module->super.coll_alltoallw = NULL;
accelerator_module->super.coll_barrier = NULL;
accelerator_module->super.coll_bcast = NULL;
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
accelerator_module->super.coll_gather = NULL;
accelerator_module->super.coll_gatherv = NULL;
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
accelerator_module->super.coll_reduce_scatter = NULL;
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
accelerator_module->super.coll_scatter = NULL;
accelerator_module->super.coll_scatterv = NULL;
if (!OMPI_COMM_IS_INTER(comm)) {
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
}

return &(accelerator_module->super);
}


#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if ((__comm)->c_coll->coll_##__api) \
{ \
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
} \
else \
{ \
opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \
"cuda", #__api, ompi_process_info.nodename, \
mca_coll_accelerator_component.priority); \
} \
} while (0)

#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \
MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
(__module)->c_coll.coll_##__api##_module = NULL; \
(__module)->c_coll.coll_##__api = NULL; \
} \
} while (0)

/*
* Init module on the communicator
* Init/Fini module on the communicator
*/
int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
static int
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
bool good = true;
char *msg = NULL;
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

#define CHECK_AND_RETAIN(src, dst, name) \
if (NULL == (src)->c_coll->coll_ ## name ## _module) { \
good = false; \
msg = #name; \
} else if (good) { \
(dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \
(dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \
OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \
}

CHECK_AND_RETAIN(comm, s, allreduce);
CHECK_AND_RETAIN(comm, s, reduce);
CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
CHECK_AND_RETAIN(comm, s, scatter);
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm)) {
/* MPI does not define scan/exscan on intercommunicators */
CHECK_AND_RETAIN(comm, s, exscan);
CHECK_AND_RETAIN(comm, s, scan);
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_INSTALL_COLL_API(comm, s, scan);
}

/* All done */
if (good) {
return OMPI_SUCCESS;
}
opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true,
ompi_process_info.nodename,
mca_coll_accelerator_component.priority, msg);
return OMPI_ERR_NOT_FOUND;
return OMPI_SUCCESS;
}

static int
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
if (!OMPI_COMM_IS_INTER(comm))
{
/* MPI does not define scan/exscan on intercommunicators */
ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan);
}

return OMPI_SUCCESS;
}
29 changes: 0 additions & 29 deletions ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt

This file was deleted.

87 changes: 52 additions & 35 deletions ompi/mca/coll/adapt/coll_adapt_module.c
Expand Up @@ -5,6 +5,7 @@
* Copyright (c) 2021 Triad National Security, LLC. All rights
* reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
*
* $COPYRIGHT$
*
Expand Down Expand Up @@ -83,25 +84,41 @@ OBJ_CLASS_INSTANCE(mca_coll_adapt_module_t,
adapt_module_construct,
adapt_module_destruct);

/*
* In this macro, the following variables are supposed to have been declared
* in the caller:
* . ompi_communicator_t *comm
* . mca_coll_adapt_module_t *adapt_module
*/
#define ADAPT_SAVE_PREV_COLL_API(__api) \
do { \
adapt_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
adapt_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
opal_output_verbose(1, ompi_coll_base_framework.framework_output, \
"(%s/%s): no underlying " # __api"; disqualifying myself", \
ompi_comm_print_cid(comm), comm->c_name); \
return OMPI_ERROR; \
} \
OBJ_RETAIN(adapt_module->previous_ ## __api ## _module); \
} while(0)

#define ADAPT_INSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (__module->super.coll_##__api) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
} \
} while (0)
#define ADAPT_UNINSTALL_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "adapt"); \
} \
} while (0)
#define ADAPT_INSTALL_AND_SAVE_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api && __comm->c_coll->coll_##__api##_module) \
{ \
MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
} \
} while (0)
#define ADAPT_UNINSTALL_AND_RESTORE_COLL_API(__comm, __module, __api) \
do \
{ \
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
{ \
MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
__module->previous_##__api = NULL; \
__module->previous_##__api##_module = NULL; \
} \
} while (0)

/*
* Init module on the communicator
Expand All @@ -111,12 +128,25 @@ static int adapt_module_enable(mca_coll_base_module_t * module,
{
mca_coll_adapt_module_t * adapt_module = (mca_coll_adapt_module_t*) module;

ADAPT_SAVE_PREV_COLL_API(reduce);
ADAPT_SAVE_PREV_COLL_API(ireduce);
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, reduce);
ADAPT_INSTALL_COLL_API(comm, adapt_module, bcast);
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, ireduce);
ADAPT_INSTALL_COLL_API(comm, adapt_module, ibcast);

return OMPI_SUCCESS;
}
static int adapt_module_disable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_adapt_module_t *adapt_module = (mca_coll_adapt_module_t *)module;

ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, reduce);
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, bcast);
ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, ireduce);
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, ibcast);

return OMPI_SUCCESS;
}
/*
* Initial query function that is invoked during MPI_INIT, allowing
* this component to disqualify itself if it doesn't support the
Expand Down Expand Up @@ -165,24 +195,11 @@ mca_coll_base_module_t *ompi_coll_adapt_comm_query(struct ompi_communicator_t *

/* All is good -- return a module */
adapt_module->super.coll_module_enable = adapt_module_enable;
adapt_module->super.coll_allgather = NULL;
adapt_module->super.coll_allgatherv = NULL;
adapt_module->super.coll_allreduce = NULL;
adapt_module->super.coll_alltoall = NULL;
adapt_module->super.coll_alltoallw = NULL;
adapt_module->super.coll_barrier = NULL;
adapt_module->super.coll_module_disable = adapt_module_disable;
adapt_module->super.coll_bcast = ompi_coll_adapt_bcast;
adapt_module->super.coll_exscan = NULL;
adapt_module->super.coll_gather = NULL;
adapt_module->super.coll_gatherv = NULL;
adapt_module->super.coll_reduce = ompi_coll_adapt_reduce;
adapt_module->super.coll_reduce_scatter = NULL;
adapt_module->super.coll_scan = NULL;
adapt_module->super.coll_scatter = NULL;
adapt_module->super.coll_scatterv = NULL;
adapt_module->super.coll_ibcast = ompi_coll_adapt_ibcast;
adapt_module->super.coll_ireduce = ompi_coll_adapt_ireduce;
adapt_module->super.coll_iallreduce = NULL;

opal_output_verbose(10, ompi_coll_base_framework.framework_output,
"coll:adapt:comm_query (%s/%s): pick me! pick me!",
Expand Down