Skip to content

Commit

Permalink
Done compartmentalized sketching functions (I think)
Browse files Browse the repository at this point in the history
  • Loading branch information
Heatherms27 committed Sep 27, 2023
1 parent d5ffc31 commit f158f61
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
9 changes: 1 addition & 8 deletions src/eigs/lanczos.c
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ int lanczos_Sprimme(HEVAL *evals, SCALAR *evecs, PRIMME_INT ldevecs,
CHKERR(Num_malloc_Sprimme(ldSV*primme->numEvals, &SVhVecs, ctx));
CHKERR(Num_malloc_Sprimme(ldSV*primme->numEvals, &SWhVecs, ctx));
CHKERR(Num_malloc_Sprimme(ldV*(maxBasisSize+blockSize), &V_temp, ctx));

S_rows = (PRIMME_INT*)malloc(nnzPerCol*nLocal*sizeof(PRIMME_INT));

/* Build Sketch CSR Locally */
Expand Down Expand Up @@ -380,14 +379,8 @@ int lanczos_Sprimme(HEVAL *evals, SCALAR *evecs, PRIMME_INT ldevecs,
/* TEST 1 - SKETCHED RESIDUALS */
/* Get the projected basis */
CHKERR(Num_gemm_Sprimme("N", "N", ldSV, i+blockSize, i+2*blockSize, 1.0, SV, ldSV, H, ldH, 0.0, SW, ldSV, ctx)); // SW = A*SV = SV*H
CHKERR(Num_gemm_Sprimme("N", "N", ldSV, numEvals, i+blockSize, 1.0, SV, ldSV, hVecs, ldhVecs, 0.0, SVhVecs, ldSV, ctx)); // SV*hVecs
CHKERR(Num_gemm_Sprimme("N", "N", ldSV, numEvals, i+blockSize, 1.0, SW, ldSV, hVecs, ldhVecs, 0.0, SWhVecs, ldSV, ctx)); // SW*hVecs

CHKERR(Num_compute_residuals_Sprimme(sketchSize, numEvals, evals, SVhVecs, ldSV, SWhVecs, ldSV, rwork, ldrwork, ctx)); // rwork = SW*hVecs - evals*SV*hVecs

for(j = 0; j < numEvals; j++) resNorms[j] = Num_dot_Sprimme(sketchSize, &rwork[j*ldrwork], 1, &rwork[j*ldrwork], 1, ctx) / Num_dot_Sprimme(sketchSize, &SVhVecs[j*ldSV], 1, &SVhVecs[j*ldSV], 1, ctx);
CHKERR(globalSum_Rprimme(resNorms, numEvals, ctx));
for(j = 0; j < numEvals; j++) resNorms[j] = sqrt(resNorms[j]);
CHKERR(sketched_residuals_Sprimme(SV, ldSV, SW, ldSV, hVecs, ldhVecs, evals, i+blockSize, rwork, ldrwork, resNorms, ctx));

if(primme->procID == 0){
for(j = 0; j < primme->numEvals; j++)
Expand Down
37 changes: 29 additions & 8 deletions src/eigs/sketch.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,41 +131,62 @@ int build_sketch_Sprimme(PRIMME_INT *S_rows, SCALAR *S_vals, PRIMME_INT ldS, PRI
* subspace embedding and returns the result in SV
******************************************************************************/
TEMPLATE_PLEASE
int sketch_basis_Sprimme(SCALAR *V, PRIMME_INT ldV, SCALAR *SV, PRIMME_INT ldSV, PRIMME_INT blockSize, PRIMME_INT nnz_per_col, PRIMME_INT *S_rows, SCALAR *S_vals, primme_context ctx) {
int sketch_basis_Sprimme(SCALAR *V, PRIMME_INT ldV, SCALAR *SV, PRIMME_INT ldSV, PRIMME_INT basisSize, PRIMME_INT nnz_per_col, PRIMME_INT *S_rows, SCALAR *S_vals, primme_context ctx) {

primme_params *primme = ctx.primme;

SCALAR *V_row; /* Used to temporarily store a row in V to avoid numberous memory accesses */
PRIMME_INT i, j;

CHKERR(Num_malloc_Sprimme(blockSize, &V_row, ctx));
CHKERR(Num_zero_matrix_Sprimme(SV, ldSV, blockSize, ldSV, ctx));
CHKERR(Num_malloc_Sprimme(basisSize, &V_row, ctx));
CHKERR(Num_zero_matrix_Sprimme(SV, ldSV, basisSize, ldSV, ctx));

/* Sparse MM */
for(i = 0; i < primme->nLocal; i++) /* Traverse the rows of the basis V */
{
CHKERR(Num_copy_Sprimme(blockSize, &V[i], ldV, V_row, 1, ctx));
for(j = 0; j < nnz_per_col; j++) CHKERR(Num_axpy_Sprimme(blockSize, S_vals[i*nnz_per_col+j], V_row, 1, &SV[S_rows[i*nnz_per_col+j]], ldSV, ctx));
CHKERR(Num_copy_Sprimme(basisSize, &V[i], ldV, V_row, 1, ctx));
for(j = 0; j < nnz_per_col; j++) CHKERR(Num_axpy_Sprimme(basisSize, S_vals[i*nnz_per_col+j], V_row, 1, &SV[S_rows[i*nnz_per_col+j]], ldSV, ctx));
}
CHKERR(Num_free_Sprimme(V_row, ctx));

/* Find the sketched basis */
CHKERR(globalSum_Sprimme(SV, blockSize*ldSV, ctx));
CHKERR(globalSum_Sprimme(SV, basisSize*ldSV, ctx));

return 0;
}

/******************************************************************************
* Subroutine sketched_residuals - This routine finds the residuals of the
* sketched Ritz pairs along with the residual norms.
* \| SAV*x - SV*x*lambda \|_2 / \|SV*x\_x
* \| SW*x - SV*x*lambda \|_2 / \|SV*x\|_2
* SAV = SV*H
******************************************************************************/
TEMPLATE_PLEASE
int sketched_residuals_Sprimme(SCALAR *SV, PRIMME_INT ldSV, HSCALAR *H, PRIMME_INT ldH, primme_context ctx) {
int sketched_residuals_Sprimme(SCALAR *SV, PRIMME_INT ldSV, SCALAR *SW, PRIMME_INT ldSW, HSCALAR *hVecs, PRIMME_INT ldhVecs, HEVAL *evals, PRIMME_INT basisSize, SCALAR *residuals, PRIMME_INT ldresiduals, SCALAR *resNorms, primme_context ctx) {

primme_params *primme = ctx.primme;

SCALAR *SVhVecs, *SWhVecs; /* Temporary arrays used to compute residuals */
PRIMME_INT numEvals = min(primme->numEvals, basisSize);
PRIMME_INT i; /* Loop variable */

CHKERR(Num_malloc_Sprimme(ldSV*numEvals, &SVhVecs, ctx));
CHKERR(Num_malloc_Sprimme(ldSW*numEvals, &SWhVecs, ctx));

CHKERR(Num_gemm_Sprimme("N", "N", ldSV, numEvals, basisSize, 1.0, SV, ldSV, hVecs, ldhVecs, 0.0, SVhVecs, ldSV, ctx)); // SVhVecs
CHKERR(Num_gemm_Sprimme("N", "N", ldSW, numEvals, basisSize, 1.0, SW, ldSW, hVecs, ldhVecs, 0.0, SWhVecs, ldSW, ctx)); // SWhVecs

/* Compute residual vectors */
CHKERR(Num_compute_residuals_Sprimme(ldSV, numEvals, evals, SVhVecs, ldSV, SWhVecs, ldSW, residuals, ldresiduals, ctx));

/* Compute residual norms */
for(i = 0; i < numEvals; i++) resNorms[i] = Num_dot_Sprimme(ldSV, &residuals[i*ldresiduals], 1, &residuals[i*ldresiduals], 1, ctx) / Num_dot_Sprimme(ldSV, &SVhVecs[i*ldSV], 1, &SVhVecs[i*ldSV], 1, ctx);
CHKERR(globalSum_Rprimme(resNorms, numEvals, ctx));
for(i = 0; i < numEvals; i++) resNorms[i] = sqrt(resNorms[i]);

Num_free_Sprimme(SVhVecs, ctx);
Num_free_Sprimme(SWhVecs, ctx);

return 0;
}

Expand Down

0 comments on commit f158f61

Please sign in to comment.