Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemm! alias #202

Open
harrisonritz opened this issue Jan 21, 2024 · 5 comments
Open

gemm! alias #202

harrisonritz opened this issue Jan 21, 2024 · 5 comments

Comments

@harrisonritz
Copy link

love the package!

BLAS.gemm! fails for any PDMat arguments unless you pass a.mat.
Maybe something like could be more general:

pd_gemm!(tA, tB, alpha, A, B, beta, C) =    BLAS.gemm!( tA, tB, alpha, 
                                                        A isa AbstractPDMat ? A.mat : A, 
                                                        B isa AbstractPDMat ? B.mat : B, 
                                                        beta, 
                                                        C isa AbstractPDMat ? C.mat : C);

Benchmarks seem to run just as fast.
minimal example:

using LinearAlgebra, PDMats, BenchmarkTools


ix = randn(20,20);
xx = PDMat(Hermitian(ix'*ix));
aa = randn(20,20);

pd_gemm!(tA, tB, alpha, A, B, beta, C) =    BLAS.gemm!( tA, tB, alpha, 
                                                        A isa AbstractPDMat ? A.mat : A, 
                                                        B isa AbstractPDMat ? B.mat : B, 
                                                        beta, 
                                                        C isa AbstractPDMat ? C.mat : C);

yy = zeros(20,20);
@benchmark mul!($yy, $xx, $aa', 1.0, 1.0)

yy = zeros(20,20);
@benchmark BLAS.gemm!('N', 'T', 1.0, $xx.mat, $aa, 1.0, $aa)

yy = zeros(20,20);
@benchmark pd_gemm!('N', 'T', 1.0, $xx, $aa, 1.0, $yy)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min  max):  2.384 μs  331.486 μs  ┊ GC (min  max):  0.00%  98.48%
 Time  (median):     3.176 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.285 μs ±  15.201 μs  ┊ GC (mean ± σ):  22.44% ±  6.26%

            █▂                                                 
  █▂▁▁▁▁▁▁▂▇██▇▆▅▄▃▃▂▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  2.38 μs         Histogram: frequency by time         6.3 μs <

 Memory estimate: 20.58 KiB, allocs estimate: 4.

BenchmarkTools.Trial: 10000 samples with 193 evaluations.
 Range (min  max):  505.394 ns  993.523 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     506.477 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   518.730 ns ±  28.264 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▂▁▁▁▁▄▃▁▁▁▂▂▁▂▁▁                                             ▁
  ██████████████████▇▇▇▆▆▆▇▇▆▆▆▆▆▆▇▆▇▆▆▆▆▆▆▆▆▅▆▆▅▆▆▅▅▄▄▅▃▄▅▄▅▄▅ █
  505 ns        Histogram: log(frequency) by time        642 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

BenchmarkTools.Trial: 10000 samples with 194 evaluations.
 Range (min  max):  505.371 ns  909.577 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     506.443 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   510.625 ns ±  13.425 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▆▂▁      ▁  ▂▃▁                                              ▁
  ████████▇███▇███▇▆▆▆▇▇▇▇▇▇▇▆▆▆▆▅▆▆▆▆▆▅▄▅▅▄▅▅▄▄▄▃▅▄▃▄▅▅▃▄▄▄▂▃▃ █
  505 ns        Histogram: log(frequency) by time        573 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.
@ararslan
Copy link
Member

Rather than special-casing particular BLAS functions, I think it would make more sense to define methods for the auxiliary functions used by the existing methods. For example, gemm! works as-is with one or more PDMat arguments by defining

Base.strides(X::PDMat) = strides(X.mat)
Base.unsafe_convert(::Type{Ptr{T}}, X::PDMat{T}) where {T} = Base.unsafe_convert(Ptr{T}, X.mat)

@harrisonritz
Copy link
Author

Agree that's much better!
If we use the AbstractPDMat type, could add to generics.jl?
I'm happy to file PR tomorrow.

@ararslan
Copy link
Member

At least as written, what I wrote before isn't generic because the mat field access is specific to the PDMat type; other AbstractPDMats such as PDiagMat don't have such a field. There also isn't a way to generically retrieve a "conventional" AbstractMatrix given an AbstractPDMat without making a copy. The AbstractPDMat subtypes defined in this package do all have Matrix constructor methods, but that makes a copy.

So all that is to say, I think it would have to go in src/pdmat.jl and remain specific to PDMat, at least for now.

@andreasnoack
Copy link
Member

The product of two symmetric matrices is generally not symmetric

julia> A = randn(3,3) |> t -> t't
3×3 Matrix{Float64}:
  0.276842   -0.0227288  0.31778
 -0.0227288   1.82429    1.71879
  0.31778     1.71879    2.08469

julia> B = randn(3,3) |> t -> t't
3×3 Matrix{Float64}:
  1.17375  -1.0618   -2.01181
 -1.0618    1.01923   1.72527
 -2.01181   1.72527   5.20746

julia> A*B
3×3 Matrix{Float64}:
 -0.290238  0.23114   1.05866
 -5.4216    4.84889  12.1436
 -5.64603   5.01109  13.182

and therefore cannot be a PDMat so the proposed definition can't work.

Futhermore, the BLAS nomenclature "GE" means general as in "no structure" so I don't think it would make sense to overload gemm for structured matrices like PDMat even if the result matrix was unstructured such that the operation wasn't incorrect. Instead, we could maybe define some mul! methods.

@ararslan
Copy link
Member

Ah right, math...

I think mul! already works as-is though? At least judging by its use in the OP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants