Skip to content

Commit

Permalink
Add gpu versions of globalSum and broadcast.
Browse files Browse the repository at this point in the history
Add alternative version of classical GS ortho that reduces cpu-gpu traffic
  • Loading branch information
eromero-vlc committed Dec 5, 2023
1 parent f08356a commit 8e1d430
Show file tree
Hide file tree
Showing 6 changed files with 486 additions and 83 deletions.
70 changes: 42 additions & 28 deletions src/eigs/auxiliary_eigs.c
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,6 @@ int applyPreconditioner_Sprimme(SCALAR *V, PRIMME_INT nLocal, PRIMME_INT ldV,
return 0;
}

#ifdef USE_HOST

TEMPLATE_PLEASE
int globalSum_Sprimme(SCALAR *buffer, int count, primme_context ctx) {

Expand All @@ -389,37 +387,46 @@ int broadcast_Sprimme(SCALAR *buffer, int count, primme_context ctx) {
#ifdef USE_DOUBLE

TEMPLATE_PLEASE
int globalSum_Tprimme(
void *buffer, primme_op_datatype buffert, int count, primme_context ctx) {
int globalSum_Tprimme(void *buffer_, primme_op_datatype buffert, int count,
primme_context ctx) {

primme_params *primme = ctx.primme;
REAL *buffer = buffer_;

/* Quick exit */

if (!primme || primme->numProcs == 1 || !primme->globalSumReal) {
return 0;
}

if (!primme || primme->numProcs == 1 || !primme->globalSumReal) { return 0; }

double t0 = primme_wTimer();

/* Transfer the buffer to host */

REAL *buffer0 = NULL;
CHKERR(Num_matrix_on_cpu_Rprimme(
buffer, 1, count, 1, &buffer0, NULL, 1 /* alloc */, ctx));

/* Cast buffer */

void *buffer0 = NULL;
CHKERR(Num_matrix_astype_Rprimme(buffer, 1, count, 1, buffert, &buffer0,
void *buffer1 = NULL;
CHKERR(Num_matrix_astype_Rprimme(buffer0, 1, count, 1, buffert, &buffer1,
NULL, primme->globalSumReal_type, 1 /* alloc */, 1 /* copy */, ctx));

int ierr = 0;
CHKERRM(
(primme->globalSumReal(buffer0, buffer0, &count, primme, &ierr), ierr),
(primme->globalSumReal(buffer1, buffer1, &count, primme, &ierr), ierr),
PRIMME_USER_FAILURE, "Error returned by 'globalSumReal' %d", ierr);

/* Copy back buffer0 */
/* Copy back buffer1 */

CHKERR(Num_matrix_astype_Rprimme(buffer0, 1, count, 1,
primme->globalSumReal_type, (void **)&buffer, NULL, buffert,
CHKERR(Num_matrix_astype_Rprimme(buffer1, 1, count, 1,
primme->globalSumReal_type, (void **)&buffer0, NULL, buffert,
-1 /* dealloc */, 1 /* copy */, ctx));

/* Copy back to gpu */

CHKERR(Num_matrix_on_cpu_Rprimme(buffer, 1, count, 1, &buffer0, NULL,
-1 /* copy back and dealloc */, ctx));

primme->stats.numGlobalSum++;
primme->stats.timeGlobalSum += primme_wTimer() - t0;
primme->stats.volumeGlobalSum += count;
Expand All @@ -428,37 +435,46 @@ int globalSum_Tprimme(
}

TEMPLATE_PLEASE
int broadcast_Tprimme(
void *buffer, primme_op_datatype buffert, int count, primme_context ctx) {
int broadcast_Tprimme(void *buffer_, primme_op_datatype buffert, int count,
primme_context ctx) {

primme_params *primme = ctx.primme;
REAL *buffer = buffer_;
int ierr;

/* Quick exit */

if (!primme || primme->numProcs == 1) {
return 0;
}
if (!primme || primme->numProcs == 1) { return 0; }

double t0 = primme_wTimer();

if (primme && primme->broadcastReal) {
/* Transfer the buffer to host */

REAL *buffer0 = NULL;
CHKERR(Num_matrix_on_cpu_Rprimme(
buffer, 1, count, 1, &buffer0, NULL, 1 /* alloc */, ctx));

/* Cast buffer */

void *buffer0 = NULL;
CHKERR(Num_matrix_astype_dprimme(buffer, 1, count, 1, buffert,
(void **)&buffer0, NULL, primme->broadcastReal_type, 1 /* alloc */,
void *buffer1 = NULL;
CHKERR(Num_matrix_astype_dprimme(buffer0, 1, count, 1, buffert,
(void **)&buffer1, NULL, primme->broadcastReal_type, 1 /* alloc */,
1 /* copy */, ctx));

CHKERRM((primme->broadcastReal(buffer0, &count, primme, &ierr), ierr),
CHKERRM((primme->broadcastReal(buffer1, &count, primme, &ierr), ierr),
PRIMME_USER_FAILURE, "Error returned by 'broadcastReal' %d", ierr);

/* Copy back buffer0 */
/* Copy back buffer1 */

CHKERR(Num_matrix_astype_Sprimme(buffer0, 1, count, 1,
primme->broadcastReal_type, (void **)&buffer, NULL, buffert,
CHKERR(Num_matrix_astype_Sprimme(buffer1, 1, count, 1,
primme->broadcastReal_type, (void **)&buffer0, NULL, buffert,
-1 /* dealloc */, 1 /* copy */, ctx));

/* Copy back to gpu */

CHKERR(Num_matrix_on_cpu_Rprimme(buffer, 1, count, 1, &buffer0, NULL,
-1 /* copy back and dealloc */, ctx));
} else {
if (primme->procID != 0) {
CHKERR(Num_zero_matrix_Tprimme(buffer, buffert, 1, count, 1, ctx));
Expand All @@ -482,8 +498,6 @@ int broadcast_iprimme(int *buffer, int count, primme_context ctx) {

#endif /* USE_DOUBLE */

#endif /* USE_HOST */

/*******************************************************************************
* Subroutine machineEpsMatrix - return the machine epsilon considering the
* precision of the matrixMatvec and massMatrixMatvec, and the working
Expand Down
28 changes: 26 additions & 2 deletions src/eigs/auxiliary_eigs.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ int broadcast_dprimme(dummy_type_dprimme *buffer, int count, primme_context ctx)
# define globalSum_TprimmeRHqprimme CONCAT(globalSum_Tprimme,CONCAT(CONCAT(CONCAT(,q),primme),))
#endif
int globalSum_Tprimme(
void *buffer, primme_op_datatype buffert, int count, primme_context ctx);
void *buffer_, primme_op_datatype buffert, int count, primme_context ctx);
#if !defined(CHECK_TEMPLATE) && !defined(broadcast_TprimmeSprimme)
# define broadcast_TprimmeSprimme CONCAT(broadcast_Tprimme,SCALAR_SUF)
#endif
Expand Down Expand Up @@ -858,7 +858,7 @@ int globalSum_Tprimme(
# define broadcast_TprimmeRHqprimme CONCAT(broadcast_Tprimme,CONCAT(CONCAT(CONCAT(,q),primme),))
#endif
int broadcast_Tprimme(
void *buffer, primme_op_datatype buffert, int count, primme_context ctx);
void *buffer_, primme_op_datatype buffert, int count, primme_context ctx);
#if !defined(CHECK_TEMPLATE) && !defined(broadcast_iprimmeSprimme)
# define broadcast_iprimmeSprimme CONCAT(broadcast_iprimme,SCALAR_SUF)
#endif
Expand Down Expand Up @@ -1608,6 +1608,8 @@ int massMatrixMatvec_magma_hprimme(dummy_type_magma_hprimme *V, PRIMME_INT nLoca
primme_context ctx);
int applyPreconditioner_magma_hprimme(dummy_type_magma_hprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_magma_hprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_magma_hprimme(dummy_type_magma_hprimme *buffer, int count, primme_context ctx);
int broadcast_magma_hprimme(dummy_type_magma_hprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_magma_hprimme(double *eps, primme_context ctx);
int machineEpsOrth_magma_hprimme(double *eps, primme_context ctx);
dummy_type_sprimme problemNorm_magma_hprimme(
Expand All @@ -1626,6 +1628,8 @@ int massMatrixMatvec_magma_kprimme(dummy_type_magma_kprimme *V, PRIMME_INT nLoca
primme_context ctx);
int applyPreconditioner_magma_kprimme(dummy_type_magma_kprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_magma_kprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_magma_kprimme(dummy_type_magma_kprimme *buffer, int count, primme_context ctx);
int broadcast_magma_kprimme(dummy_type_magma_kprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_magma_kprimme(double *eps, primme_context ctx);
int machineEpsOrth_magma_kprimme(double *eps, primme_context ctx);
dummy_type_sprimme problemNorm_magma_kprimme(
Expand All @@ -1644,6 +1648,8 @@ int massMatrixMatvec_magma_sprimme(dummy_type_magma_sprimme *V, PRIMME_INT nLoca
primme_context ctx);
int applyPreconditioner_magma_sprimme(dummy_type_magma_sprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_magma_sprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_magma_sprimme(dummy_type_magma_sprimme *buffer, int count, primme_context ctx);
int broadcast_magma_sprimme(dummy_type_magma_sprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_magma_sprimme(double *eps, primme_context ctx);
int machineEpsOrth_magma_sprimme(double *eps, primme_context ctx);
dummy_type_sprimme problemNorm_magma_sprimme(
Expand All @@ -1662,6 +1668,8 @@ int massMatrixMatvec_magma_cprimme(dummy_type_magma_cprimme *V, PRIMME_INT nLoca
primme_context ctx);
int applyPreconditioner_magma_cprimme(dummy_type_magma_cprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_magma_cprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_magma_cprimme(dummy_type_magma_cprimme *buffer, int count, primme_context ctx);
int broadcast_magma_cprimme(dummy_type_magma_cprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_magma_cprimme(double *eps, primme_context ctx);
int machineEpsOrth_magma_cprimme(double *eps, primme_context ctx);
dummy_type_sprimme problemNorm_magma_cprimme(
Expand All @@ -1680,6 +1688,8 @@ int massMatrixMatvec_magma_dprimme(dummy_type_magma_dprimme *V, PRIMME_INT nLoca
primme_context ctx);
int applyPreconditioner_magma_dprimme(dummy_type_magma_dprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_magma_dprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_magma_dprimme(dummy_type_magma_dprimme *buffer, int count, primme_context ctx);
int broadcast_magma_dprimme(dummy_type_magma_dprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_magma_dprimme(double *eps, primme_context ctx);
int machineEpsOrth_magma_dprimme(double *eps, primme_context ctx);
dummy_type_dprimme problemNorm_magma_dprimme(
Expand All @@ -1698,6 +1708,8 @@ int massMatrixMatvec_magma_zprimme(dummy_type_magma_zprimme *V, PRIMME_INT nLoca
primme_context ctx);
int applyPreconditioner_magma_zprimme(dummy_type_magma_zprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_magma_zprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_magma_zprimme(dummy_type_magma_zprimme *buffer, int count, primme_context ctx);
int broadcast_magma_zprimme(dummy_type_magma_zprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_magma_zprimme(double *eps, primme_context ctx);
int machineEpsOrth_magma_zprimme(double *eps, primme_context ctx);
dummy_type_dprimme problemNorm_magma_zprimme(
Expand All @@ -1716,6 +1728,8 @@ int massMatrixMatvec_cublas_hprimme(dummy_type_cublas_hprimme *V, PRIMME_INT nLo
primme_context ctx);
int applyPreconditioner_cublas_hprimme(dummy_type_cublas_hprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_cublas_hprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_cublas_hprimme(dummy_type_cublas_hprimme *buffer, int count, primme_context ctx);
int broadcast_cublas_hprimme(dummy_type_cublas_hprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_cublas_hprimme(double *eps, primme_context ctx);
int machineEpsOrth_cublas_hprimme(double *eps, primme_context ctx);
dummy_type_sprimme problemNorm_cublas_hprimme(
Expand All @@ -1734,6 +1748,8 @@ int massMatrixMatvec_cublas_kprimme(dummy_type_cublas_kprimme *V, PRIMME_INT nLo
primme_context ctx);
int applyPreconditioner_cublas_kprimme(dummy_type_cublas_kprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_cublas_kprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_cublas_kprimme(dummy_type_cublas_kprimme *buffer, int count, primme_context ctx);
int broadcast_cublas_kprimme(dummy_type_cublas_kprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_cublas_kprimme(double *eps, primme_context ctx);
int machineEpsOrth_cublas_kprimme(double *eps, primme_context ctx);
dummy_type_sprimme problemNorm_cublas_kprimme(
Expand All @@ -1752,6 +1768,8 @@ int massMatrixMatvec_cublas_sprimme(dummy_type_cublas_sprimme *V, PRIMME_INT nLo
primme_context ctx);
int applyPreconditioner_cublas_sprimme(dummy_type_cublas_sprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_cublas_sprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_cublas_sprimme(dummy_type_cublas_sprimme *buffer, int count, primme_context ctx);
int broadcast_cublas_sprimme(dummy_type_cublas_sprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_cublas_sprimme(double *eps, primme_context ctx);
int machineEpsOrth_cublas_sprimme(double *eps, primme_context ctx);
dummy_type_sprimme problemNorm_cublas_sprimme(
Expand All @@ -1770,6 +1788,8 @@ int massMatrixMatvec_cublas_cprimme(dummy_type_cublas_cprimme *V, PRIMME_INT nLo
primme_context ctx);
int applyPreconditioner_cublas_cprimme(dummy_type_cublas_cprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_cublas_cprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_cublas_cprimme(dummy_type_cublas_cprimme *buffer, int count, primme_context ctx);
int broadcast_cublas_cprimme(dummy_type_cublas_cprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_cublas_cprimme(double *eps, primme_context ctx);
int machineEpsOrth_cublas_cprimme(double *eps, primme_context ctx);
dummy_type_sprimme problemNorm_cublas_cprimme(
Expand All @@ -1788,6 +1808,8 @@ int massMatrixMatvec_cublas_dprimme(dummy_type_cublas_dprimme *V, PRIMME_INT nLo
primme_context ctx);
int applyPreconditioner_cublas_dprimme(dummy_type_cublas_dprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_cublas_dprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_cublas_dprimme(dummy_type_cublas_dprimme *buffer, int count, primme_context ctx);
int broadcast_cublas_dprimme(dummy_type_cublas_dprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_cublas_dprimme(double *eps, primme_context ctx);
int machineEpsOrth_cublas_dprimme(double *eps, primme_context ctx);
dummy_type_dprimme problemNorm_cublas_dprimme(
Expand All @@ -1806,6 +1828,8 @@ int massMatrixMatvec_cublas_zprimme(dummy_type_cublas_zprimme *V, PRIMME_INT nLo
primme_context ctx);
int applyPreconditioner_cublas_zprimme(dummy_type_cublas_zprimme *V, PRIMME_INT nLocal, PRIMME_INT ldV,
dummy_type_cublas_zprimme *W, PRIMME_INT ldW, int blockSize, primme_context ctx);
int globalSum_cublas_zprimme(dummy_type_cublas_zprimme *buffer, int count, primme_context ctx);
int broadcast_cublas_zprimme(dummy_type_cublas_zprimme *buffer, int count, primme_context ctx);
int machineEpsMatrix_cublas_zprimme(double *eps, primme_context ctx);
int machineEpsOrth_cublas_zprimme(double *eps, primme_context ctx);
dummy_type_dprimme problemNorm_cublas_zprimme(
Expand Down

0 comments on commit 8e1d430

Please sign in to comment.