Skip to content

Commit

Permalink
Updated the sketched restart to more efficiently recomputing the Q an…
Browse files Browse the repository at this point in the history
…d R factors of the sketched basis
  • Loading branch information
Heatherms27 committed Mar 26, 2024
1 parent 9d43bbc commit 84ad177
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions src/eigs/restart.c
Expand Up @@ -1781,16 +1781,17 @@ STATIC int restart_sketched(SCALAR *V, int ldV, SCALAR *W, int ldW, SCALAR *SV,

int i; /* Loop variable */
SCALAR *V_temp;
//SCALAR *SV_temp;
SCALAR *SV_temp;
SCALAR *VecNorms;
SCALAR *hVecs_temp;
SCALAR *q;

int old_basisSize = *basisSize;

CHKERR(Num_malloc_Sprimme(ldV*restartSize, &V_temp, ctx));
//CHKERR(Num_malloc_Sprimme(ldSV*restartSize, &SV_temp, ctx));
CHKERR(Num_malloc_Sprimme(ldSV*restartSize, &SV_temp, ctx));
CHKERR(Num_malloc_Sprimme(restartSize, &VecNorms, ctx));
CHKERR(Num_malloc_Sprimme(old_basisSize*restartSize, &hVecs_temp, ctx));
CHKERR(Num_malloc_Sprimme(old_basisSize*restartSize, &q, ctx));

//CHKERR(Num_malloc_Sprimme(old_basisSize*restartSize, &hVecs_temp, ctx));

/* RESTART ALGORITHM ----------------------------
* 1. Orthonormalize hVecs
Expand All @@ -1805,10 +1806,7 @@ STATIC int restart_sketched(SCALAR *V, int ldV, SCALAR *W, int ldW, SCALAR *SV,
* ----------------------------------------------- */

// Orthogonalize hVecs
CHKERR(Num_copy_matrix_Sprimme(hVecs, old_basisSize, restartSize, ldhVecs, hVecs_temp, old_basisSize, ctx));
CHKERR(ortho_Sprimme(hVecs_temp, old_basisSize, NULL, 0, 0, restartSize-1, NULL, 0, 0, old_basisSize, primme->iseed, ctx));
CHKERR(Num_copy_matrix_Sprimme(hVecs_temp, old_basisSize, restartSize, old_basisSize, hVecs, ldhVecs, ctx));
CHKERR(Num_free_Sprimme(hVecs_temp, ctx));
CHKERR(ortho_Sprimme(hVecs, old_basisSize, NULL, 0, 0, restartSize-1, NULL, 0, 0, old_basisSize, primme->iseed, ctx));

// V = V * hVecs
CHKERR(Num_gemm_SHprimme("N", "N", ldV, restartSize, old_basisSize, 1.0, V, ldV, hVecs, ldhVecs, 0.0, V_temp, ldV, ctx));
Expand All @@ -1820,14 +1818,17 @@ STATIC int restart_sketched(SCALAR *V, int ldV, SCALAR *W, int ldW, SCALAR *SV,
CHKERR(Num_free_Sprimme(V_temp, ctx));

// SV = SV * hVecs
//CHKERR(Num_gemm_SHprimme("N", "N", ldSV, restartSize, old_basisSize, 1.0, SV, ldSV, hVecs, ldhVecs, 0.0, SV_temp, ldSV, ctx));
//CHKERR(Num_copy_matrix_Sprimme(SV_temp, ldSV, restartSize, ldSV, SV, ldSV, ctx)); // Copy the temporary matrix back into SV
CHKERR(Num_gemm_SHprimme("N", "N", ldSV, restartSize, old_basisSize, 1.0, SV, ldSV, hVecs, ldhVecs, 0.0, SV_temp, ldSV, ctx));
CHKERR(Num_copy_matrix_Sprimme(SV_temp, ldSV, restartSize, ldSV, SV, ldSV, ctx)); // Copy the temporary matrix back into SV

// SW = SW * hVecs
assert(ldSV == ldSW);
//CHKERR(Num_gemm_SHprimme("N", "N", ldSV, restartSize, old_basisSize, 1.0, SW, ldSW, hVecs, ldhVecs, 0.0, SV_temp, ldSW, ctx));
//CHKERR(Num_copy_matrix_Sprimme(SV_temp, ldSV, restartSize, ldSW, SW, ldSW, ctx)); // Copy the temporary matrix back into SW
//CHKERR(Num_free_Sprimme(SV_temp, ctx));
CHKERR(Num_gemm_SHprimme("N", "N", ldSV, restartSize, old_basisSize, 1.0, SW, ldSW, hVecs, ldhVecs, 0.0, SV_temp, ldSW, ctx));
CHKERR(Num_copy_matrix_Sprimme(SV_temp, ldSV, restartSize, ldSW, SW, ldSW, ctx)); // Copy the temporary matrix back into SW
CHKERR(Num_free_Sprimme(SV_temp, ctx));

// q = T * hVecs
CHKERR(Num_gemm_Sprimme("N", "N", old_basisSize, restartSize, old_basisSize, 1.0, T, ldT, hVecs, ldhVecs, 0.0, q, old_basisSize, ctx));

// Normalize the matrices
for(i = 0; i < restartSize; i++) VecNorms[i] = sqrt(Num_dot_Sprimme(primme->nLocal, &V[i*ldV], 1, &V[i*ldV], 1, ctx));
Expand All @@ -1836,16 +1837,23 @@ STATIC int restart_sketched(SCALAR *V, int ldV, SCALAR *W, int ldW, SCALAR *SV,
for (i = 0; i < restartSize; i++) {
CHKERR(Num_scal_Sprimme(ldV, 1.0/VecNorms[i], &V[i*ldV], 1, ctx));
CHKERR(Num_scal_Sprimme(ldW, 1.0/VecNorms[i], &W[i*ldW], 1, ctx));
//CHKERR(Num_scal_Sprimme(ldSV, 1.0/VecNorms[i], &SV[i*ldSV], 1, ctx));
//CHKERR(Num_scal_Sprimme(ldSW, 1.0/VecNorms[i], &SW[i*ldSW], 1, ctx));
CHKERR(Num_scal_Sprimme(ldSV, 1.0/VecNorms[i], &SV[i*ldSV], 1, ctx));
CHKERR(Num_scal_Sprimme(ldSW, 1.0/VecNorms[i], &SW[i*ldSW], 1, ctx));
CHKERR(Num_scal_Sprimme(old_basisSize, 1.0/VecNorms[i], &q[i*old_basisSize], 1, ctx));
}


// Recompute Q and R factors
CHKERR(ortho_Sprimme(q, old_basisSize, T, ldT, 0, restartSize-1, NULL, 0, 0, old_basisSize, primme->iseed, ctx)); // [q, r] = qr(T*hVecs*Vn)
SCALAR *Q_temp;
CHKERR(Num_malloc_Sprimme(ldQ*old_basisSize, &Q_temp, ctx));
CHKERR(Num_copy_matrix_Sprimme(Q, ldQ, old_basisSize, ldQ, Q_temp, ldQ, ctx)); // Temporary matrix to not overwrite Q
CHKERR(Num_gemm_SHprimme("N", "N", ldQ, restartSize, old_basisSize, 1.0, Q_temp, ldQ, q, old_basisSize, 0.0, Q, ldQ, ctx)); // Q = Q*q

CHKERR(Num_free_Sprimme(VecNorms, ctx));

CHKERR(Num_free_Sprimme(Q_temp, ctx));
CHKERR(Num_free_Sprimme(q, ctx));

// Update eiganpairs and residuals
CHKERR(sketch_basis_Sprimme(V, ldV, SV, ldSV, Q, ldQ, T, ldT, 0, restartSize, S_rows, S_vals, ctx));
CHKERR(sketch_basis_Sprimme(W, ldW, SW, ldSW, NULL, 0, NULL, 0, 0, restartSize, S_rows, S_vals, ctx));
CHKERR(sketched_RR_Sprimme(Q, ldQ, T, ldT, SW, ldSW, hVecs, restartSize, hVals, restartSize, ctx));

(*basisSize) = restartSize;
Expand Down

0 comments on commit 84ad177

Please sign in to comment.