Skip to content

Commit

Permalink
Do not slice memoryviews in _compute_dist_middle_terms
Browse files Browse the repository at this point in the history
See the reasons here:
scikit-learn#17299
  • Loading branch information
jjerphan committed Oct 17, 2022
1 parent 8ddef01 commit f2e917b
Showing 1 changed file with 9 additions and 11 deletions.
Expand Up @@ -175,8 +175,6 @@ cdef class GEMMTermComputer{{name_suffix}}:
ITYPE_t thread_num,
) nogil:
cdef:
const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :]
const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data()

# Careful: LDA, LDB and LDC are given for F-ordered arrays
Expand All @@ -187,9 +185,9 @@ cdef class GEMMTermComputer{{name_suffix}}:
BLAS_Order order = RowMajor
BLAS_Trans ta = NoTrans
BLAS_Trans tb = Trans
ITYPE_t m = X_c.shape[0]
ITYPE_t n = Y_c.shape[0]
ITYPE_t K = X_c.shape[1]
ITYPE_t m = X_end - X_start
ITYPE_t n = Y_end - Y_start
ITYPE_t K = self.n_features
DTYPE_t alpha = - 2.
{{if upcast_to_float64}}
DTYPE_t * A = self.X_c_upcast[thread_num].data()
Expand All @@ -198,15 +196,15 @@ cdef class GEMMTermComputer{{name_suffix}}:
# Casting for A and B to remove the const is needed because APIs exposed via
# scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
# See: https://github.com/scipy/scipy/issues/14262
DTYPE_t * A = <DTYPE_t *> &X_c[0, 0]
DTYPE_t * B = <DTYPE_t *> &Y_c[0, 0]
DTYPE_t * A = <DTYPE_t *> &self.X[X_start, 0]
DTYPE_t * B = <DTYPE_t *> &self.Y[Y_start, 0]
{{endif}}
ITYPE_t lda = X_c.shape[1]
ITYPE_t ldb = X_c.shape[1]
ITYPE_t lda = self.n_features
ITYPE_t ldb = self.n_features
DTYPE_t beta = 0.
ITYPE_t ldc = Y_c.shape[0]
ITYPE_t ldc = Y_end - Y_start

# dist_middle_terms = `-2 * X_c @ Y_c.T`
# dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end].T`
_gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc)

return dist_middle_terms
Expand Down

0 comments on commit f2e917b

Please sign in to comment.