Skip to content

Commit

Permalink
Merge pull request #6998 from hzhou/2404_abi_mpix
Browse files Browse the repository at this point in the history
abi: add mpix op and errhandler delcarations in mpi_abi.h

Approved-by: Ken Raffenetti
  • Loading branch information
hzhou committed May 8, 2024
2 parents 5ae17c4 + be826b3 commit 81f5d34
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 67 deletions.
11 changes: 3 additions & 8 deletions maint/gen_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def load_mpi_abi_h(mpi_abi_h):
def dump_mpi_abi_internal_h(mpi_abi_internal_h):
define_constants = {}
def gen_mpi_abi_internal_h(out):
re_Handle = r'\bMPI_(Comm|Datatype|Errhandler|Group|Info|Message|Op|Request|Session|Win|KEYVAL_INVALID|TAG_UB|IO|HOST|WTIME_IS_GLOBAL|APPNUM|LASTUSEDCODE|UNIVERSE_SIZE|WIN_BASE|WIN_DISP_UNIT|WIN_SIZE|WIN_CREATE_FLAVOR|WIN_MODEL)\b'
re_Handle = r'\bMPI_(Comm|Datatype|Errhandler|Group|Info|Message|Op|Request|Session|Win|File|KEYVAL_INVALID|TAG_UB|IO|HOST|WTIME_IS_GLOBAL|APPNUM|LASTUSEDCODE|UNIVERSE_SIZE|WIN_BASE|WIN_DISP_UNIT|WIN_SIZE|WIN_CREATE_FLAVOR|WIN_MODEL)\b'
for line in G.abi_h_lines:
if RE.search(r'MPI_ABI_H_INCLUDED', line):
# skip the include guard, harmless
Expand All @@ -58,13 +58,8 @@ def gen_mpi_abi_internal_h(out):
elif T == "MPI_Op":
idx = int(val, 0) & G.op_mask
G.abi_ops[idx] = name

if T == "MPI_File":
# pass through
out.append(line.rstrip())
else:
# replace param prefix
out.append(re.sub(r'\bMPI_', 'ABI_', line.rstrip()))
# replace param prefix
out.append(re.sub(r'\bMPI_', 'ABI_', line.rstrip()))
elif RE.match(r'#define MPI_(LONG_LONG|C_COMPLEX)', line):
# datatype aliases
out.append(re.sub(r'\bMPI_', 'ABI_', line.rstrip()))
Expand Down
9 changes: 5 additions & 4 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,11 +1047,12 @@ def out_can_be_undefined(p):
return True
return False
# ----
re_Handle = r'MPI_(Comm|Datatype|Errhandler|Group|Info|Message|Op|Request|Session|Win|File)\b'
for p in func['c_parameters']:
skip_abi_swap = False
param_type = mapping[p['kind']]
name = p['name']
if RE.match(r'MPI_(Comm|Datatype|Errhandler|Group|Info|Message|Op|Request|Session|Win)\b', param_type):
if RE.match(re_Handle, param_type):
process_handle(p)
elif p['kind'] == 'KEYVAL' and p['param_direction'] == 'in':
pre_filters.append("int %s = ABI_KEYVAL_to_mpi(%s_abi);" % (name, name))
Expand All @@ -1070,7 +1071,7 @@ def out_can_be_undefined(p):

# MPI_Comm comm -> ABI_Comm comm_abi
param = get_C_param(p, func, mapping)
param = re.sub(r'MPI_(Comm|Datatype|Errhandler|Group|Info|Message|Op|Request|Session|Win)\b', r'ABI_\1', param)
param = re.sub(re_Handle, r'ABI_\1', param)
if not skip_abi_swap:
param = re.sub(r'\b' + name, name+"_abi", param)
param_list.append(param)
Expand All @@ -1081,7 +1082,7 @@ def out_can_be_undefined(p):
ret = "int"
if 'return' in func:
ret = mapping[func['return']]
ret = re.sub(r'MPI_(Comm|Datatype|Errhandler|Group|Info|Message|Op|Request|Session|Win)\b', r'ABI_\1', ret)
ret = re.sub(re_Handle, r'ABI_\1', ret)

static_call = get_static_call_internal(func, is_large)

Expand All @@ -1098,7 +1099,7 @@ def out_can_be_undefined(p):

if ret != 'int':
# MPI_Wtime, MPI_Aint_{add,diff}, MPI_{Comm,...}_{c2f,f2c}
if RE.match(r'..._(Comm|Datatype|Errhandler|Group|Info|Message|Op|Request|Session|Win)\b', ret):
if RE.match(r'..._(Comm|Datatype|Errhandler|Group|Info|Message|Op|Request|Session|Win|File)\b', ret):
G.out.append("return ABI_%s_from_mpi(%s);" % (RE.m.group(1), static_call))
else:
G.out.append("return " + static_call + ";")
Expand Down
17 changes: 17 additions & 0 deletions src/binding/abi/mpi_abi.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ enum {

typedef void (MPI_User_function)(void *invec, void *inoutvec, int *len, MPI_Datatype *datatype);
typedef void (MPI_User_function_c)(void *invec, void *inoutvec, MPI_Count *len, MPI_Datatype *datatype);
typedef void (MPIX_User_function_x) ( void *invec, void *inoutvec, MPI_Count count, MPI_Datatype datatype, void *extra_state);

typedef int (MPI_Grequest_query_function)(void *extra_state, MPI_Status *status);
typedef int (MPI_Grequest_free_function)(void *extra_state);
Expand All @@ -474,11 +475,18 @@ typedef void (MPI_File_errhandler_function)(MPI_File *file, int *error_code, ...
typedef void (MPI_Win_errhandler_function)(MPI_Win *win, int *error_code, ...);
typedef void (MPI_Session_errhandler_function)(MPI_Session *session, int *error_code, ...);

typedef void (MPIX_Comm_errhandler_function_x)(MPI_Comm comm, int error_code, void *extra_state);
typedef void (MPIX_File_errhandler_function_x)(MPI_File file, int error_code, void *extra_state);
typedef void (MPIX_Win_errhandler_function_x)(MPI_Win win, int error_code, void *extra_state);
typedef void (MPIX_Session_errhandler_function_x)(MPI_Session session, int error_code, void *extra_state);

typedef MPI_Comm_errhandler_function MPI_Comm_errhandler_fn;
typedef MPI_File_errhandler_function MPI_File_errhandler_fn;
typedef MPI_Win_errhandler_function MPI_Win_errhandler_fn;
typedef MPI_Session_errhandler_function MPI_Session_errhandler_fn;

typedef void (MPIX_Destructor_function) (void *extra_state);

#define MPI_NULL_COPY_FN ((MPI_Copy_function*)0x0) /* deprecated: MPI-2.0 */
#define MPI_DUP_FN ((MPI_Copy_function*)0x1) /* deprecated: MPI-2.0 */
#define MPI_NULL_DELETE_FN ((MPI_Delete_function*)0x0) /* deprecated: MPI-2.0 */
Expand Down Expand Up @@ -650,6 +658,7 @@ int MPI_Comm_compare(MPI_Comm comm1, MPI_Comm comm2, int *result);
int MPI_Comm_connect(const char *port_name, MPI_Info info, int root, MPI_Comm comm, MPI_Comm *newcomm);
int MPI_Comm_create(MPI_Comm comm, MPI_Group group, MPI_Comm *newcomm);
int MPI_Comm_create_errhandler(MPI_Comm_errhandler_function *comm_errhandler_fn, MPI_Errhandler *errhandler);
int MPIX_Comm_create_errhandler_x(MPIX_Comm_errhandler_function_x *comm_errhandler_fn_x, MPIX_Destructor_function *destructor_fn, void *extra_state, MPI_Errhandler *errhandler);
int MPI_Comm_create_from_group(MPI_Group group, const char *stringtag, MPI_Info info, MPI_Errhandler errhandler, MPI_Comm *newcomm);
int MPI_Comm_create_group(MPI_Comm comm, MPI_Group group, int tag, MPI_Comm *newcomm);
int MPI_Comm_create_keyval(MPI_Comm_copy_attr_function *comm_copy_attr_fn, MPI_Comm_delete_attr_function *comm_delete_attr_fn, int *comm_keyval, void *extra_state);
Expand Down Expand Up @@ -702,6 +711,7 @@ int MPI_Fetch_and_op(const void *origin_addr, void *result_addr, MPI_Datatype da
int MPI_File_call_errhandler(MPI_File fh, int errorcode);
int MPI_File_close(MPI_File *fh);
int MPI_File_create_errhandler(MPI_File_errhandler_function *file_errhandler_fn, MPI_Errhandler *errhandler);
int MPIX_File_create_errhandler_x(MPIX_File_errhandler_function_x *comm_errhandler_fn_x, MPIX_Destructor_function *destructor_fn, void *extra_state, MPI_Errhandler *errhandler);
int MPI_File_delete(const char *filename, MPI_Info info);
int MPI_File_get_amode(MPI_File fh, int *amode);
int MPI_File_get_atomicity(MPI_File fh, int *flag);
Expand Down Expand Up @@ -941,6 +951,7 @@ int MPI_Neighbor_alltoallw_init_c(const void *sendbuf, const MPI_Count sendcount
int MPI_Op_commutative(MPI_Op op, int *commute);
int MPI_Op_create(MPI_User_function *user_fn, int commute, MPI_Op *op);
int MPI_Op_create_c(MPI_User_function_c *user_fn, int commute, MPI_Op *op);
int MPIX_Op_create_x(MPIX_User_function_x *user_fn_x, MPIX_Destructor_function *destructor_fn, int commute, void *extra_state, MPI_Op *op);
int MPI_Op_free(MPI_Op *op);
int MPI_Open_port(MPI_Info info, char *port_name);
int MPI_Pack(const void *inbuf, int incount, MPI_Datatype datatype, void *outbuf, int outsize, int *position, MPI_Comm comm);
Expand Down Expand Up @@ -1027,6 +1038,7 @@ int MPI_Session_attach_buffer(MPI_Session session, void *buffer, int size);
int MPI_Session_attach_buffer_c(MPI_Session session, void *buffer, MPI_Count size);
int MPI_Session_call_errhandler(MPI_Session session, int errorcode);
int MPI_Session_create_errhandler(MPI_Session_errhandler_function *session_errhandler_fn, MPI_Errhandler *errhandler);
int MPIX_Session_create_errhandler_x(MPIX_Session_errhandler_function_x *comm_errhandler_fn_x, MPIX_Destructor_function *destructor_fn, void *extra_state, MPI_Errhandler *errhandler);
int MPI_Session_detach_buffer(MPI_Session session, void *buffer_addr, int *size);
int MPI_Session_detach_buffer_c(MPI_Session session, void *buffer_addr, MPI_Count *size);
int MPI_Session_finalize(MPI_Session *session);
Expand Down Expand Up @@ -1131,6 +1143,7 @@ int MPI_Win_create(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_
int MPI_Win_create_c(void *base, MPI_Aint size, MPI_Aint disp_unit, MPI_Info info, MPI_Comm comm, MPI_Win *win);
int MPI_Win_create_dynamic(MPI_Info info, MPI_Comm comm, MPI_Win *win);
int MPI_Win_create_errhandler(MPI_Win_errhandler_function *win_errhandler_fn, MPI_Errhandler *errhandler);
int MPIX_Win_create_errhandler_x(MPIX_Win_errhandler_function_x *comm_errhandler_fn_x, MPIX_Destructor_function *destructor_fn, void *extra_state, MPI_Errhandler *errhandler);
int MPI_Win_create_keyval(MPI_Win_copy_attr_function *win_copy_attr_fn, MPI_Win_delete_attr_function *win_delete_attr_fn, int *win_keyval, void *extra_state);
int MPI_Win_delete_attr(MPI_Win win, int win_keyval);
int MPI_Win_detach(MPI_Win win, const void *base);
Expand Down Expand Up @@ -1318,6 +1331,7 @@ int PMPI_Comm_compare(MPI_Comm comm1, MPI_Comm comm2, int *result);
int PMPI_Comm_connect(const char *port_name, MPI_Info info, int root, MPI_Comm comm, MPI_Comm *newcomm);
int PMPI_Comm_create(MPI_Comm comm, MPI_Group group, MPI_Comm *newcomm);
int PMPI_Comm_create_errhandler(MPI_Comm_errhandler_function *comm_errhandler_fn, MPI_Errhandler *errhandler);
int PMPIX_Comm_create_errhandler_x(MPIX_Comm_errhandler_function_x *comm_errhandler_fn_x, MPIX_Destructor_function *destructor_fn, void *extra_state, MPI_Errhandler *errhandler);
int PMPI_Comm_create_from_group(MPI_Group group, const char *stringtag, MPI_Info info, MPI_Errhandler errhandler, MPI_Comm *newcomm);
int PMPI_Comm_create_group(MPI_Comm comm, MPI_Group group, int tag, MPI_Comm *newcomm);
int PMPI_Comm_create_keyval(MPI_Comm_copy_attr_function *comm_copy_attr_fn, MPI_Comm_delete_attr_function *comm_delete_attr_fn, int *comm_keyval, void *extra_state);
Expand Down Expand Up @@ -1370,6 +1384,7 @@ int PMPI_Fetch_and_op(const void *origin_addr, void *result_addr, MPI_Datatype d
int PMPI_File_call_errhandler(MPI_File fh, int errorcode);
int PMPI_File_close(MPI_File *fh);
int PMPI_File_create_errhandler(MPI_File_errhandler_function *file_errhandler_fn, MPI_Errhandler *errhandler);
int PMPIX_File_create_errhandler_x(MPIX_File_errhandler_function_x *comm_errhandler_fn_x, MPIX_Destructor_function *destructor_fn, void *extra_state, MPI_Errhandler *errhandler);
int PMPI_File_delete(const char *filename, MPI_Info info);
int PMPI_File_get_amode(MPI_File fh, int *amode);
int PMPI_File_get_atomicity(MPI_File fh, int *flag);
Expand Down Expand Up @@ -1695,6 +1710,7 @@ int PMPI_Session_attach_buffer(MPI_Session session, void *buffer, int size);
int PMPI_Session_attach_buffer_c(MPI_Session session, void *buffer, MPI_Count size);
int PMPI_Session_call_errhandler(MPI_Session session, int errorcode);
int PMPI_Session_create_errhandler(MPI_Session_errhandler_function *session_errhandler_fn, MPI_Errhandler *errhandler);
int PMPIX_Session_create_errhandler_x(MPIX_Session_errhandler_function_x *comm_errhandler_fn_x, MPIX_Destructor_function *destructor_fn, void *extra_state, MPI_Errhandler *errhandler);
int PMPI_Session_detach_buffer(MPI_Session session, void *buffer_addr, int *size);
int PMPI_Session_detach_buffer_c(MPI_Session session, void *buffer_addr, MPI_Count *size);
int PMPI_Session_finalize(MPI_Session *session);
Expand Down Expand Up @@ -1799,6 +1815,7 @@ int PMPI_Win_create(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI
int PMPI_Win_create_c(void *base, MPI_Aint size, MPI_Aint disp_unit, MPI_Info info, MPI_Comm comm, MPI_Win *win);
int PMPI_Win_create_dynamic(MPI_Info info, MPI_Comm comm, MPI_Win *win);
int PMPI_Win_create_errhandler(MPI_Win_errhandler_function *win_errhandler_fn, MPI_Errhandler *errhandler);
int PMPIX_Win_create_errhandler_x(MPIX_Win_errhandler_function_x *comm_errhandler_fn_x, MPIX_Destructor_function *destructor_fn, void *extra_state, MPI_Errhandler *errhandler);
int PMPI_Win_create_keyval(MPI_Win_copy_attr_function *win_copy_attr_fn, MPI_Win_delete_attr_function *win_delete_attr_fn, int *win_keyval, void *extra_state);
int PMPI_Win_delete_attr(MPI_Win win, int win_keyval);
int PMPI_Win_detach(MPI_Win win, const void *base);
Expand Down
116 changes: 64 additions & 52 deletions src/binding/abi/mpi_abi_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,72 +49,72 @@ static inline ABI_Comm ABI_Comm_from_mpi(MPI_Comm in)

static inline MPIR_Comm *ABI_Comm_ptr(ABI_Comm comm_abi)
{
MPIR_Comm *comm_ptr = NULL;
if (comm_abi != ABI_COMM_NULL) {
MPI_Comm comm = ABI_Comm_to_mpi(comm_abi);
MPIR_Comm_get_ptr(comm, comm_ptr);
if (comm_ptr != NULL) {
if (MPIR_Object_get_ref(comm_ptr) <= 0) {
comm_ptr = NULL;
}
MPIR_Comm *comm_ptr = NULL;
if (comm_abi != ABI_COMM_NULL) {
MPI_Comm comm = ABI_Comm_to_mpi(comm_abi);
MPIR_Comm_get_ptr(comm, comm_ptr);
if (comm_ptr != NULL) {
if (MPIR_Object_get_ref(comm_ptr) <= 0) {
comm_ptr = NULL;
}
}
}
}
return comm_ptr;
return comm_ptr;
}

static inline int ABI_Comm_rank(ABI_Comm comm_abi)
{
MPIR_Comm *comm_ptr = ABI_Comm_ptr(comm_abi);
int rank = MPI_PROC_NULL;
if (comm_ptr != NULL) {
MPIR_Comm_rank_impl(comm_ptr, &rank);
}
return rank;
MPIR_Comm *comm_ptr = ABI_Comm_ptr(comm_abi);
int rank = MPI_PROC_NULL;
if (comm_ptr != NULL) {
MPIR_Comm_rank_impl(comm_ptr, &rank);
}
return rank;
}

static inline int ABI_Comm_peer_size(ABI_Comm comm_abi)
{
MPIR_Comm *comm_ptr = ABI_Comm_ptr(comm_abi);
int size = 0;
if (comm_ptr != NULL) {
int flag = 0;
MPIR_Comm_test_inter_impl(comm_ptr, &flag);
if (flag) {
MPIR_Comm_remote_size_impl(comm_ptr, &size);
} else {
MPIR_Comm_size_impl(comm_ptr, &size);
MPIR_Comm *comm_ptr = ABI_Comm_ptr(comm_abi);
int size = 0;
if (comm_ptr != NULL) {
int flag = 0;
MPIR_Comm_test_inter_impl(comm_ptr, &flag);
if (flag) {
MPIR_Comm_remote_size_impl(comm_ptr, &size);
} else {
MPIR_Comm_size_impl(comm_ptr, &size);
}
}
}
return size;
return size;
}

static inline void ABI_Comm_neighbors_count(ABI_Comm comm_abi, int *indegree, int *outdegree)
{
MPIR_Comm *comm_ptr = ABI_Comm_ptr(comm_abi);
int topo = MPI_UNDEFINED;
int rank = MPI_PROC_NULL;
int ival = 0;
if (comm_ptr != NULL) {
MPIR_Topo_test_impl(comm_ptr, &topo);
}
switch (topo) {
case MPI_CART:
MPIR_Cartdim_get_impl(comm_ptr, &ival);
*indegree = *outdegree = 2 * ival;
break;
case MPI_GRAPH:
MPIR_Comm_rank_impl(comm_ptr, &rank);
MPIR_Graph_neighbors_count_impl(comm_ptr, rank, &ival);
*indegree = *outdegree = ival;
break;
case MPI_DIST_GRAPH:
*indegree = *outdegree = 0;
MPIR_Dist_graph_neighbors_count_impl(comm_ptr, indegree, outdegree, &ival);
break;
default:
*indegree = *outdegree = 0;
break;
}
MPIR_Comm *comm_ptr = ABI_Comm_ptr(comm_abi);
int topo = MPI_UNDEFINED;
int rank = MPI_PROC_NULL;
int ival = 0;
if (comm_ptr != NULL) {
MPIR_Topo_test_impl(comm_ptr, &topo);
}
switch (topo) {
case MPI_CART:
MPIR_Cartdim_get_impl(comm_ptr, &ival);
*indegree = *outdegree = 2 * ival;
break;
case MPI_GRAPH:
MPIR_Comm_rank_impl(comm_ptr, &rank);
MPIR_Graph_neighbors_count_impl(comm_ptr, rank, &ival);
*indegree = *outdegree = ival;
break;
case MPI_DIST_GRAPH:
*indegree = *outdegree = 0;
MPIR_Dist_graph_neighbors_count_impl(comm_ptr, indegree, outdegree, &ival);
break;
default:
*indegree = *outdegree = 0;
break;
}
}

static inline MPI_Datatype ABI_Datatype_to_mpi(ABI_Datatype in)
Expand Down Expand Up @@ -323,6 +323,18 @@ static inline ABI_Win ABI_Win_from_mpi(MPI_Win in)
return (ABI_Win) (void *) ptr;
}

static inline MPI_File ABI_File_to_mpi(ABI_File in)
{
/* Both MPI_File in mpich and ABI_File are pointers */
return (MPI_File) in;
}

static inline ABI_File ABI_File_from_mpi(MPI_File in)
{
/* Both MPI_File in mpich and ABI_File are pointers */
return (ABI_File) in;
}

/* MPICH internal callbacks does not differentiate handle types, so we need
* a general conversion routine */
static inline void *ABI_Handle_from_mpi(int in)
Expand Down
2 changes: 1 addition & 1 deletion src/binding/fortran/mpif_h/mpi_fortimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ typedef void (FORT_CALL F77_OpFunction) (void *, void *, MPI_Fint *, MPI_Fint *)
typedef void (FORT_CALL F77_ErrFunction) (MPI_Fint *, MPI_Fint *);

void MPII_Keyval_set_f90_proxy(int keyval);
int MPII_op_create(MPI_User_function * opfn, MPI_Fint commute, MPI_Fint * op);
int MPII_op_create(F77_OpFunction * opfn, MPI_Fint commute, MPI_Fint * op);
int MPII_errhan_create(F77_ErrFunction * err_fn, MPI_Fint * errhandler, enum F77_handle_type type);

extern FORT_DLL_SPEC void FORT_CALL mpi_alloc_mem_cptr_(MPI_Aint * size, MPI_Fint * info,
Expand Down
2 changes: 0 additions & 2 deletions src/include/mpi.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,13 @@ typedef int MPI_Win;
typedef int MPI_Session;
#define MPI_SESSION_NULL ((MPI_Session)0x38000000)

#ifndef BUILD_MPI_ABI
/* File and IO */
/* This define lets ROMIO know that MPI_File has been defined */
#define MPI_FILE_DEFINED
/* ROMIO uses a pointer for MPI_File objects. This must be the same definition
as in src/mpi/romio/include/mpio.h.in */
typedef struct ADIOI_FileD *MPI_File;
#define MPI_FILE_NULL ((MPI_File)0)
#endif

/* Collective operations */
typedef int MPI_Op;
Expand Down

0 comments on commit 81f5d34

Please sign in to comment.